Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions example/01_gemm/gemm_wmma_fp8_v3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;

using ALayout = Row;
using ALayout = Col;
using BLayout = Col;
using CLayout = Row;

Expand All @@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64, 64,
8, 8,
16, 16, // AK1, BK1
16, 16,
4, 2,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 4, 16, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
2, 16, 16, 0,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
ComputeTypeA, ComputeTypeB>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
constexpr auto BlockGemmPipeline_Selector()
{
Expand All @@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
Expand All @@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>{};
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_base
{
Expand All @@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};

using ThisThreadBlock = ThisThreadBlock<BlockSize>;

Expand All @@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
static constexpr index_t B_KRow = 1;
#endif

static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
ComputeTypeB,
AccDataType,
MPerWmma,
NPerWmma,
KPack / KInner,
TransposeC>{};

static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);

static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");

static constexpr auto wmma_gemm =
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};

static constexpr index_t KRepeat = KPerBlock / KPack;

static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
Expand Down Expand Up @@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
const auto wmma_krow = 0;
#endif

// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
}

__device__ static auto CalculateBThreadOriginDataIndex()
Expand All @@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
const auto wmma_krow = 0;
#endif

// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
}

template <index_t m0, index_t n0>
Expand Down Expand Up @@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}

using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
using Tuple7 = decltype(CalculateAThreadOriginDataIndex());

/**
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
Expand All @@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
Tuple7 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
Expand Down Expand Up @@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<KRepeat>{},
I1,
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I0,
I1));

static constexpr auto b_thread_desc_ =
Expand All @@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<KRepeat>{},
I1,
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I0,
I1));

// C[M, N, NumRegWmma]
Expand All @@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;

Expand All @@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;

Expand Down
Loading