Skip to content

Commit 8370a17

Browse files
committed
remove more GemmDataTypes
1 parent 30f2193 commit 8370a17

6 files changed

+111
-181
lines changed

include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ template <index_t BlockSize,
3838
index_t MRepeat,
3939
index_t NRepeat,
4040
index_t KPack,
41-
typename ComputeTypeA = FloatA,
42-
typename ComputeTypeB = FloatB,
43-
typename ComputeTypeGemm = FloatAcc>
41+
typename ComputeTypeA_ = FloatA,
42+
typename ComputeTypeB_ = FloatB>
4443
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
4544
{
4645
static constexpr auto I0 = Number<0>{};
@@ -50,6 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
5049

5150
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
5251

52+
using ComputeTypeA = conditional_t<is_same_v<ComputeTypeA_, ck::xf32_t>, float, ComputeTypeA_>;
53+
using ComputeTypeB = conditional_t<is_same_v<ComputeTypeB_, ck::xf32_t>, float, ComputeTypeB_>;
54+
using GemmDataTypeA = ComputeTypeA_;
55+
using GemmDataTypeB = ComputeTypeB_;
56+
5357
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
5458
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
5559
static constexpr index_t KPerBlock =
@@ -64,14 +68,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
6468
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
6569
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
6670

67-
static constexpr auto xdlops_gemm = XdlopsGemm<ComputeTypeA,
68-
MPerXDL,
69-
NPerXDL,
70-
KPack,
71-
ComputeTypeB,
72-
false,
73-
false,
74-
ComputeTypeGemm>{};
71+
static constexpr auto xdlops_gemm =
72+
XdlopsGemm<GemmDataTypeA, MPerXDL, NPerXDL, KPack, GemmDataTypeB, false, false>{};
7573

7674
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
7775

@@ -179,6 +177,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
179177

180178
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
181179
"wrong!");
180+
if constexpr(is_same_v<ComputeTypeA, ck::xf32_t> || is_same_v<ComputeTypeB, ck::xf32_t>)
181+
{
182+
static_assert(
183+
is_same_v<ComputeTypeA_, ComputeTypeA_>,
184+
"ComputeTypeA and ComputeTypeB must be both xf32_t when one of them is xf32_t");
185+
}
182186
}
183187

184188
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
@@ -406,10 +410,9 @@ template <index_t BlockSize,
406410
index_t MRepeat,
407411
index_t NRepeat,
408412
index_t KPack,
409-
typename ComputeTypeA = FloatA,
410-
typename ComputeTypeB = FloatB,
411-
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS,
412-
typename ComputeTypeGemm = FloatAcc>
413+
typename ComputeTypeA = FloatA,
414+
typename ComputeTypeB = FloatB,
415+
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
413416
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
414417
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
415418
FloatA,
@@ -423,8 +426,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
423426
NRepeat,
424427
KPack,
425428
ComputeTypeA,
426-
ComputeTypeB,
427-
ComputeTypeGemm>
429+
ComputeTypeB>
428430
{
429431
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
430432
FloatA,
@@ -438,8 +440,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
438440
NRepeat,
439441
KPack,
440442
ComputeTypeA,
441-
ComputeTypeB,
442-
ComputeTypeGemm>;
443+
ComputeTypeB>;
443444

444445
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
445446
using Base::a_block_desc_m0_m1_m2_k;
@@ -610,9 +611,8 @@ template <index_t BlockSize,
610611
index_t NRepeat,
611612
index_t KPack,
612613
LoopScheduler LoopSched,
613-
typename ComputeTypeA = FloatA,
614-
typename ComputeTypeB = FloatB,
615-
typename ComputeTypeGemm = FloatAcc>
614+
typename ComputeTypeA = FloatA,
615+
typename ComputeTypeB = FloatB>
616616
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
617617
{
618618
if constexpr(LoopSched == LoopScheduler::Default)
@@ -629,8 +629,7 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
629629
NRepeat,
630630
KPack,
631631
ComputeTypeA,
632-
ComputeTypeB,
633-
ComputeTypeGemm>{};
632+
ComputeTypeB>{};
634633
}
635634
else if constexpr(LoopSched == LoopScheduler::Interwave)
636635
{
@@ -648,8 +647,7 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
648647
KPack,
649648
ComputeTypeA,
650649
ComputeTypeB,
651-
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS,
652-
ComputeTypeGemm>{};
650+
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{};
653651
}
654652
};
655653

include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
7878
ELayout,
7979
ADataType,
8080
BDataType,
81-
ADataType,
81+
ComputeDataType,
8282
AccDataType,
8383
CShuffleDataType,
8484
ck::Tuple<>,
@@ -115,7 +115,6 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
115115
CDEBlockTransferScalarPerVector_NPerBlock,
116116
LoopSched,
117117
PipelineVer,
118-
BDataType,
119118
ComputeDataType>;
120119

121120
using Argument = typename GridwiseGemm::Argument;

include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
107107
using BComputeDataType =
108108
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
109109
#else
110-
static constexpr bool is_xf32 = is_same_v<AComputeDataType_, ck::xf32_t>;
111-
static_assert(!is_xf32 || is_same_v<AComputeDataType_, BComputeDataType_>,
112-
"A and B compute type should be the same when using xf32_t");
113-
using AComputeDataType =
114-
conditional_t<is_same_v<AComputeDataType_, ck::xf32_t>, float, AComputeDataType_>;
115-
using BComputeDataType =
116-
conditional_t<is_same_v<BComputeDataType_, ck::xf32_t>, float, BComputeDataType_>;
117-
using GemmDataType = AComputeDataType_;
110+
using AComputeDataType = AComputeDataType_;
111+
using BComputeDataType = BComputeDataType_;
118112
#endif
119113

