diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 92ae6411a5..a891d4df55 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -646,16 +646,13 @@ struct StreamKTilePartitioner * @brief Get length of loop iterations for stream-k loop */ CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, - uint32_t iter_end, - uint32_t total_iter_length) const noexcept + uint32_t iter_end) const noexcept { - uint32_t iter_length_mod, iter_length_quo /*unused*/; - k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); - uint32_t total_iter_length_val = static_cast(total_iter_length); - uint32_t current_iter_length = - min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, - total_iter_length_val); - return current_iter_length; + // A WG's iter_end is either in the current C macro tile or not. + // If it is not, then the macro tile boundary is where the WG must stop. + uint32_t distance_to_tile_boundary = + k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get()); + return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start; } /** @@ -672,9 +669,7 @@ struct StreamKTilePartitioner CK_TILE_DEVICE void GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept { - uint32_t tile_idx_val = static_cast(tile_idx); - uint32_t iter_offset_val = static_cast(iter_offset); - k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val); + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); } /** diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 704d0d01ee..dda38bbc47 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -374,7 +374,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -436,7 +436,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 77c431e49c..5df1f092d7 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -141,11 +141,17 @@ struct StreamKKernel return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) + /// @brief Constructs kernel arguments for the Stream-K kernel. + /// @param host_args Stream-K host arguments. + /// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device. + /// The caller may select their own to assist with test reproducibility, etc. + /// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may + /// select their own to assist with test reproducibility, etc. + /// @return The kernel arguments for Stream-K. + CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args, + int num_cu = NumCU(), + int occupancy = Occupancy()) { - uint32_t occupancy = static_cast(Occupancy()); - uint32_t num_cu = static_cast(NumCU()); - return StreamKKernelArgs{{host_args.as_ptr, host_args.bs_ptr, host_args.ds_ptr, @@ -166,14 +172,71 @@ struct StreamKKernel TilePartitioner{static_cast(host_args.M), static_cast(host_args.N), static_cast(host_args.K), - num_cu, - occupancy, + static_cast(num_cu), + static_cast(occupancy), host_args.num_sk_blocks}}; } - CK_TILE_HOST static bool - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) + template + CK_TILE_DEVICE static void + RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const typename UniversalGemmKernel::KernelArgs& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t k_size) { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + UniversalGemmKernel::template MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); + + const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); + const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); + const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + + // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute + // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this + // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and + // tail_num. + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], + bs_block_window[UniversalGemmKernel::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs) + { + if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy."); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } @@ -199,9 +262,81 @@ struct StreamKKernel kargs.workspace_ptr = workspace_ptr; } - // Temporary placeholder to support the Occupancy() static function. - // Since the Occupancy function uses kentry, this class must have an operator() function - CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {} + /// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop. + CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const + { + // Allocate LDS + __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; + + uint32_t block_idx = ck_tile::get_block_1d_id(); + + bool is_padding_block = + __builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks && + block_idx < kargs.tile_partitioner.dp_start_block_idx); + + // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they + // should not partake in the GEMM + if(is_padding_block) + return; + + // Determine the K offset of the first and final macro tile in the A and B tensors along the + // K dimension. + uint32_t iter_start, iter_end; + kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end); + + // Main Stream-K loop + while(true) + { + // Determine the number of macro tiles in A and B this WG is resposible for in the + // current C macro tile. + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end)); + + // Determine the 1D tile_idx and the iter_offset for this WG. + // The tile_idx is the 1D macro tile index in the C tensor. + // The iter_offset is the starting macro tile index in the K dimension for the WG in the + // current iteration of the while loop. + uint32_t tile_idx, iter_offset; + kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset); + + // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx) + auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx); + + // Get the offsets in A, B, C tensors. + index_t i_m = static_cast(spatial_idx[UniversalGemmKernel::I0] * + TilePartitioner::MPerBlock); + index_t i_n = static_cast(spatial_idx[UniversalGemmKernel::I1] * + TilePartitioner::NPerBlock); + index_t i_k = static_cast(iter_offset) * TilePartitioner::KPerBlock; + + // Determine the total size along the K dimension the WG is using in this iteration + // (used to construct tensor views). + index_t k_size = static_cast(current_iter_length * TilePartitioner::KPerBlock); + + // Update pointer offsets for A, B, and C. + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + i_k; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // Run the GEMM pipeline and Epilogue. + RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr_0, + kargs, + current_iter_length, + i_m, + i_n, + k_size); + + // Prepare for next Stream-K loop iteration. + iter_start += current_iter_length; + if(iter_end <= iter_start) + break; + block_sync_lds(); + } + } private: CK_TILE_HOST static int NumCU() diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 8117d65758..cfba8b6c9d 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -579,7 +579,7 @@ struct UniversalGemmKernel const std::array& ds_ptr, EDataType* e_ptr, const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + const index_t k_size) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); @@ -591,7 +591,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -600,7 +600,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -617,7 +617,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -638,7 +638,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(k_size, kargs.N), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -649,7 +649,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -672,7 +672,7 @@ struct UniversalGemmKernel { index_t kFlatK = GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / + (k_size / TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; @@ -687,7 +687,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.N, k_size), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -962,7 +962,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -1018,7 +1018,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 993df2ec40..32230bbce2 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(gemm_weight_preshuffle) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_streamk) add_subdirectory(data_type) add_subdirectory(container) add_subdirectory(elementwise) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..e00874ba07 --- /dev/null +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -0,0 +1,7 @@ +# Currently test_ck_tile_streamk is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + #TODO: support all arches + add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp) +else() + message(DEBUG "Skipping test_ck_tile_streamk tests for current target") +endif() diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp new file mode 100644 index 0000000000..99c3fb397f --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp @@ -0,0 +1,14 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_types.hpp" +#include "test_gemm_streamk_util.hpp" +#include "gtest/gtest.h" + +#define TEST_SUITE_NAME TestCkTileStreamK + +TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK); + +#include "test_gemm_streamk_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc new file mode 100644 index 0000000000..1db7ef0fb0 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc @@ -0,0 +1,118 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 4; + + this->Run(M, N, K, num_sk_blocks); +} + +// TODO: Renable this test once reduction is implemented +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks12) +{ + GTEST_SKIP() << "Skipping this test: There are precision issues with atomics due to >=3 WGs " + "contributing to each macro tile in C"; + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 12; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 16; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks64) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction), + std::runtime_error); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp new file mode 100644 index 0000000000..399f3f11e8 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -0,0 +1,25 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypesStreamK = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType + std::tuple< Row, Col, Row, F16, F16, F32, F16>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16> +>; + +// clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp new file mode 100644 index 0000000000..b8a55b024d --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -0,0 +1,282 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are + // resolved. Because the number of WGs contributing to a macro tile in C may not be the same for + // all macro tiles in C. + + // Calculate error due to more than 1 WG contributing to the same macro tile in C + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileStreamK : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + + template + void invoke_streamk(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s, + int num_cu, + int occupancy) + { + + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool kPadM = PadM; + constexpr bool kPadN = PadN; + constexpr bool kPadK = PadK; + constexpr bool preshuffle = Preshuffle; + + constexpr bool DoubleSmemBuffer = false; + constexpr int kBlockPerCu = 1; + constexpr bool StructuredSparsity = false; + constexpr bool NumWaveGroup = 1; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args, num_cu, occupancy); + + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } + + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + }; + + Run(ck_tile::integral_constant{}); + } + + public: + // Since Stream-K is build on gfx9, the lower bound for CUs is 104. Thus, we default num_cu to + // 104 and occupancy to 1 to ensure tests are reproducible on different architectures. + void Run(ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + uint32_t num_sk_blocks = 0xffffffff, + ck_tile::StreamKReductionStrategy reduction_strategy = + ck_tile::StreamKReductionStrategy::Atomic, + int occupancy = 1, + int num_cu = 104, + ck_tile::index_t stride_A = 0, + ck_tile::index_t stride_B = 0, + ck_tile::index_t stride_C = 0) + { + + using namespace ck_tile::literals; + + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + throw std::runtime_error("Reduction Strategy is current unsupported!\n"); + } + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); + stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11940}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy, + num_sk_blocks}; + + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, /*kbatch*/ 1, max_accumulated_value); + + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass); + }; +};