Skip to content

Commit 75570d0

Browse files
[CK_TILE] Add permuteN optimization to remove lds operation in c_shuffle (#2764)
* permuteN optimization to remove lds operation in c_shuffle * add the change log --------- Co-authored-by: ThomasNing <[email protected]>
1 parent 92b0738 commit 75570d0

File tree

5 files changed

+189
-4
lines changed

5 files changed

+189
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
2828
* Added support for elementwise kernel.
2929
* Added benchmarking support for tile engine GEMM Multi D.
3030
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
31+
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
3132

3233
### Optimized
3334

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
276276
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
277277
static constexpr bool Preshuffle = true;
278278
static constexpr bool DoubleSmemBuffer = true;
279+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
280+
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
279281
};
280282

281283
template <typename PrecType>
@@ -298,6 +300,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
298300
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
299301
static constexpr bool Preshuffle = true;
300302
static constexpr bool DoubleSmemBuffer = true;
303+
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
304+
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
301305
};
302306

303307
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>

example/ck_tile/03_gemm/run_gemm_example.inc

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,26 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
241241
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
242242
}
243243

244+
template <typename GemmConfig, typename T>
245+
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
246+
{
247+
assert(t.get_lengths().size() == 2);
248+
249+
int n_ = t.get_lengths()[1];
250+
int k_ = t.get_lengths()[0];
251+
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
252+
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
253+
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
254+
GemmConfig::N_Warp,
255+
GemmConfig::N_Warp_Tile,
256+
NRepeat,
257+
k_ / GemmConfig::K_Warp_Tile,
258+
divisor,
259+
GemmConfig::K_Warp_Tile / divisor});
260+
std::copy(t.begin(), t.end(), t_view.begin());
261+
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
262+
}
263+
244264
template <typename CDataType>
245265
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
246266
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
@@ -346,7 +366,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
346366

347367
if constexpr(preshuffle)
348368
{
349-
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
369+
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
370+
if constexpr(GemmConfig::TiledMMAPermuteN)
371+
{
372+
std::cout << "Run with PermuteN" << std::endl;
373+
return shuffle_b_permuteN<GemmConfig>(b_k_n);
374+
}
375+
else
376+
{
377+
std::cout << "Run without PermuteN" << std::endl;
378+
return shuffle_b<GemmConfig>(b_k_n);
379+
}
380+
}();
350381
// shuffled buffer B for device implementation
351382
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
352383
}

include/ck_tile/core/container/sequence.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ struct sequence
175175
return sequence<type::get(number<Ids>{})...>{};
176176
}
177177

178+
CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); }
179+
CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); }
180+
178181
// modify element at index "I" with value "X"
179182
template <index_t I, index_t X>
180183
CK_TILE_HOST_DEVICE static constexpr auto modify(number<I>, number<X>)

include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp

