Skip to content

Commit e5492d0

Browse files
committed
rename xf32 to tf32
1 parent 8370a17 commit e5492d0

File tree

36 files changed

+180
-151
lines changed

36 files changed

+180
-151
lines changed

example/01_gemm/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
313313
template <typename DataType, typename GemmType = DataType>
314314
inline __host__ __device__ constexpr double get_rtol()
315315
{
316-
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::xf32_t>)
316+
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
317317
{
318318
return 1e-3;
319319
}
@@ -358,7 +358,7 @@ inline __host__ __device__ constexpr double get_rtol()
358358
template <typename DataType, typename GemmType = DataType>
359359
inline __host__ __device__ constexpr double get_atol()
360360
{
361-
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::xf32_t>)
361+
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
362362
{
363363
return 1e-3;
364364
}

example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using BDataType = F32;
2121
using AccDataType = F32;
2222
using CShuffleDataType = F32;
2323
using CDataType = F32;
24-
using GemmDataType = ck::xf32_t;
24+
using GemmDataType = ck::tf32_t;
2525

2626
using ALayout = Row;
2727
using BLayout = Col;

example/09_convnd_fwd/convnd_fwd_common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void print_helper_msg()
3030
template <typename DataType, typename GemmType = DataType>
3131
inline __host__ __device__ constexpr double get_rtol()
3232
{
33-
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::xf32_t>)
33+
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
3434
{
3535
return 1e-3;
3636
}
@@ -75,7 +75,7 @@ inline __host__ __device__ constexpr double get_rtol()
7575
template <typename DataType, typename GemmType = DataType>
7676
inline __host__ __device__ constexpr double get_atol()
7777
{
78-
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::xf32_t>)
78+
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
7979
{
8080
return 1e-3;
8181
}

example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using WeiDataType = float;
1414
using AccDataType = float;
1515
using CShuffleDataType = float;
1616
using OutDataType = float;
17-
using GemmDataType = ck::xf32_t;
17+
using GemmDataType = ck::tf32_t;
1818

1919
template <ck::index_t... Is>
2020
using S = ck::Sequence<Is...>;

include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
4949

5050
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
5151

52-
using ComputeTypeA = conditional_t<is_same_v<ComputeTypeA_, ck::xf32_t>, float, ComputeTypeA_>;
53-
using ComputeTypeB = conditional_t<is_same_v<ComputeTypeB_, ck::xf32_t>, float, ComputeTypeB_>;
52+
using ComputeTypeA = conditional_t<is_same_v<ComputeTypeA_, ck::tf32_t>, float, ComputeTypeA_>;
53+
using ComputeTypeB = conditional_t<is_same_v<ComputeTypeB_, ck::tf32_t>, float, ComputeTypeB_>;
5454
using GemmDataTypeA = ComputeTypeA_;
5555
using GemmDataTypeB = ComputeTypeB_;
5656

@@ -177,11 +177,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
177177

178178
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
179179
"wrong!");
180-
if constexpr(is_same_v<ComputeTypeA, ck::xf32_t> || is_same_v<ComputeTypeB, ck::xf32_t>)
180+
if constexpr(is_same_v<ComputeTypeA, ck::tf32_t> || is_same_v<ComputeTypeB, ck::tf32_t>)
181181
{
182-
static_assert(
183-
is_same_v<ComputeTypeA_, ComputeTypeA_>,
184-
"ComputeTypeA and ComputeTypeB must be both xf32_t when one of them is xf32_t");
182+
static_assert(is_same_v<ComputeTypeA_, ComputeTypeA_>,
183+
"ComputeTypeA and ComputeTypeB must be same when one of them is tf32");
185184
}
186185
}
187186

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
10431043

