Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions include/ck_tile/core/tensor/buffer_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,9 @@ struct buffer_view<address_space_enum::global,
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;

// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size; // 1

constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size; // 8

static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
Expand All @@ -643,7 +643,15 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing = false;
#endif

constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; // 8
// static_assert(scalar_per_x_vector > scalar_per_t_vector, "Condition not met: ((
// scalar_per_x_vector > scalar_per_t_vector ))"); if(threadIdx.x == 0)
//{
// printf("[DEBUG]: BufferView: t_per_x: %d\n",t_per_x);
// printf("[DEBUG]: BufferView: scalar_per_x_vector: %d\n",scalar_per_x_vector);
// printf("[DEBUG]: BufferView: scalar_per_t_vector: %d\n",scalar_per_t_vector);
// printf("[DEBUG]: BufferView: x.size(): %d\n",x.size());
// }

if constexpr(use_amd_buffer_addressing)
{
Expand Down
163 changes: 139 additions & 24 deletions include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ struct CShuffleEpilogueProblem

static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");

CK_TILE_HOST static void PrintInfo()
{
printf("[DEBUG]: CShuffleEpilogueProblem: kBlockSize: %d\n", kBlockSize);
printf("[DEBUG]: CShuffleEpilogueProblem: kMPerBlock: %d\n", kMPerBlock);
printf("[DEBUG]: CShuffleEpilogueProblem: kNPerBlock: %d\n", kNPerBlock);
printf("[DEBUG]: CShuffleEpilogueProblem: MWave: %d\n", MWave);
printf("[DEBUG]: CShuffleEpilogueProblem: NWave: %d\n", NWave);
printf("[DEBUG]: CShuffleEpilogueProblem: MPerXdl: %d\n", MPerXdl);
printf("[DEBUG]: CShuffleEpilogueProblem: NPerXdl: %d\n", NPerXdl);
printf("[DEBUG]: CShuffleEpilogueProblem: KPerXdl: %d\n", KPerXdl);
printf("[DEBUG]: CShuffleEpilogueProblem: isCTransposed: %d\n", isCTransposed);
printf("[DEBUG]: CShuffleEpilogueProblem: MemoryOperation: %d\n",
static_cast<int>(MemoryOperation));
printf("[DEBUG]: CShuffleEpilogueProblem: FixedVectorSize: %d\n",
static_cast<int>(FixedVectorSize));
printf("[DEBUG]: CShuffleEpilogueProblem: VectorSizeC: %d\n", VectorSizeC);
printf("[DEBUG]: CShuffleEpilogueProblem: TiledMMAPermuteN: %d\n",
static_cast<int>(TiledMMAPermuteN));
printf("[DEBUG]: CShuffleEpilogueProblem: kNumWaveGroups: %d\n", kNumWaveGroups);
printf("[DEBUG]: CShuffleEpilogueProblem: NumDTensor: %d\n", NumDTensor);
}
};

template <typename Problem_, typename Policy_ = void>
Expand Down Expand Up @@ -114,6 +136,21 @@ struct CShuffleEpilogue
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);

static constexpr bool IsERowMajor =
std::is_same_v<ELayout, tensor_layout::gemm::RowMajor> ? true : false;

CK_TILE_HOST static void PrintInfo()
{
printf("[DEBUG]: CShuffleEpilogue: MPerIteration: %d\n", MPerIteration);
printf("[DEBUG]: CShuffleEpilogue: NPerIteration: %d\n", NPerIteration);
printf("[DEBUG]: CShuffleEpilogue: MRepeat: %d\n", MRepeat);
printf("[DEBUG]: CShuffleEpilogue: NRepeat: %d\n", NRepeat);
printf("[DEBUG]: CShuffleEpilogue: GetVectorSizeC: %d\n", GetVectorSizeC());
printf("[DEBUG]: CShuffleEpilogue: get_warp_size: %d\n", get_warp_size());
printf("[DEBUG]: CShuffleEpilogue: MPerIterationShuffle: %d\n", MPerIterationShuffle);
printf("[DEBUG]: CShuffleEpilogue: NPerIterationShuffle: %d\n", NPerIterationShuffle);
}

static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
Expand Down Expand Up @@ -225,6 +262,22 @@ struct CShuffleEpilogue
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);

