Skip to content

Commit dee185d

Browse files
ecamartinsarai713
andauthored
[CK_TILE] Stream-K GEMM Implementation (#2781)
* Change splitk_batch_offset parameter to k_size in UniversalGemmKernel::MakeGemmTensorViews function Prior to this change, the splitk_batch_offset parameter of MakeGemmTensorViews had type SplitKBatchOffset. But, the only member variable of the SplitKBatchOffset class used in the MakeGemmTensorViews function was splitted_k (an int32_t). The splitted_k value was used as part of defining the dimensions of the tensor view. That said, for Stream K, we do not need to use the SplitKBatchOffset class since we are not using Split K. Thus, this commit changes the splitk_batch_offset parameter to a int32_t called k_size. This will avoid the constraint of requiring a caller of MakeGemmTensorViews to use the SplitKBatchOffset class while still providing the same functionality. Calls to UniversalGemmKernel::MakeGemmTensorViews have been updated accordingly. * StreamK Kernel RunGemm Implementation Stream K cannot simply use UniversalGemmKernel's RunGemm for the following reasons: 1. The UniversalGemmKernel::RunGemm function computes num_loop based on a static function of the TilePartitioner. That said, for Stream K, num_loop must be computed using a member function (namely GetCurrentIterLength from PR #2708). 2. The UniversalGemmKernel::RunGemm function requires the use of a SplitKBatchOffset object which is not used for Stream K since we are not using Split K. Thus, this change adds a RunGemm function in the StreamKKernel class. * initial implementation for operator() for StreamKKernel: adding stream-k algorithm and calls to RunGemm * Fix indexing and offset issues for StreamK These changes do the following: - Ensure offsets along the M and N dimensions are multiplied by MPerblock or NPerBlock, respectively. This ensures tile window origins are at the correct locations. - Fix bug in the tile partitioner's GetTileIdxWithOffset. Now, we apply divmod to the given references to ensure correct values are available to the caller. - Added documentation in the Stream-K operator() * Initial gtests for Stream-K These changes add an initial gtest suite for the CK Tile Stream-K kernel. Currently, due to bugs in the StreamKTilePartitioner (which will be handled in a future PR), there are validation issues for certain cases which may differ on different architectures. Thus, we opted to run cases that are only fully data-parallel (skipping others). A guard was added to Stream-K's IsSupportedArgument method to ensure that callers are aware of this constraint. Additionally, to ensure testing reproducibility, options for setting the number of CUs and occupancy were added to MakeKernelArgs. * Use GemmPipeline operator() variant that takes hot loop and tail num In Stream-K, the num_loop value varies per WG and per iteration of a Stream-K loop. So instead, we use the version of the GemmPipeline's operator() function that takes in has_hot_loop and tail_num. This is similar to what is done in Grouped GEMM. * changes from review: comments, move readfirstlane, remove ifndef * Switch direction of C tensor traversal & add padding guard Prior to this change, WGs travelled backwards through their assigned macro tiles in the C tensor. For instance, if WG0 is responsible for C tiles 0 and 1, it would first visit tile 1 then tile 0. This means that the iter_end decrements in each iteration of the stream-K while loop. Since we are working with unsigned integers, the subtraction operation may not be safe. Thus, this change makes is such that WGs travel forward so that their iter_start is incremented and their iter_end remains fixed. Additionally, we added a guard against WGs that are neither sk_blocks nor dp_blocks to ensure such WGs do not participate in the GEMM. Together, these changes make is such that the algorithm is correct when sk_blocks is greater than zero. * Disable StreamK_M256_N256_K256_SKBlocks12 test case This instance involves >=3 WGs contributing to each macro tile in C. Due to the use of atomics, this is resulting in precision errors. These errors will not persist once the reduction strategy is implemented. We will re-enable this test then. --------- Co-authored-by: Astha Rai <[email protected]>
1 parent b7a806f commit dee185d

File tree

10 files changed

+612
-35
lines changed

10 files changed

+612
-35
lines changed

include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -646,16 +646,13 @@ struct StreamKTilePartitioner
646646
* @brief Get length of loop iterations for stream-k loop
647647
*/
648648
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start,
649-
uint32_t iter_end,
650-
uint32_t total_iter_length) const noexcept
649+
uint32_t iter_end) const noexcept
651650
{
652-
uint32_t iter_length_mod, iter_length_quo /*unused*/;
653-
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
654-
uint32_t total_iter_length_val = static_cast<uint32_t>(total_iter_length);
655-
uint32_t current_iter_length =
656-
min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
657-
total_iter_length_val);
658-
return current_iter_length;
651+
// A WG's iter_end is either in the current C macro tile or not.
652+
// If it is not, then the macro tile boundary is where the WG must stop.
653+
uint32_t distance_to_tile_boundary =
654+
k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get());
655+
return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start;
659656
}
660657