Lines changed: 149 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ template <typename ADataType_,
3131
memory_operation_enum MemoryOperation_,
3232
index_t kNumWaveGroups_ = 1,
3333
bool FixedVectorSize_ = false,
34-
index_t VectorSizeC_ = 1>
34+
index_t VectorSizeC_ = 1,
35+
bool TiledMMAPermuteN_ = false>
3536
struct CShuffleEpilogueProblem
3637
{
3738
using ADataType = remove_cvref_t<ADataType_>;
@@ -54,6 +55,7 @@ struct CShuffleEpilogueProblem
5455
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
5556
static constexpr bool FixedVectorSize = FixedVectorSize_;
5657
static constexpr index_t VectorSizeC = VectorSizeC_;
58+
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
5759
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
5860
static constexpr index_t NumDTensor = DsDataType::size();
5961

@@ -89,10 +91,13 @@ struct CShuffleEpilogue
8991
static constexpr index_t KPerXdl = Problem::KPerXdl;
9092
static constexpr index_t isCTransposed = Problem::isCTransposed;
9193
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
94+
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
9295
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
9396
static constexpr index_t MPerIteration = MPerXdl * MWave;
9497
static constexpr index_t NPerIteration = NPerXdl * NWave;
9598
static constexpr index_t NumDTensor = Problem::NumDTensor;
99+
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
100+
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
96101

97102
static_assert(NumDTensor == DsLayout::size(),
98103
"The size of DsDataType and DsLayout should be the same");
@@ -367,11 +372,152 @@ struct CShuffleEpilogue
367372
struct EmptyScale
368373
{
369374
};
375+
376+
template <typename ODramWindow,
377+
typename OAccTile,
378+
typename DsDramWindows,
379+
typename ScaleM = EmptyScale,
380+
typename ScaleN = EmptyScale,
381+
int EnablePermuateN_ = TiledMMAPermuteN,
382+
std::enable_if_t<EnablePermuateN_, int> = 0>
383+
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
384+
const OAccTile& o_acc_tile,
385+
const DsDramWindows& ds_dram_windows,
386+
void* /*p_smem*/,
387+
const ScaleM& scale_m = {},
388+
const ScaleN& scale_n = {})
389+
{
390+
constexpr int kM0 = MWave;
391+
constexpr int kM2 = 4;
392+
constexpr int kM1 = MPerXdl / kM2;
393+
394+
constexpr int kN0 = NWave;
395+
constexpr int kN1 = NPerXdl;
396+
constexpr int kN2 = NRepeat;
397+
398+
using IntrThreadShuffleEncode =
399+
tile_distribution_encoding<sequence<>,
400+
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
401+
tuple<sequence<1, 2>, sequence<1, 2>>,
402+
tuple<sequence<0, 0>, sequence<1, 1>>,
403+
sequence<1, 2>,
404+
sequence<2, 2>>;
405+
constexpr auto dram_tile_distribution =
406+
make_static_tile_distribution(IntrThreadShuffleEncode{});
407+
408+
auto d_dram_windows = generate_tuple(
409+
[&](auto idx) {
410+
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
411+
},
412+
number<NumDTensor>{});
413+
414+
constexpr auto c_warp_y_lengths =
415+
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
416+
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
417+
418+
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
419+
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
420+
421+
// Optional scales (must share the same distribution to match per-thread indexing)
422+
constexpr bool has_scales =
423+
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
424+
425+
// Tiles to hold row/col scales when present
426+
using SMType =
427+
std::conditional_t<has_scales, remove_cvref_t<typename ScaleM::DataType>, float>;
428+
using SNType =
429+
std::conditional_t<has_scales, remove_cvref_t<typename ScaleN::DataType>, float>;
430+
431+
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
432+
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
433+
434+
// Build windows only if scales are provided
435+
auto scale_m_window = [&]() {
436+
if constexpr(has_scales)
437+
{
438+
return make_tile_window(scale_m, dram_tile_distribution);
439+
}
440+
else
441+
{
442+
return EmptyScale{};
443+
}
444+
}();
445+
auto scale_n_window = [&]() {
446+
if constexpr(has_scales)
447+
{
448+
return make_tile_window(scale_n, dram_tile_distribution);
449+
}
450+
else
451+
{
452+
return EmptyScale{};
453+
}
454+
}();
455+
456+
static_for<0, MRepeat, 1>{}([&](auto mIter) {
457+
// Slice accumulators for this M repeat into the permuted layout
458+
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
459+
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
460+
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
461+
462+
// If scales provided, load them with identical distribution
463+
if constexpr(has_scales)
464+
{
465+
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
466+
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
467+
}
468+
469+
// Pack 4 “rows per lane” as you already do
470+
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
471+
// source indices in shuffle_acc: (n_idx * product(Y) + row)
472+
const index_t base = n_idx * c_warp_y_lengths.product();
473+
474+
// local lambda to fuse scale (if present) and convert
475+
auto emit = [&](index_t out_idx, index_t src_row) {
476+
AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row];
477+
478+
if constexpr(has_scales)
479+
{
480+
// same linear index mapping on the permuted distribution
481+
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
482+
const auto s_n = static_cast<float>(sn_tile.get_thread_buffer()[out_idx]);
483+
v = static_cast<AccDataType>(v * s_m * s_n);
484+
}
485+
486+
c_out_tensor.get_thread_buffer()[out_idx] = type_convert<ODataType>(v);
487+
};
488+
489+
// Your current packing pattern (rows 0..3, spaced by NRepeat)
490+
emit(n_idx + 0 * NRepeat, 0);
491+
emit(n_idx + 1 * NRepeat, 1);
492+
emit(n_idx + 2 * NRepeat, 2);
493+
emit(n_idx + 3 * NRepeat, 3);
494+
});
495+
496+
// store/update
497+
if constexpr(MemoryOperation == memory_operation_enum::set)
498+
{
499+
store_tile(out_dram_window, c_out_tensor);
500+
}
501+
else
502+
{
503+
update_tile(out_dram_window, c_out_tensor);
504+
}
505+
506+
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
507+
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
508+
static_for<0, NumDTensor, 1>{}([&](auto idx) {
509+
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
510+
});
511+
});
512+
}
513+
370514
template <typename ODramWindow,
371515
typename OAccTile,
372516
typename DsDramWindows,
373-
typename ScaleM = EmptyScale,
374-
typename ScaleN = EmptyScale>
517+
typename ScaleM = EmptyScale,
518+
typename ScaleN = EmptyScale,
519+
int EnablePermuateN_ = TiledMMAPermuteN,
520+
std::enable_if_t<!EnablePermuateN_, int> = 0>
375521
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
376522
const OAccTile& o_acc_tile,
377523
const DsDramWindows& ds_dram_windows,

0 commit comments

Comments
 (0)