static constexpr index_t NumYXdlPerWavePerShuffle =
IsERowMajor ? NumMXdlPerWavePerShuffle : NumNXdlPerWavePerShuffle;
static constexpr index_t NumXXdlPerWavePerShuffle =
IsERowMajor ? NumNXdlPerWavePerShuffle : NumMXdlPerWavePerShuffle;

static constexpr index_t YPerIterationShuffle =
IsERowMajor ? MPerIterationShuffle : NPerIterationShuffle;
static constexpr index_t XPerIterationShuffle =
IsERowMajor ? NPerIterationShuffle : MPerIterationShuffle;

static constexpr index_t YPerBlock = IsERowMajor ? kMPerBlock : kNPerBlock;
static constexpr index_t XPerBlock = IsERowMajor ? kNPerBlock : kMPerBlock;

static constexpr index_t YWave = IsERowMajor ? MWave : NWave;
static constexpr index_t XWave = IsERowMajor ? NWave : MWave;
Comment on lines +265 to +279
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using those? Thanks to this you get cleaner code without all those boilerplate if constexpr is_same.... RowMajor stuff.


using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
Expand Down Expand Up @@ -254,8 +307,8 @@ struct CShuffleEpilogue
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
make_tuple(number<NPerIterationShuffle>{}, number<MPerIterationShuffle>{}),
make_tuple(number<MPerIterationShuffle>{}, number<1>{}));
}
else
{
Expand All @@ -265,14 +318,32 @@ struct CShuffleEpilogue

CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumNXdlPerWavePerShuffle, NWave>,
sequence<NumMXdlPerWavePerShuffle, MWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});

Expand Down Expand Up @@ -581,22 +652,66 @@ struct CShuffleEpilogue
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);

auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);

auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
auto in_lds_window = [&o_lds_block, &LdsTileDistr] {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(
o_lds_block,
make_tuple(number<NPerIterationShuffle>{}, number<MPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}();
// auto in_lds_window = make_tile_window(
// o_lds_block,
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
// {0, 0},
// LdsTileDistr);

// auto out_lds_window = make_tile_window(
// o_lds_block,
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
// {0, 0});

auto out_lds_window = [&o_lds_block] {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(
o_lds_block,
make_tuple(number<NPerIterationShuffle>{}, number<MPerIterationShuffle>{}),
{0, 0});
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}();

constexpr index_t num_access = SFC::get_num_of_access();

static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");

// TODO: Add support for Col Major Output Layout - CShuffle Epilogue
// static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
// "Currently, the CShuffle Epilogue only supports the Row Major Output
// layout");
static_assert(GetVectorSizeC() > 1, "VectorSizeC is not greater than 1!");
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
Expand Down
44 changes: 34 additions & 10 deletions include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -740,9 +740,9 @@ struct UniversalGemmKernel
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_E),
number<1>{},
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
}();
Expand Down Expand Up @@ -831,9 +831,9 @@ struct UniversalGemmKernel
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<GemmPipeline::kPadM, false>{});
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, GemmPipeline::kPadM>{});
}
}();

Expand Down Expand Up @@ -929,10 +929,22 @@ struct UniversalGemmKernel
},
number<NumDTensor>{});

auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
const auto e_block_window = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(e_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
}();

return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
}
Expand Down Expand Up @@ -986,7 +998,19 @@ struct UniversalGemmKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);

// if(threadIdx.x == 0)
//{
// printf("CShuffleEpilogue operator() called! Before\n");
// c_block_window.template print_tile_window_range<EDataType>(0, 4, 0, 8, "A");
// }

EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);

// if(threadIdx.x == 0)
//{
// printf("CShuffleEpilogue operator() called! After\n");
// c_block_window.template print_tile_window_range<EDataType>(0, 4, 0, 8, "A");
// }
}
}

Expand Down
6 changes: 6 additions & 0 deletions include/ck_tile/utility.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/utility/json_dump.hpp"
24 changes: 12 additions & 12 deletions test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,18 @@ using KernelTypesMemWmma = ::testing::Types<
>;

using KernelTypesCompV3 = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
std::tuple< Row, Row, Col, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>//,
//std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>,
//std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>
>;

using KernelTypesCompV3Wmma = ::testing::Types<
Expand Down
Loading