661658
/**
@@ -672,9 +669,7 @@ struct StreamKTilePartitioner
672669
CK_TILE_DEVICE void
673670
GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
674671
{
675-
uint32_t tile_idx_val = static_cast<uint32_t>(tile_idx);
676-
uint32_t iter_offset_val = static_cast<uint32_t>(iter_offset);
677-
k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val);
672+
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
678673
}
679674

680675
/**

include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ struct GroupedGemmKernel
374374
// Create Gemm tensor views, pad views and tile windows
375375
const auto& gemm_tensor_views_tuple =
376376
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
377-
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
377+
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
378378

379379
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
380380
auto gemm_tile_windows =
@@ -436,7 +436,7 @@ struct GroupedGemmKernel
436436
// Create Gemm tensor views, pad views and tile windows
437437
const auto& gemm_tensor_views_tuple =
438438
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
439-
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
439+
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
440440

441441
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
442442
auto gemm_tile_windows =

include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp

Lines changed: 146 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,17 @@ struct StreamKKernel
141141
return UniversalGemmKernel::BlockSize();
142142
}
143143

144-
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args)
144+
/// @brief Constructs kernel arguments for the Stream-K kernel.
145+
/// @param host_args Stream-K host arguments.
146+
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
147+
/// The caller may select their own to assist with test reproducibility, etc.
148+
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
149+
/// select their own to assist with test reproducibility, etc.
150+
/// @return The kernel arguments for Stream-K.
151+
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
152+
int num_cu = NumCU(),
153+
int occupancy = Occupancy())
145154
{
146-
uint32_t occupancy = static_cast<uint32_t>(Occupancy());
147-
uint32_t num_cu = static_cast<uint32_t>(NumCU());
148-
149155
return StreamKKernelArgs{{host_args.as_ptr,
150156
host_args.bs_ptr,
151157
host_args.ds_ptr,
@@ -166,14 +172,71 @@ struct StreamKKernel
166172
TilePartitioner{static_cast<uint32_t>(host_args.M),
167173
static_cast<uint32_t>(host_args.N),
168174
static_cast<uint32_t>(host_args.K),
169-
num_cu,
170-
occupancy,
175+
static_cast<uint32_t>(num_cu),
176+
static_cast<uint32_t>(occupancy),
171177
host_args.num_sk_blocks}};
172178
}
173179

174-
CK_TILE_HOST static bool
175-
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs)
180+
template <bool UseDefaultScheduler = true>
181+
CK_TILE_DEVICE static void
182+
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
183+
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
184+
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
185+
CDataType* c_ptr,
186+
void* smem_ptr_0,
187+
const typename UniversalGemmKernel::KernelArgs& kargs,
188+
const index_t num_loop,
189+
const index_t block_idx_m,
190+
const index_t block_idx_n,
191+
const index_t k_size)
176192
{
193+
// Create Gemm tensor views, pad views and tile windows
194+
const auto& gemm_tensor_views_tuple =
195+
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
196+
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
197+
198+
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
199+
auto gemm_tile_windows =
200+
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
201+
202+
// Run GEMM cooperatively by whole workgroup.
203+
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
204+
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
205+
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
206+
207+
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
208+
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
209+
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
210+
// tail_num.
211+
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
212+
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
213+
214+
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
215+
bs_block_window[UniversalGemmKernel::I0],
216+
num_loop,
217+
has_hot_loop,
218+
tail_num,
219+
smem_ptr_0);
220+
221+
if(UseDefaultScheduler || (get_warp_id() == 0))
222+
{
223+
// Run Epilogue Pipeline
224+
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
225+
226+
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
227+
}
228+
}
229+
230+
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
231+
{
232+
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
233+
{
234+
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
235+
{
236+
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
237+
}
238+
return false;
239+
}
177240
return UniversalGemmKernel::IsSupportedArgument(kargs);
178241
}
179242

@@ -199,9 +262,81 @@ struct StreamKKernel
199262
kargs.workspace_ptr = workspace_ptr;
200263
}
201264

202-
// Temporary placeholder to support the Occupancy() static function.
203-
// Since the Occupancy function uses kentry, this class must have an operator() function
204-
CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {}
265+
/// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop.
266+
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
267+
{
268+
// Allocate LDS
269+
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
270+
271+
uint32_t block_idx = ck_tile::get_block_1d_id();
272+
273+
bool is_padding_block =
274+
__builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
275+
block_idx < kargs.tile_partitioner.dp_start_block_idx);
276+
277+
// Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
278+
// should not partake in the GEMM
279+
if(is_padding_block)
280+
return;
281+
282+
// Determine the K offset of the first and final macro tile in the A and B tensors along the
283+
// K dimension.
284+
uint32_t iter_start, iter_end;
285+
kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
286+
287+
// Main Stream-K loop
288+
while(true)
289+
{
290+
// Determine the number of macro tiles in A and B this WG is resposible for in the
291+
// current C macro tile.
292+
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
293+
kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
294+
295+
// Determine the 1D tile_idx and the iter_offset for this WG.
296+
// The tile_idx is the 1D macro tile index in the C tensor.
297+
// The iter_offset is the starting macro tile index in the K dimension for the WG in the
298+
// current iteration of the while loop.
299+
uint32_t tile_idx, iter_offset;
300+
kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
301+
302+
// Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
303+
auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
304+
305+
// Get the offsets in A, B, C tensors.
306+
index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
307+
TilePartitioner::MPerBlock);
308+
index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
309+
TilePartitioner::NPerBlock);
310+
index_t i_k = static_cast<index_t>(iter_offset) * TilePartitioner::KPerBlock;
311+
312+
// Determine the total size along the K dimension the WG is using in this iteration
313+
// (used to construct tensor views).
314+
index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
315+
316+
// Update pointer offsets for A, B, and C.
317+
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k;
318+
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k;
319+
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
320+
321+
// Run the GEMM pipeline and Epilogue.
322+
RunGemm({a_ptr},
323+
{b_ptr},
324+
{/*ds_ptr*/},
325+
c_ptr,
326+
smem_ptr_0,
327+
kargs,
328+
current_iter_length,
329+
i_m,
330+
i_n,
331+
k_size);
332+
333+
// Prepare for next Stream-K loop iteration.
334+
iter_start += current_iter_length;
335+
if(iter_end <= iter_start)
336+
break;
337+
block_sync_lds();
338+
}
339+
}
205340

