Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for elementwise kernel.
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.

### Optimized

Expand Down
4 changes: 4 additions & 0 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};

template <typename PrecType>
Expand All @@ -298,6 +300,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};

template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
Expand Down
33 changes: 32 additions & 1 deletion example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,26 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}

template <typename GemmConfig, typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);

int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
GemmConfig::N_Warp,
GemmConfig::N_Warp_Tile,
NRepeat,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

template <typename CDataType>
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
Expand Down Expand Up @@ -346,7 +366,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,

if constexpr(preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if constexpr(GemmConfig::TiledMMAPermuteN)
{
std::cout << "Run with PermuteN" << std::endl;
return shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
std::cout << "Run without PermuteN" << std::endl;
return shuffle_b<GemmConfig>(b_k_n);
}
}();
// shuffled buffer B for device implementation
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}
Expand Down
3 changes: 3 additions & 0 deletions include/ck_tile/core/container/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ struct sequence
return sequence<type::get(number<Ids>{})...>{};
}

CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); }
CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); }

// modify element at index "I" with value "X"
template <index_t I, index_t X>
CK_TILE_HOST_DEVICE static constexpr auto modify(number<I>, number<X>)
Expand Down
152 changes: 149 additions & 3 deletions include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ template <typename ADataType_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1>
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
Expand All @@ -54,6 +55,7 @@ struct CShuffleEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();

Expand Down Expand Up @@ -89,10 +91,13 @@ struct CShuffleEpilogue
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);

static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
Expand Down Expand Up @@ -367,11 +372,152 @@ struct CShuffleEpilogue
struct EmptyScale
{
};

template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* /*p_smem*/,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
constexpr int kM0 = MWave;
constexpr int kM2 = 4;
constexpr int kM1 = MPerXdl / kM2;

constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;

using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});

auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});

constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};

auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);

// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;

// Tiles to hold row/col scales when present
using SMType =
std::conditional_t<has_scales, remove_cvref_t<typename ScaleM::DataType>, float>;
using SNType =
std::conditional_t<has_scales, remove_cvref_t<typename ScaleN::DataType>, float>;

auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);

// Build windows only if scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_m, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales)
{
return make_tile_window(scale_n, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();

static_for<0, MRepeat, 1>{}([&](auto mIter) {
// Slice accumulators for this M repeat into the permuted layout
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));

// If scales provided, load them with identical distribution
if constexpr(has_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
}

// Pack 4 “rows per lane” as you already do
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t base = n_idx * c_warp_y_lengths.product();

// local lambda to fuse scale (if present) and convert
auto emit = [&](index_t out_idx, index_t src_row) {
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];

if constexpr(has_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
v = static_cast<AccDataType>(v * s_m * s_n);
}

c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
};

// Your current packing pattern (rows 0..3, spaced by NRepeat)
emit(n_idx + 0 * NRepeat, 0);
emit(n_idx + 1 * NRepeat, 1);
emit(n_idx + 2 * NRepeat, 2);
emit(n_idx + 3 * NRepeat, 3);
});

// store/update
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}

// advance output (and any D-tensors) by one MPerXdl*MWave chunk
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}

template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale>
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
Expand Down
Loading