10441044
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
10451045
{
1046-
::std::cout << __FILE__ << ":" << __LINE__
1047-
<< " DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" << std::endl;
10481046
if(stream_config.log_level_ > 0)
10491047
{
10501048
arg.Print();
@@ -1657,6 +1655,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
16571655
arg.block_2_etile_map_);
16581656
}
16591657
}
1658+
if constexpr(is_same_v<AComputeDataType, ck::tf32_t> ||
1659+
is_same_v<BComputeDataType, ck::tf32_t>)
1660+
1661+
{
1662+
if(!(ck::get_device_name() == "gfx942"))
1663+
{
1664+
std::cout << "TF32 is enabled on gfx942 only" << std::endl;
1665+
return false;
1666+
}
1667+
if constexpr(!is_same_v<AComputeDataType, BComputeDataType>)
1668+
{
1669+
std::cout << "ComputeDataType for A and B should be same while using TF32"
1670+
<< std::endl;
1671+
return false;
1672+
}
1673+
}
1674+
return true;
16601675
}
16611676

16621677
bool IsSupportedArgument(const BaseArgument* p_arg) override

include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
108108
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
109109
#else
110110
using AComputeDataType =
111-
conditional_t<is_same_v<AComputeDataType_, ck::xf32_t>, float, AComputeDataType_>;
111+
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
112112
using BComputeDataType =
113-
conditional_t<is_same_v<BComputeDataType_, ck::xf32_t>, float, BComputeDataType_>;
113+
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
114114
using GemmDataTypeA = AComputeDataType_;
115115
using GemmDataTypeB = BComputeDataType_;
116116
#endif

include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
169169
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
170170
#else
171171
using AComputeDataType =
172-
conditional_t<is_same_v<AComputeDataType_, ck::xf32_t>, float, AComputeDataType_>;
172+
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
173173
using BComputeDataType =
174-
conditional_t<is_same_v<BComputeDataType_, ck::xf32_t>, float, BComputeDataType_>;
174+
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
175175
using GemmDataTypeA = AComputeDataType_;
176176
using GemmDataTypeB = BComputeDataType_;
177177
#endif

include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ enum struct MfmaInstr
7878
mfma_f32_16x16x128f8f6f4,
7979
mfma_scale_f32_32x32x64f8f6f4,
8080
mfma_scale_f32_16x16x128f8f6f4,
81-
mfma_f32_16x16x8xf32, // xf32
81+
mfma_f32_16x16x8xf32, // tf32
8282
mfma_f32_32x32x4xf32,
8383
// gfx11
8484
wmma_f32_16x16x16_f16,
@@ -1273,13 +1273,13 @@ struct MfmaSelector
12731273
}
12741274

12751275
template <>
1276-
constexpr auto GetMfma<xf32_t, 32, 32>()
1276+
constexpr auto GetMfma<tf32_t, 32, 32>()
12771277
{
12781278
return MfmaInstr::mfma_f32_32x32x4xf32;
12791279
}
12801280

12811281
template <>
1282-
constexpr auto GetMfma<xf32_t, 16, 16>()
1282+
constexpr auto GetMfma<tf32_t, 16, 16>()
12831283
{
12841284
return MfmaInstr::mfma_f32_16x16x8xf32;
12851285
}
@@ -1998,12 +1998,12 @@ struct XdlopsGemm
19981998
{
19991999
static_assert(
20002000
is_same<base_type, double>::value || is_same<base_type, float>::value ||
2001-
is_same<base_type, xf32_t>::value || is_same<base_type, half_t>::value ||
2001+
is_same<base_type, tf32_t>::value || is_same<base_type, half_t>::value ||
20022002
is_same<base_type, bhalf_t>::value || is_same<base_type, int8_t>::value ||
20032003
is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value ||
20042004
(is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
20052005
(is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
2006-
"base_type must be double, float, xf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
2006+
"base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
20072007

20082008
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
20092009
if constexpr(!TransposeC)

include/ck/utility/amd_xdlops.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1636,7 +1636,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
16361636
}
16371637
};
16381638

1639-
/******************* xf32 *************************************/
1639+
/******************* tf32 *************************************/
16401640
template <index_t MPerWave, index_t NPerWave>
16411641
struct intrin_mfma_f32_16x16x8xf32;
16421642

0 commit comments

Comments
 (0)