206341
private:
207342
CK_TILE_HOST static int NumCU()

include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ struct UniversalGemmKernel
579579
const std::array<const void*, NumDTensor>& ds_ptr,
580580
EDataType* e_ptr,
581581
const KernelArgs& kargs,
582-
const SplitKBatchOffset& splitk_batch_offset)
582+
const index_t k_size)
583583
{
584584
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
585585

@@ -591,7 +591,7 @@ struct UniversalGemmKernel
591591
{
592592
return make_naive_tensor_view<address_space_enum::global>(
593593
static_cast<const AiDataType*>(as_ptr[i]),
594-
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
594+
make_tuple(kargs.M, k_size),
595595
make_tuple(kargs.stride_As[i], 1),
596596
number<GemmPipeline::GetVectorSizeA()>{},
597597
number<1>{});
@@ -600,7 +600,7 @@ struct UniversalGemmKernel
600600
{
601601
return make_naive_tensor_view<address_space_enum::global>(
602602
static_cast<const AiDataType*>(as_ptr[i]),
603-
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
603+
make_tuple(k_size, kargs.M),
604604
make_tuple(kargs.stride_As[i], 1),
605605
number<GemmPipeline::GetVectorSizeA()>{},
606606
number<1>{});
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
617617
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
618618
{
619619
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
620-
const index_t K0 = splitk_batch_offset.splitted_k / K1;
620+
const index_t K0 = k_size / K1;
621621
constexpr index_t VectorSizeB =
622622
std::min(K1, GemmPipeline::GetVectorSizeB());
623623
const auto b_k0_n_k1_desc =
@@ -638,7 +638,7 @@ struct UniversalGemmKernel
638638
{
639639
return make_naive_tensor_view<address_space_enum::global>(
640640
bs_ptr[i],
641-
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
641+
make_tuple(k_size, kargs.N),
642642
make_tuple(kargs.stride_Bs[i], 1),
643643
number<GemmPipeline::GetVectorSizeB()>{},
644644
number<1>{});
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
649649
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
650650
{
651651
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
652-
const index_t K0 = splitk_batch_offset.splitted_k / K1;
652+
const index_t K0 = k_size / K1;
653653
constexpr index_t VectorSizeB =
654654
std::min(K1, GemmPipeline::GetVectorSizeB());
655655
const auto b_k0_n_k1_desc =
@@ -672,7 +672,7 @@ struct UniversalGemmKernel
672672
{
673673
index_t kFlatK =
674674
GemmPipeline::BlockGemmShape::flatKPerWarp *
675-
(splitk_batch_offset.splitted_k /
675+
(k_size /
676676
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
677677
index_t kFlatN = kargs.N * kargs.K / kFlatK;
678678

@@ -687,7 +687,7 @@ struct UniversalGemmKernel
687687
{
688688
return make_naive_tensor_view<address_space_enum::global>(
689689
bs_ptr[i],
690-
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
690+
make_tuple(kargs.N, k_size),
691691
make_tuple(kargs.stride_Bs[i], 1),
692692
number<GemmPipeline::GetVectorSizeB()>{},
693693
number<1>{});
@@ -962,7 +962,7 @@ struct UniversalGemmKernel
962962
// Create Gemm tensor views, pad views and tile windows
963963
const auto& gemm_tensor_views_tuple =
964964
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
965-
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
965+
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
966966

967967
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
968968
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -1018,7 +1018,7 @@ struct UniversalGemmKernel
10181018
// Create Gemm tensor views, pad views and tile windows
10191019
const auto& gemm_tensor_views_tuple =
10201020
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
1021-
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
1021+
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
10221022

10231023
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
10241024
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);

test/ck_tile/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(gemm_weight_preshuffle)
44
add_subdirectory(batched_gemm)
55
add_subdirectory(grouped_gemm)
66
add_subdirectory(gemm_multi_d)
7+
add_subdirectory(gemm_streamk)
78
add_subdirectory(data_type)
89
add_subdirectory(container)
910
add_subdirectory(elementwise)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Currently test_ck_tile_streamk is only built on gfx9
2+
if(GPU_TARGETS MATCHES "gfx9")
3+
#TODO: support all arches
4+
add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp)
5+
else()
6+
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
7+
endif()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
#include "test_gemm_streamk_types.hpp"
5+
#include "test_gemm_streamk_util.hpp"
6+
#include "gtest/gtest.h"
7+
8+
#define TEST_SUITE_NAME TestCkTileStreamK
9+
10+
TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK);
11+
12+
#include "test_gemm_streamk_cases.inc"
13+
14+
#undef TEST_SUITE_NAME

0 commit comments

Comments
 (0)