@@ -373,6 +373,22 @@ struct ScaledAccPerRowBias
373
373
static constexpr bool IsPerRowBiasSupported = true ;
374
374
};
375
375
376
+ template <
377
+ class ElementOutput_ ,
378
+ class ElementCompute_ ,
379
+ class ElementBias_ = ElementOutput_,
380
+ class ElementScalar_ = ElementCompute_,
381
+ int AlignmentBias_ = 128 / cute::sizeof_bits_v<ElementBias_>,
382
+ FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
383
+ >
384
+ struct ScaledAccPerColBias
385
+ : ScaledAcc<ElementOutput_, ElementCompute_, ElementScalar_, RoundStyle_>
386
+ {
387
+ using ElementBias = ElementBias_;
388
+ static constexpr int AlignmentBias = AlignmentBias_;
389
+ static constexpr bool IsPerColBiasSupported = true ;
390
+ };
391
+
376
392
template <
377
393
class GmemLayoutTagOut ,
378
394
class ElementOutput ,
@@ -393,6 +409,26 @@ struct ScaledAccPerRowBiasPerColScaleScatter
393
409
static constexpr bool IsAuxOutSupported = true ;
394
410
};
395
411
412
+ template <
413
+ class GmemLayoutTagOut ,
414
+ class ElementOutput ,
415
+ class ElementCompute ,
416
+ class ElementBias = ElementOutput,
417
+ class ElementScale = ElementCompute,
418
+ class ElementScalar = ElementCompute,
419
+ int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
420
+ int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
421
+ FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
422
+ >
423
+ struct ScaledAccPerColBiasPerRowScaleScatter
424
+ : ScaledAccPerColBias<ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>
425
+ {
426
+ using ElementAux = ElementOutput;
427
+ using GmemLayoutTagAux = GmemLayoutTagOut;
428
+ static constexpr int AlignmentAux = AlignmentOutput;
429
+ static constexpr bool IsAuxOutSupported = true ;
430
+ };
431
+
396
432
// D = alpha * acc + per-row bias
397
433
template <
398
434
class CtaTileShapeMNK ,
@@ -410,6 +446,22 @@ using Sm90ScaledAccPerRowBiasPtrArray =
410
446
Sm90ColBroadcast<0 , CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_1,_0,int64_t >, AlignmentBias> // bias
411
447
>;
412
448
449
+ template <
450
+ class CtaTileShapeMNK ,
451
+ class ElementOutput ,
452
+ class ElementCompute ,
453
+ class ElementBias = ElementOutput,
454
+ class ElementScalar = ElementCompute,
455
+ int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
456
+ FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
457
+ >
458
+ using Sm90ScaledAccPerColBiasPtrArray =
459
+ Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // alpha * acc + bias
460
+ Sm90ScalarBroadcastPtrArray<ElementScalar, Stride<_0,_0,int64_t >>, // alpha
461
+ Sm90AccFetch, // acc
462
+ Sm90RowBroadcast<0 , CtaTileShapeMNK, ElementBias *, ElementCompute, Stride<_0,_1,int64_t >, AlignmentBias> // bias
463
+ >;
464
+
413
465
template <
414
466
class CtaTileShapeMNK ,
415
467
class EpilogueTile ,
@@ -433,6 +485,29 @@ using Sm90ScaledAccPerRowBiasPerColScaleScatterPtrArray =
433
485
>
434
486
>;
435
487
488
+ template <
489
+ class CtaTileShapeMNK ,
490
+ class EpilogueTile ,
491
+ class StrideOutput ,
492
+ class SmemLayoutAtom ,
493
+ class CopyOpR2S ,
494
+ class ElementOutput ,
495
+ class ElementCompute ,
496
+ class ElementBias = ElementOutput,
497
+ class ElementScale = ElementCompute,
498
+ class ElementScalar = ElementCompute,
499
+ int AlignmentBias = 128 / cute::sizeof_bits_v<ElementBias>,
500
+ int AlignmentOutput = 128 / cute::sizeof_bits_v<ElementOutput>,
501
+ FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
502
+ >
503
+ using Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray =
504
+ Sm90EVT<Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>, // scatter store
505
+ Sm90EVT<Sm90Compute<multiplies, ElementCompute, ElementCompute, RoundStyle>, // scale * (alpha * acc + bias)
506
+ Sm90ColBroadcast<0 , CtaTileShapeMNK, ElementScalar *, ElementCompute, Stride<_1,_0,int64_t >, 1 >, // scale
507
+ Sm90ScaledAccPerColBiasPtrArray<CtaTileShapeMNK, ElementCompute, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> // alpha * acc + bias
508
+ >
509
+ >;
510
+
436
511
template <
437
512
int StagesC,
438
513
int StagesD,
@@ -556,6 +631,129 @@ struct FusionCallbacks<
556
631
557
632
};
558
633
634
+ template <
635
+ int StagesC,
636
+ int StagesD,
637
+ int FragmentSize,
638
+ bool ReuseSmemC,
639
+ bool DelayTmaStore,
640
+ int NumEpilogueWarpGroups,
641
+ class GmemLayoutTagOut ,
642
+ class ElementOutput ,
643
+ class ElementCompute ,
644
+ class ElementBias ,
645
+ class ElementScale ,
646
+ class ElementScalar ,
647
+ int AlignmentBias,
648
+ int AlignmentOutput,
649
+ FloatRoundStyle RoundStyle,
650
+ class CtaTileShapeMNK ,
651
+ class EpilogueTile ,
652
+ class SmemLayoutAtom ,
653
+ class CopyOpR2S
654
+ >
655
+ struct FusionCallbacks <
656
+ epilogue::Sm90PtrArrayTmaWarpSpecialized<StagesC,
657
+ StagesD,
658
+ FragmentSize,
659
+ ReuseSmemC,
660
+ DelayTmaStore,
661
+ NumEpilogueWarpGroups
662
+ >,
663
+ fusion::ScaledAccPerColBiasPerRowScaleScatter<GmemLayoutTagOut,
664
+ ElementOutput,
665
+ ElementCompute,
666
+ ElementBias,
667
+ ElementScale,
668
+ ElementScalar,
669
+ AlignmentBias,
670
+ AlignmentOutput,
671
+ RoundStyle>,
672
+ CtaTileShapeMNK,
673
+ EpilogueTile,
674
+ SmemLayoutAtom,
675
+ CopyOpR2S
676
+ > : Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray<
677
+ CtaTileShapeMNK,
678
+ EpilogueTile,
679
+ cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>,
680
+ SmemLayoutAtom, CopyOpR2S,
681
+ ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
682
+ AlignmentBias, AlignmentOutput, RoundStyle
683
+ > {
684
+
685
+ using StrideOutput = cutlass::gemm::TagToStrideC_t<GmemLayoutTagOut>;
686
+
687
+ using Impl = Sm90ScaledAccPerColBiasPerRowScaleScatterPtrArray<
688
+ CtaTileShapeMNK,
689
+ EpilogueTile,
690
+ StrideOutput,
691
+ SmemLayoutAtom, CopyOpR2S,
692
+ ElementOutput, ElementCompute, ElementBias, ElementScale, ElementScalar,
693
+ AlignmentBias, AlignmentOutput, RoundStyle
694
+ >;
695
+ using Operation = fusion::ScaledAccPerColBiasPerRowScaleScatter<
696
+ GmemLayoutTagOut,
697
+ ElementOutput,
698
+ ElementCompute,
699
+ ElementBias,
700
+ ElementScale,
701
+ ElementScalar,
702
+ AlignmentBias,
703
+ AlignmentOutput,
704
+ RoundStyle>;
705
+
706
+ struct Arguments {
707
+
708
+ using StrideAlpha = Stride<_0,_0,int64_t >;
709
+ ElementScalar alpha = ElementScalar(1 );
710
+ ElementScalar const * alpha_ptr{};
711
+ ElementScalar const * const * alpha_ptr_array{};
712
+ StrideAlpha dAlpha{};
713
+
714
+ using StrideBias = Stride<_0,_1,int64_t >;
715
+ ElementBias const * const * bias_ptr{};
716
+ StrideBias dBias{};
717
+
718
+ using StrideScale = Stride<_1,_0,int64_t >;
719
+ ElementScalar const * const * scale_ptr_array{};
720
+ StrideScale dScale{};
721
+
722
+ // Nested args not usable due to a compiler bug with constexpr evaluation
723
+ // using ScatterArguments = typename Sm90ScatterPtrArray<EpilogueTile, StrideOutput, SmemLayoutAtom, CopyOpR2S, ElementOutput, AlignmentOutput, RoundStyle>::Arguments;
724
+ // ScatterArguments scatter{};
725
+
726
+ ElementOutput* ptr_out = nullptr ;
727
+ StrideOutput dOut = {};
728
+ int const * const * ptr_index{}; // per-group pointer to the scatter index
729
+ int index_modulo{}; // modulo used to transform the index before store
730
+ int shape_override = -1 ; // override value for contiguous output tensor mode
731
+ bool use_reduction = true ;
732
+
733
+ operator typename Impl::Arguments () const {
734
+ return
735
+ { // unary op: reduce(scale * (beta * C + (alpha * acc)))
736
+ { // binary op: scale * (beta * C + (alpha * acc))
737
+ { scale_ptr_array, ElementScalar (1 ), dScale }, // leaf args : scale broadcast
738
+ { // ternary op : alpha * acc + bias
739
+ {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha
740
+ {}, // leaf args : acc
741
+ {bias_ptr, ElementBias (0 ), dBias}, // leaf args : bias
742
+ {} // ternary args : multiply_add
743
+ }, // end binary op
744
+ {} // binary args: multiply
745
+ }, // end binary op
746
+ // scatter // unary args: reduce
747
+ { ptr_out, dOut, ptr_index, index_modulo, shape_override, use_reduction }
748
+ }; // end unary op
749
+ }
750
+ };
751
+
752
+ // Ctor inheritance
753
+ using Impl::Impl;
754
+
755
+ };
756
+
559
757
} // namespace cutlass::epilogue::fusion
560
758
561
759
// clang-format on
0 commit comments