120114
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -697,8 +691,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
697691
NPerXdl,
698692
BComputeDataType,
699693
is_single_rate_mfma,
700-
is_scale_mfma,
701-
GemmDataType>::selected_mfma.k_per_blk);
694+
is_scale_mfma>::selected_mfma.k_per_blk);
702695

703696
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
704697
BlockSize,
@@ -714,8 +707,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
714707
KPack,
715708
LoopSched,
716709
AComputeDataType,
717-
BComputeDataType,
718-
GemmDataType>();
710+
BComputeDataType>();
719711

720712
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
721713

include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,12 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
107107
using BComputeDataType =
108108
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
109109
#else
110-
static constexpr bool is_xf32 = is_same_v<AComputeDataType_, ck::xf32_t>;
111-
static_assert(!is_xf32 || is_same_v<AComputeDataType_, BComputeDataType_>,
112-
"A and B compute type should be the same when using xf32_t");
113110
using AComputeDataType =
114111
conditional_t<is_same_v<AComputeDataType_, ck::xf32_t>, float, AComputeDataType_>;
115112
using BComputeDataType =
116113
conditional_t<is_same_v<BComputeDataType_, ck::xf32_t>, float, BComputeDataType_>;
117-
using GemmDataType = AComputeDataType_;
114+
using GemmDataTypeA = AComputeDataType_;
115+
using GemmDataTypeB = BComputeDataType_;
118116
#endif
119117

120118
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -661,31 +659,28 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
661659
? true
662660
: false;
663661
constexpr auto is_scale_mfma = false;
664-
// todo: GemmDataType
665-
constexpr index_t KPack = math::max(lcm_AK1_BK1,
666-
MfmaSelector<AComputeDataType,
667-
MPerXdl,
668-
NPerXdl,
669-
BComputeDataType,
670-
is_single_rate_mfma,
671-
is_scale_mfma,
672-
GemmDataType>::selected_mfma.k_per_blk);
673-
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
674-
BlockSize,
675-
AComputeDataType,
676-
BComputeDataType,
677-
AccDataType,
678-
decltype(a_block_desc_ak0_m_ak1),
679-
decltype(b_block_desc_bk0_n_bk1),
680-
MPerXdl,
681-
NPerXdl,
682-
MXdlPerWave,
683-
NXdlPerWave,
684-
KPack,
685-
LoopSched,
686-
AComputeDataType,
687-
BComputeDataType,
688-
GemmDataType>();
662+
constexpr index_t KPack = math::max(lcm_AK1_BK1,
663+
MfmaSelector<GemmDataTypeA,
664+
MPerXdl,
665+
NPerXdl,
666+
GemmDataTypeB,
667+
is_single_rate_mfma,
668+
is_scale_mfma>::selected_mfma.k_per_blk);
669+
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
670+
BlockSize,
671+
AComputeDataType,
672+
BComputeDataType,
673+
AccDataType,
674+
decltype(a_block_desc_ak0_m_ak1),
675+
decltype(b_block_desc_bk0_n_bk1),
676+
MPerXdl,
677+
NPerXdl,
678+
MXdlPerWave,
679+
NXdlPerWave,
680+
KPack,
681+
LoopSched,
682+
GemmDataTypeA,
683+
GemmDataTypeB>();
689684

690685
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
691686

include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ template <typename ALayout,
140140
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
141141
LoopScheduler LoopSched,
142142
PipelineVersion PipelineVer = PipelineVersion::v4,
143-
typename BComputeDataType = AComputeDataType_,
144-
typename GemmDataType = EDataType>
143+
typename BComputeDataType_ = AComputeDataType_>
145144
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
146145
{
147146
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -169,7 +168,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
169168
using AComputeDataType =
170169
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
171170
#else
172-
using AComputeDataType = AComputeDataType_;
171+
using AComputeDataType =
172+
conditional_t<is_same_v<AComputeDataType_, ck::xf32_t>, float, AComputeDataType_>;
173+
using BComputeDataType =
174+
conditional_t<is_same_v<BComputeDataType_, ck::xf32_t>, float, BComputeDataType_>;
175+
using GemmDataTypeA = AComputeDataType_;
176+
using GemmDataTypeB = BComputeDataType_;
173177
#endif
174178

175179
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -633,13 +637,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
633637
constexpr auto is_scale_mfma = false;
634638

635639
constexpr index_t KPack = math::max(lcm_AK1_BK1,
636-
MfmaSelector<AComputeDataType,
640+
MfmaSelector<GemmDataTypeA,
637641
MPerXdl,
638642
NPerXdl,
639-
BComputeDataType,
643+
GemmDataTypeB,
640644
is_single_rate_mfma,
641-
is_scale_mfma,
642-
GemmDataType>::selected_mfma.k_per_blk);
645+
is_scale_mfma>::selected_mfma.k_per_blk);
643646

644647
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
645648
BlockSize,
@@ -654,9 +657,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
654657
NXdlPerWave,
655658
KPack,
656659
LoopSched,
657-
AComputeDataType_,
658-
BComputeDataType,
659-
GemmDataType>();
660+
GemmDataTypeA,
661+
GemmDataTypeB>();
660662

661663
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
662664

0 commit comments

Comments
 (0)