@@ -31,7 +31,8 @@ template <typename ADataType_,
31
31
memory_operation_enum MemoryOperation_,
32
32
index_t kNumWaveGroups_ = 1 ,
33
33
bool FixedVectorSize_ = false ,
34
- index_t VectorSizeC_ = 1 >
34
+ index_t VectorSizeC_ = 1 ,
35
+ bool TiledMMAPermuteN_ = false >
35
36
struct CShuffleEpilogueProblem
36
37
{
37
38
using ADataType = remove_cvref_t <ADataType_>;
@@ -54,6 +55,7 @@ struct CShuffleEpilogueProblem
54
55
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
55
56
static constexpr bool FixedVectorSize = FixedVectorSize_;
56
57
static constexpr index_t VectorSizeC = VectorSizeC_;
58
+ static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
57
59
static constexpr index_t kNumWaveGroups = kNumWaveGroups_ ;
58
60
static constexpr index_t NumDTensor = DsDataType::size();
59
61
@@ -89,10 +91,13 @@ struct CShuffleEpilogue
89
91
static constexpr index_t KPerXdl = Problem::KPerXdl;
90
92
static constexpr index_t isCTransposed = Problem::isCTransposed;
91
93
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
94
+ static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
92
95
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
93
96
static constexpr index_t MPerIteration = MPerXdl * MWave;
94
97
static constexpr index_t NPerIteration = NPerXdl * NWave;
95
98
static constexpr index_t NumDTensor = Problem::NumDTensor;
99
+ static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
100
+ static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
96
101
97
102
static_assert (NumDTensor == DsLayout::size(),
98
103
" The size of DsDataType and DsLayout should be the same" );
@@ -367,11 +372,152 @@ struct CShuffleEpilogue
367
372
struct EmptyScale
368
373
{
369
374
};
375
+
376
+ template <typename ODramWindow,
377
+ typename OAccTile,
378
+ typename DsDramWindows,
379
+ typename ScaleM = EmptyScale,
380
+ typename ScaleN = EmptyScale,
381
+ int EnablePermuateN_ = TiledMMAPermuteN,
382
+ std::enable_if_t <EnablePermuateN_, int > = 0 >
383
+ CK_TILE_DEVICE auto operator ()(ODramWindow& out_dram_window,
384
+ const OAccTile& o_acc_tile,
385
+ const DsDramWindows& ds_dram_windows,
386
+ void * /* p_smem*/ ,
387
+ const ScaleM& scale_m = {},
388
+ const ScaleN& scale_n = {})
389
+ {
390
+ constexpr int kM0 = MWave;
391
+ constexpr int kM2 = 4 ;
392
+ constexpr int kM1 = MPerXdl / kM2 ;
393
+
394
+ constexpr int kN0 = NWave;
395
+ constexpr int kN1 = NPerXdl;
396
+ constexpr int kN2 = NRepeat;
397
+
398
+ using IntrThreadShuffleEncode =
399
+ tile_distribution_encoding<sequence<>,
400
+ tuple<sequence<kM0 , kM1 , kM2 >, sequence<kN0 , kN1 , kN2 >>,
401
+ tuple<sequence<1 , 2 >, sequence<1 , 2 >>,
402
+ tuple<sequence<0 , 0 >, sequence<1 , 1 >>,
403
+ sequence<1 , 2 >,
404
+ sequence<2 , 2 >>;
405
+ constexpr auto dram_tile_distribution =
406
+ make_static_tile_distribution (IntrThreadShuffleEncode{});
407
+
408
+ auto d_dram_windows = generate_tuple (
409
+ [&](auto idx) {
410
+ return make_tile_window (ds_dram_windows[idx], dram_tile_distribution);
411
+ },
412
+ number<NumDTensor>{});
413
+
414
+ constexpr auto c_warp_y_lengths =
415
+ to_sequence (CWarpDstr{}.get_ys_to_d_descriptor ().get_lengths ());
416
+ constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t <CWarpDstr::NDimY, 0 >{};
417
+
418
+ auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
419
+ auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
420
+
421
+ // Optional scales (must share the same distribution to match per-thread indexing)
422
+ constexpr bool has_scales =
423
+ !std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
424
+
425
+ // Tiles to hold row/col scales when present
426
+ using SMType =
427
+ std::conditional_t <has_scales, remove_cvref_t <typename ScaleM::DataType>, float >;
428
+ using SNType =
429
+ std::conditional_t <has_scales, remove_cvref_t <typename ScaleN::DataType>, float >;
430
+
431
+ auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
432
+ auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
433
+
434
+ // Build windows only if scales are provided
435
+ auto scale_m_window = [&]() {
436
+ if constexpr (has_scales)
437
+ {
438
+ return make_tile_window (scale_m, dram_tile_distribution);
439
+ }
440
+ else
441
+ {
442
+ return EmptyScale{};
443
+ }
444
+ }();
445
+ auto scale_n_window = [&]() {
446
+ if constexpr (has_scales)
447
+ {
448
+ return make_tile_window (scale_n, dram_tile_distribution);
449
+ }
450
+ else
451
+ {
452
+ return EmptyScale{};
453
+ }
454
+ }();
455
+
456
+ static_for<0 , MRepeat, 1 >{}([&](auto mIter ) {
457
+ // Slice accumulators for this M repeat into the permuted layout
458
+ shuffle_acc.get_thread_buffer () = o_acc_tile.get_y_sliced_thread_data (
459
+ merge_sequences (sequence<mIter , 0 >{}, c_warp_y_index_zeros),
460
+ merge_sequences (sequence<1 , NRepeat>{}, c_warp_y_lengths));
461
+
462
+ // If scales provided, load them with identical distribution
463
+ if constexpr (has_scales)
464
+ {
465
+ sm_tile = load_tile (scale_m_window); // row scales in permuted layout
466
+ sn_tile = load_tile (scale_n_window); // col scales in permuted layout
467
+ }
468
+
469
+ // Pack 4 “rows per lane” as you already do
470
+ static_for<0 , NRepeat, 1 >{}([&](auto n_idx) {
471
+ // source indices in shuffle_acc: (n_idx * product(Y) + row)
472
+ const index_t base = n_idx * c_warp_y_lengths.product ();
473
+
474
+ // local lambda to fuse scale (if present) and convert
475
+ auto emit = [&](index_t out_idx, index_t src_row) {
476
+ AccDataType v = shuffle_acc.get_thread_buffer ()[base + src_row];
477
+
478
+ if constexpr (has_scales)
479
+ {
480
+ // same linear index mapping on the permuted distribution
481
+ const auto s_m = static_cast <float >(sm_tile.get_thread_buffer ()[out_idx]);
482
+ const auto s_n = static_cast <float >(sn_tile.get_thread_buffer ()[out_idx]);
483
+ v = static_cast <AccDataType>(v * s_m * s_n);
484
+ }
485
+
486
+ c_out_tensor.get_thread_buffer ()[out_idx] = type_convert<ODataType>(v);
487
+ };
488
+
489
+ // Your current packing pattern (rows 0..3, spaced by NRepeat)
490
+ emit (n_idx + 0 * NRepeat, 0 );
491
+ emit (n_idx + 1 * NRepeat, 1 );
492
+ emit (n_idx + 2 * NRepeat, 2 );
493
+ emit (n_idx + 3 * NRepeat, 3 );
494
+ });
495
+
496
+ // store/update
497
+ if constexpr (MemoryOperation == memory_operation_enum::set)
498
+ {
499
+ store_tile (out_dram_window, c_out_tensor);
500
+ }
501
+ else
502
+ {
503
+ update_tile (out_dram_window, c_out_tensor);
504
+ }
505
+
506
+ // advance output (and any D-tensors) by one MPerXdl*MWave chunk
507
+ move_tile_window (out_dram_window, {number<MPerXdl * MWave>{}, number<0 >{}});
508
+ static_for<0 , NumDTensor, 1 >{}([&](auto idx) {
509
+ move_tile_window (d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0 >{}});
510
+ });
511
+ });
512
+ }
513
+
370
514
template <typename ODramWindow,
371
515
typename OAccTile,
372
516
typename DsDramWindows,
373
- typename ScaleM = EmptyScale,
374
- typename ScaleN = EmptyScale>
517
+ typename ScaleM = EmptyScale,
518
+ typename ScaleN = EmptyScale,
519
+ int EnablePermuateN_ = TiledMMAPermuteN,
520
+ std::enable_if_t <!EnablePermuateN_, int > = 0 >
375
521
CK_TILE_DEVICE auto operator ()(ODramWindow& out_dram_window,
376
522
const OAccTile& o_acc_tile,
377
523
const DsDramWindows& ds_dram_windows,
0 commit comments