diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 1641549c98..d2ad442248 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -14,6 +14,7 @@ if(has_supported_gpu) add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp) add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp) add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp) + add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only set(EXAMPLE_FLATMM_COMPILE_OPTIONS) set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS) @@ -27,6 +28,6 @@ if(has_supported_gpu) target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - + target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only endif() diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp new file mode 100644 index 0000000000..0474e4b1d6 --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "mx_flatmm.hpp" + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) +{ + using CodegenFlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits; + + using ComputeDataType = ADataType; + static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), + "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); + + using GemmPipelineProblem = ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern + + using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem; + + using CodegenMXFlatmmPipeline = + ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = + ck_tile::MXFlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" + << "Shape: " << CodegenFlatmmShape::GetName() << "\n" + << "problem: " << CodegenPipelineProblem::GetName() << "\n" + << "pipeline: " << CodegenMXFlatmmPipeline::GetName() << "\n" + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits::PackedSize; + constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits::PackedSize; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + rotating_mem_ptr = std::make_unique>( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } + + ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + ScaleA scale_a, + ScaleB scale_b, + int n_warmup, + int n_repeat) +{ + ck_tile::ScaleFlatmmHostArgs args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_a, + scale_b}; + + float ave_time = mx_flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + + constexpr int APackedSize = ck_tile::numeric_traits::PackedSize; + constexpr int BPackedSize = ck_tile::numeric_traits::PackedSize; + + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / 32; + std::size_t num_byte = sizeof(ADataType) * M * K / APackedSize + + sizeof(BDataType) * N * K / BPackedSize + sizeof(CDataType) * M * N + + sizeof(ck_tile::e8m0_t) * M * K / 32 + + sizeof(ck_tile::e8m0_t) * N * K / 32; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run MXFP4_Flatmm kernel " // + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "32", "m dimension") + .insert("n", "128", "n dimension") + .insert("k", "256", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert( + "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:constant(1)") + .insert("persistent", "0", "0: no persistent, 1: persistent kernel") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K) +{ + int KPack = 16; + int NLane = FlatmmConfig::N_Warp_Tile; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +template +auto preShuffleScale(Src& src) +{ + using dtype = typename Src::Data::value_type; + auto src_lengths = src.get_lengths(); + const auto MN = KLast ? src_lengths[0] : src_lengths[1]; + const auto K = KLast ? src_lengths[1] : src_lengths[0]; + + size_t MNXdlPack = 2; + size_t KXdlPack = 2; + size_t XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16 + size_t XdlKThread = 64 / XdlMNThread; + + const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack); + + ck_tile::HostTensor shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1})); + + size_t K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(size_t n = 0; n < MN_Paded; ++n) + { + for(size_t k = 0; k < K; ++k) + { + auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + auto tempn = n % (XdlMNThread * MNXdlPack); + auto n1 = tempn % XdlMNThread; // i XdlMNThread + auto n2 = tempn / XdlMNThread; // i MNXdlPack + + auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat + auto tempk = k % (XdlKThread * KXdlPack); + auto k1 = tempk % XdlKThread; // i XdlKThread + auto k2 = tempk / XdlKThread; // i KXdlPack + + auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + + if constexpr(KLast) + shuffled(outputIndex) = n < MN ? src(n, k) : dtype{}; + else + shuffled(outputIndex) = n < MN ? src(k, n) : dtype{}; + } + } + return shuffled; +} + +#include "run_mx_flatmm.inc" + +template +int run_mx_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string mx_prec = arg_parser.get_str("mx_prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + int persistent_opt = arg_parser.get_int("persistent"); + + if(a_layout == "R" && b_layout == "C") + { + if(mx_prec == "fp4xfp4") + { + if(persistent_opt == 0) + { + run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + } + else if(mx_prec == "fp6xfp6") + { + throw std::runtime_error("Only support fp4xfp4 now!"); + } + else if(mx_prec == "fp8xfp8") + { + throw std::runtime_error("Only support fp4xfp4 now!"); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } + return -1; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + try + { + int warp_tile = arg_parser.get_int("warp_tile"); + if(warp_tile == 0) + { + return !run_mx_flatmm_example(argc, argv); + } + else if(warp_tile == 1) + { + throw std::runtime_error("Only support MFMA_16x16x128 now!"); + } + else + { + throw std::runtime_error("Unsupported warp_tile!"); + } + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp new file mode 100644 index 0000000000..b47d3a95ab --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "mxfp4_flatmm.hpp" diff --git a/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp new file mode 100644 index 0000000000..4ef627969c --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp @@ -0,0 +1,40 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +// GEMM config with 16x16 warp tile +struct MXfp4_FlatmmConfig16 +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 512; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr int TileParitionerGroupNum = 8; + static constexpr int TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = false; +}; diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc new file mode 100644 index 0000000000..bc24427780 --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +template +int run_mx_flatmm_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using ADataType = PrecActType; + using BDataType = PrecWeightType; + using AccDataType = float; + + using ScaleType = ck_tile::e8m0_t; + + constexpr int ScaleGranularityM = 1; + constexpr int ScaleGranularityN = 1; + constexpr int ScaleGranularityK = 32; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + ck_tile::index_t n_warmup = arg_parser.get_int("warmup"); + ck_tile::index_t n_repeat = arg_parser.get_int("repeat"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout)); + + auto scale_stride_A = ck_tile::get_default_stride( + M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout)); + auto scale_stride_B = ck_tile::get_default_stride( + K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout)); + + if(K % ScaleGranularityK != 0) + throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK."); + if(K % ck_tile::numeric_traits::PackedSize != 0 || + K % ck_tile::numeric_traits::PackedSize != 0) + throw std::runtime_error("wrong! K must be multiple of packed size."); + + ck_tile ::HostTensor a_host( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_origin_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_rslt_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + ck_tile::HostTensor scale_a(ck_tile::host_tensor_descriptor( + M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout))); + ck_tile::HostTensor scale_b(ck_tile::host_tensor_descriptor( + K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_a); + ck_tile::FillUniformDistribution{-2.f, 2.f}(scale_b); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_origin_host); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_a); + ck_tile::FillUniformDistribution{1.f, 1.f}(scale_b); + } + else + { + throw std::runtime_error("wrong! Unexpected init_method"); + } + + ck_tile::HostTensor b_shuffled_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + preShuffleWeight(b_origin_host.begin(), b_shuffled_host.begin(), N, K); + + const auto scale_a_shuffled = preShuffleScale(scale_a); + const auto scale_b_shuffled = preShuffleScale(scale_b); + + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes()); + ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes()); + + a_dev_buf.ToDevice(a_host.data()); + b_shuffled_dev_buf.ToDevice(b_shuffled_host.data()); + c_rslt_host.SetZero(); + scale_a_dev_buf.ToDevice(scale_a_shuffled.data()); + scale_b_dev_buf.ToDevice(scale_b_shuffled.data()); + + auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer{ + static_cast(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM}; + auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer{ + static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; + + invoke_mx_flatmm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout, + decltype(scale_a_dev_ptr), + decltype(scale_b_dev_ptr), + UsePersistentKernel>(a_dev_buf, + b_shuffled_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + scale_a_dev_ptr, + scale_b_dev_ptr, + n_warmup, + n_repeat); + + c_dev_buf.FromDevice(c_rslt_host.data()); + + bool pass = true; + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_mx_gemm( + a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b); + + const float rtol = std::is_same_v ? 1e-3 : 1e-2; + const float atol = std::is_same_v ? 1e-3 : 1e-2; + + pass = ck_tile::check_err( + c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol); + + std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 0538bf3dd7..ff7c55fdc9 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -382,6 +382,93 @@ reference_gemm_multiple_abd(const std::array, AsDataType:: make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); } +template +CK_TILE_HOST void reference_mx_gemm(const HostTensor& a_m_k, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const HostTensor& scale_a, + const HostTensor& scale_b, + const AElementOp& = {}, + const BElementOp& = {}, + const ACCElementOp& = {}) +{ + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + const std::size_t ScaleBlockSize = K / scale_a.get_length(1); + + HostTensor a_m_k_scaled({std::size_t(M), std::size_t(K)}, + {std::size_t(K), std::size_t(1)}); + HostTensor b_k_n_scaled({std::size_t(K), std::size_t(N)}, + {std::size_t(1), std::size_t(K)}); + + for(std::size_t m = 0; m < M; ++m) + { + for(std::size_t k = 0; k < K; ++k) + { + if constexpr(std::is_same_v) + { + if(k % 2 == 1) + continue; // skip odd k + + auto a_f4x2 = a_m_k(m, k); + auto a_scale = ck_tile::type_convert(scale_a(m, k / ScaleBlockSize)); + auto a_f4_lo = + ck_tile::type_convert(a_f4x2.template unpack<>(number<0>{})); + auto a_f4_hi = + ck_tile::type_convert(a_f4x2.template unpack<>(number<1>{})); + + a_m_k_scaled(m, k) = a_f4_lo * a_scale; + a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale; + } + } + } + + for(std::size_t n = 0; n < N; n++) + { + for(std::size_t k = 0; k < K; k++) + { + if constexpr(std::is_same_v) + { + if(k % 2 == 1) + continue; // skip odd k + + auto b_f4x2 = b_k_n(k, n); + auto b_scale = ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); + auto b_f4_lo = + ck_tile::type_convert(b_f4x2.template unpack<>(number<0>{})); + auto b_f4_hi = + ck_tile::type_convert(b_f4x2.template unpack<>(number<1>{})); + + b_k_n_scaled(k, n) = b_f4_lo * b_scale; + b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale; + } + else + { + b_k_n_scaled(k, n) = + ck_tile::type_convert((b_k_n(k, n))) * + ck_tile::type_convert(scale_b(k / ScaleBlockSize, n)); + } + } + } + + // call reference gemm + reference_gemm( + a_m_k_scaled, b_k_n_scaled, c_m_n); +} + template +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" + +namespace ck_tile { + +template +struct MXFlatmmKernel : FlatmmKernel +{ + using Underlying = FlatmmKernel; + + using TilePartitioner = remove_cvref_t; + using FlatmmPipeline = remove_cvref_t; + using BlockGemmShape = + remove_cvref_t; // TileFlatmmShape + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + using DsDataType = remove_cvref_t; + static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using EDataType = remove_cvref_t; + + static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{}); + static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{}); + static constexpr int KThreadPerXdl = 64 / MThreadPerXdl; + + static constexpr int APackedSize = numeric_traits::PackedSize; + static constexpr int BPackedSize = numeric_traits::PackedSize; + + static constexpr int MXdlPack = FlatmmPipeline::MXdlPack; + static constexpr int NXdlPack = FlatmmPipeline::NXdlPack; + static constexpr int KXdlPack = FlatmmPipeline::KXdlPack; + + static constexpr index_t NumDTensor = DsDataType::size(); + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + static constexpr auto I4 = number<4>(); + static constexpr auto I5 = number<5>(); + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + // using KernelArgs = FlatmmKernelArgs; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "mx_flatmm_gemm", gemm_prec_str, FlatmmPipeline::GetName()); + // clang-format on + } + + template + CK_TILE_HOST static constexpr auto + GridSize(const FlatmmKernelArgs& kargs) + { + if constexpr(UsePersistentKernel) + { + hipDeviceProp_t prop; + int deviceId = 0; // default device + + constexpr int block_size = MXFlatmmKernel::BlockSize().x; + int dync_smem_size = 0; + int maxActiveBlocksPerCU = 0; + + if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess) + throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") + + hipGetErrorName(hipGetLastError())); + + if(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &maxActiveBlocksPerCU, + reinterpret_cast( + kentry<1, MXFlatmmKernel, remove_cvref_t>), + block_size, + dync_smem_size) != hipSuccess) + throw std::runtime_error( + std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") + + hipGetErrorName(hipGetLastError())); + + const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU; + const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); + + // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU + // << ", persistent_block_size: " << persistent_block_size + // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl; + + if(kargs.k_batch != 1) + throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel"); + return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch); + } + else + { + return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch); + } + } + + using SplitKBatchOffset = typename Underlying::SplitKBatchOffset; + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + const auto& b_flat_tensor_view = [&]() { + return make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // TODO: enable vector write for C in ColMajor + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_E, 1), + number<1>{}, + number<1>{}); + } + }(); + + auto scale_a = kargs.scale_m_ptr; + auto scale_b = kargs.scale_n_ptr; + + static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK; + const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl)); + const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // A scale tensor view + const auto& scale_a_tensor_view = [&]() { + // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view( + reinterpret_cast(scale_a.ptr), scale_a_desc); + }(); + + // B scale tensor view + const auto& scale_b_tensor_view = [&]() { + const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); + const auto scale_b_desc = transform_tensor_descriptor( + scale_b_navie_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view( + reinterpret_cast(scale_b.ptr), scale_b_desc); + }(); + + return make_tuple(a_tensor_view, + b_flat_tensor_view, + ds_tensor_view, + e_tensor_view, + scale_a_tensor_view, + scale_b_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& b_flat_tensor_view = views.at(I1); + + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // TODO vector write in for C in ColMajor + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple( + a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5)); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_flat_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& b_flat_block_window = + make_tile_window(b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + static constexpr int BlockScaleSize = 32; + + auto scale_a_block_window = make_tile_window( + views.at(I4), + make_tuple(number{}, + number{}), + {i_m / MXdlPack, 0}); + + auto scale_b_block_window = make_tile_window( + views.at(I5), + make_tuple(number{}, + number{}), + {i_n / NXdlPack, 0}); + + return make_tuple(a_block_window, + b_flat_block_window, + ds_block_window, + e_block_window, + scale_a_block_window, + scale_b_block_window); + } + + template + CK_TILE_DEVICE static void + RunFlatmm(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_ping, + void* smem_ptr_pong, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + const auto& scale_a_block_window = gemm_tile_windows.at(I4); + const auto& scale_b_block_window = gemm_tile_windows.at(I5); + + static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK + || ScaleM::GranularityMN == -1 // or ScaleA is disable + || ScaleN::GranularityMN == -1, // or ScaleB is disable + "ScaleM and ScaleN should have the same GranularityK"); + constexpr bool DoEpiScale = + (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token + (ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel + + auto a_block_window_with_distr = + ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(), + a_block_window.get_window_lengths(), + a_block_window.get_window_origin(), + FlatmmPipeline::GetADramTileDistribution()); + const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr, + b_flat_block_window, + scale_a_block_window, + scale_b_block_window, + num_loop, + smem_ptr_ping, + smem_ptr_pong); + + // Run Epilogue Pipeline + if constexpr(DoEpiScale) + { + auto& c_block_window = gemm_tile_windows.at(I3); + EpiloguePipeline{}(c_block_window, + c_block_tile, + d_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + } + } + + template + CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs, + int partition_idx = blockIdx.x) const + { + int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N); + + do + { + const auto [iM, iN] = + TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr) + + splitk_batch_offset.a_k_split_offset / APackedSize; + const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + + splitk_batch_offset.b_k_split_offset / BPackedSize; + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()]; + __shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()]; + + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); + RunFlatmm(a_ptr, + b_flat_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_ping, + smem_ptr_pong, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + static_assert(false, + "Unimplemented: atomic_add with odd vector size for fp16/bf16"); + } + partition_idx += gridDim.x; + } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index da5b8102dc..d3da488a88 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -291,10 +291,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t APackedSize = numeric_traits::PackedSize; + if constexpr(std::is_same_v) { - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t M0 = MPerBlock / M1; + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; + constexpr index_t M0 = MPerBlock / M1; constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); constexpr index_t K3 = total_pixels / M1; @@ -331,7 +333,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } else { - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; constexpr index_t K0 = KPerBlock / K1; // coalesce reading for each blocks if constexpr(get_warp_size() % K0 == 0) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..fbed495d25 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -0,0 +1,1330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" + +namespace ck_tile { + +template +struct MXFlatmmPipelineProblem : FlatmmPipelineProblem +{ + using BlockGemmShape = BlockGemmShape_; + + // using QuantType = BDataType_; + + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr int ScaleGranularityK = 32; + + static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4 + static constexpr int MXdlPack = 2; // it's fixed for fp4 + static constexpr int NXdlPack = 2; // it's fixed for fp4 + static constexpr int KXdlPack = 2; + // static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack; + static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread; +}; + +template +struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 +{ + using Underlying = FlatmmPipelineAGmemBGmemCRegV1; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ComputeType = ADataType; + static_assert(sizeof(ADataType) >= sizeof(BDataType)); + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockFlatmm = + remove_cvref_t())>; + + static constexpr auto config = + BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 + static constexpr index_t DsReadPreload = 4; // default 4 for MXFP4 (MXdlPack * KXdlPack) + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = Problem::flatKPerWarp; + static constexpr index_t flatNPerWarp = Problem::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/ + static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/ + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + // static constexpr index_t kLdsAlignmentInBytes = 16; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + + static constexpr index_t MXdlPack = Problem::MXdlPack; + static constexpr index_t NXdlPack = Problem::NXdlPack; + static constexpr index_t KXdlPack = Problem::KXdlPack; + + static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; + static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + + static constexpr index_t mfma_per_wg = 1; // 950 only + + static constexpr index_t dsread_per_wg = + WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize; + static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) % + Problem::VectorLoadSize == + 0); + + static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp; + static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); + static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; + static constexpr index_t Aload_num_perK = dswrite_num_perK; + static constexpr index_t Aload_rep = dswrite_rep; + + static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize; + static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4 + static constexpr index_t ScaleBload_num = + kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / WaveSize; + static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num; + static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; + static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; + + static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg; + static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; + static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + + CK_TILE_HOST_DEVICE static constexpr auto + SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) + { + // Init inst order + index_t max_data_inst = dsread_perM > load_perM + ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) + : (load_perM > dswrite_perM ? load_perM : dswrite_perM); + index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM; + index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK; + + index_t inst_order[NIterPerWarp * 10]; + _Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; } + + index_t index = 0; + _Pragma("unroll") for(int j = 0; j < max_data_inst; j++) + { + if(dswrite_perM > j) + { + inst_order[index] = 1; + index++; + } + if(load_perM > j) + { + inst_order[index] = 2; + index++; + } + if(dsread_perM > j) + { + inst_order[index] = 3; + index++; + } + } + + // Schedule IGLP + _Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++) + { + index_t inst_idx = 0; + if(j == 0) + ; + else if(j == 1) + inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2; + else if(j == 2) + inst_idx = mfma_perM_perK - 1; + else + inst_idx = mfma_perM_perK - j; + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + _Pragma("unroll") for(int r = 0; r < round_data_inst; r++) + { + if(r % 2 == 0) + { + if(inst_order[inst_idx + r * mfma_perM_perK] == 1) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + else + { + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + } + } + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N0: 57 - 8 - - + // -1 M6N1: 58 1 - - - + // -1 M6N2: 59 - - 7 - + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 3 - - - + // -1 M7N2: 63 - - 8 - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - 1 + // 0 M0N1: 2 5 - - - + // 0 M0N2: 3 - - - 2 + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - 3 + // 0 M1N1: 6 7 - - - + // 0 M1N2: 7 - - - 4 + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - 5 + // 0 M2N1: 10 9 - - - + // 0 M2N2: 11 - - - 6 + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - 7 + // 0 M3N1: 14 11 - - - + // 0 M3N2: 15 - - - 8 + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 13 - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 15 - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 17 - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 18 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 19 - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 20 - - - + // 0 M0N0K1: 33 - - - 9 + // 0 M0N1: 34 21 - - - + // 0 M0N2: 35 - - - 10 + // 0 M0N3: 36 22 - - - + // 0 M1N0: 37 - - - 11 + // 0 M1N1: 38 23 - - - + // 0 M1N2: 39 - - - 12 + // 0 M1N3: 40 24 - - - + // 0 M2N0: 41 - - - 13 + // 0 M2N1: 42 25 - - - + // 0 M2N2: 43 - - - 14 + // 0 M2N3: 44 26 - - - + // 0 M3N0: 45 - 5 - 15 + // 0 M3N1: 46 27 - - - + // 0 M3N2: 47 - - - 16 + // 0 M3N3: 48 28 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 29 - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 30 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 31 - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 32 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 1 - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 3 - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - + + _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; + + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep + : 0) + + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep + : 0); + } + else + { + load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 + ? Aload_rep + : 0; + } + // if((kIter % KPerScaleLoad == 0) && (mIter == 0)) + // { + // load_perM = load_perM + 1; + // } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + // Add Aload when Aload data > needed + if(Aload_num_perK == 0) + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler() + { + _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; + + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep + : 0); + } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler() + { + _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + dsread_perM = dsread_per_wg; + + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + // __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetADramTileDistribution() + { + return PipelinePolicy::template MakeADramTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { +#ifndef __gfx950__ + static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now."); +#endif + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + const index_t iMWarp = get_warp_id() / NWarp; + // const index_t iNWarp = get_warp_id() % NWarp; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + auto packed_m_idx = mIter / number{}; + auto packed_m_rank = mIter % number{}; + + move_tile_window( + a_warp_windows_ping(mIter)(kIter), + {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM, + kIter * KPerBlockPerIter}); + move_tile_window( + a_warp_windows_pong(mIter)(kIter), + {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM, + kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + // Acc register tile + auto c_block_tile = block_flatmm.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution(); + + auto b_flat_dram_window = make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window)); + // use v4i32 as the data type between basicblock to avoid unpack and repack operation. + using V4UInt_B_Buffer = thread_buffer; + union UnionBuf + { + V4UInt_B_Buffer u = 0; + MXFP4_B_Buffer mxfp4; + } ub; + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_ping; + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_pong; + + // pingpong buffer for Scale A and Scale B + auto scale_a_dram_window = make_tile_window( + scale_a_window.get_bottom_tensor_view(), + make_tuple(number{}, number<64 / WG::kM>{}), + scale_a_window.get_window_origin(), + PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution()); + + auto scale_b_dram_window = make_tile_window( + scale_b_window.get_bottom_tensor_view(), + make_tuple(number{}, number<64 / WG::kN>{}), + scale_b_window.get_window_origin(), + PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution()); + + // ping pong buffer for scale A + statically_indexed_array< + statically_indexed_array, + MIterPerWarp / MXdlPack> + scale_a_dram_windows; + statically_indexed_array, + MIterPerWarp / MXdlPack> + scale_a_tile_tensor_ping; + statically_indexed_array, + MIterPerWarp / MXdlPack> + scale_a_tile_tensor_pong; + + // ping pong buffer for scale B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp / NXdlPack> + scale_b_dram_windows; + statically_indexed_array, + NIterPerWarp / NXdlPack> + scale_b_tile_tensor_ping; + statically_indexed_array, + NIterPerWarp / NXdlPack> + scale_b_tile_tensor_pong; + + // HEAD + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = ub.u; + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter}); + + // prefetch Scale A + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + // move Scale A window to next K + move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + // prefetch Scale B + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + // move Scale B window to next K + move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + // A_Lds_TileDist may differ with ADramTileDistribution + auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_transformed); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + using MXFP4_A_Buffer_ping = + decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))); + // use v4i32 as the data type between basicblock to avoid unpack and repack operation. + using V4UInt_A_Buffer = thread_buffer; + union UnionBuf_A_ping + { + V4UInt_A_Buffer u = 0; + MXFP4_A_Buffer_ping mxfp4; + } ua_ping; + + using MXFP4_A_Buffer_pong = + decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))); + union UnionBuf_A_pong + { + V4UInt_A_Buffer u = 0; + MXFP4_A_Buffer_pong mxfp4; + } ua_pong; + + // preload A00,A10... from lds + statically_indexed_array a_warp_tensor; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + + ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number{})(number{})); + a_warp_tensor(loadIter) = ua_ping.u; + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) + { + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_pong(nIter)(kIter) = ub.u; + }); + }); + + // prefetch Scale A and Scale B (2i+1) + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_transformed); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + UnionBuf_A_ping ua_compute; + ua_compute.u = a_warp_tensor(number{}); + + UnionBuf ub_compute; + ub_compute.u = + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl); + // warp GEMM + WG{}.template + operator()( + c_warp_tensor, + ua_compute.mxfp4, + ub_compute.mxfp4, + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + ua_ping.mxfp4 = load_tile( + a_warp_windows_ping(number{})(number{})); + a_warp_tensor(number{}) = ua_ping.u; + } + + // barrier + if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && + mIter_pack * MXdlPack + imxdl == MIter_2nd_last) + { + block_sync_lds(); + } + }); + }); + }); + }); + }); + + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number{})(number{})); + a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer + }); + HotLoopScheduler(); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_ping(nIter)(kIter) = ub.u; + }); + }); + + // prefetch Scale A and Scale B (2i+2) + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + + // Prefill A(2i+2) + a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_transformed); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + UnionBuf_A_pong ua_compute; + ua_compute.u = a_warp_tensor(number{}); + + UnionBuf ub_compute; + ub_compute.u = + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl); + + // warp GEMM + WG{}.template + operator()( + c_warp_tensor, + ua_compute.mxfp4, + ub_compute.mxfp4, + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + ua_pong.mxfp4 = load_tile( + a_warp_windows_pong(number{})(number{})); + a_warp_tensor(number{}) = ua_pong.u; + } + + // barrier + if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && + mIter_pack * MXdlPack + imxdl == MIter_2nd_last) + { + block_sync_lds(); + } + }); + }); + }); + }); + }); + + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number{})(number{})); + a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer + }); + HotLoopScheduler(); + + iCounter--; + } + + // TAIL + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + auto packed_n_idx = nIter / number{}; + auto packed_n_rank = nIter % number{}; + + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window( + b_flat_dram_windows(nIter)(kIter), + {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank, + kIter * KFlatPerBlockPerIter}); + + ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter)); + b_warp_tensor_pong(nIter)(kIter) = ub.u; + }); + }); + + // prefetch Scale A and Scale B (2i+1) + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + + // Prefill A(loopK) + a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_transformed); + + // GEMM loopK-1 + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + UnionBuf_A_ping ua_compute; + ua_compute.u = a_warp_tensor(number{}); + + UnionBuf ub_compute; + ub_compute.u = + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl); + + // warp GEMM + WG{}.template + operator()( + c_warp_tensor, + ua_compute.mxfp4, + ub_compute.mxfp4, + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + ua_ping.mxfp4 = load_tile( + a_warp_windows_ping(number{})(number{})); + a_warp_tensor(number{}) = ua_ping.u; + } + + // barrier + if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && + mIter_pack * MXdlPack + imxdl == MIter_2nd_last) + { + block_sync_lds(); + } + }); + }); + }); + }); + }); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number{})(number{})); + a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer + }); + + Last2ndHotLoopScheduler(); + + // GEMM loopK + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + UnionBuf_A_pong ua_compute; + ua_compute.u = a_warp_tensor(number{}); + + UnionBuf ub_compute; + ub_compute.u = + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl); + // warp GEMM + WG{}.template + operator()( + c_warp_tensor, + ua_compute.mxfp4, + ub_compute.mxfp4, + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + ua_pong.mxfp4 = load_tile( + a_warp_windows_pong(number{})(number{})); + a_warp_tensor(number{}) = ua_pong.u; + } + + // barrier + if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && + mIter_pack * MXdlPack + imxdl == MIter_2nd_last) + { + block_sync_lds(); + } + }); + }); + }); + }); + }); + LastHotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + UnionBuf_A_ping ua_compute; + ua_compute.u = a_warp_tensor(number{}); + + UnionBuf ub_compute; + ub_compute.u = + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl); + + // warp GEMM + WG{}.template + operator()( + c_warp_tensor, + ua_compute.mxfp4, + ub_compute.mxfp4, + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + ua_ping.mxfp4 = load_tile( + a_warp_windows_ping(number{})(number{})); + a_warp_tensor(number{}) = ua_ping.u; + } + + // barrier + if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 && + mIter_pack * MXdlPack + imxdl == MIter_2nd_last) + { + block_sync_lds(); + } + }); + }); + }); + }); + }); + LastHotLoopScheduler(); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp, + const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType & a) { return a; }, + b_flat_dram_block_window_tmp, + scale_a_flat_window_tmp, + scale_b_flat_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp new file mode 100644 index 0000000000..f3fc5e9fef --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -0,0 +1,275 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" + +namespace ck_tile { + +struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr index_t KBPerLoad = 32; + + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() + { + using namespace ck_tile; + + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + + static_assert(MPerXdl == 16 && NPerXdl == 16); + static_assert(std::is_same_v); + + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t KPack = GetSmemPackA() * APackedSize; + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // return a_lds_block_desc_permuted; + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16"); + static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); + + constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); + constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); + constexpr int M_Lane = TileShape::WarpTile::at(I0); + + constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4 + + constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32 + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<2>, + sequence<1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple, + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<2>>, // which direction + tuple, sequence<1>>, // which index + // + sequence<2>, + sequence<2>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0); + + constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); + constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); + + static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); + + constexpr index_t M_Lanes = TileShape::WarpTile::at(I0); + constexpr index_t K_Lanes = 64 / M_Lanes; + + // Y dimension (M) decomposition + constexpr index_t Y2 = M_Lanes; + constexpr index_t Y1 = M_Warps; + constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2); + + // X dimension (K) decomposition + constexpr index_t X0 = K_Lanes; + constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load + + return make_static_tile_distribution( + tile_distribution_encoding, // repeat N_warps + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1); + + constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); + constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); + + static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); + + constexpr index_t N_Lanes = TileShape::WarpTile::at(I1); + constexpr index_t K_Lanes = 64 / N_Lanes; + + // Y dimension (M) decomposition + constexpr index_t Y2 = N_Lanes; + constexpr index_t Y1 = N_Warps; + constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2); + + // X dimension (K) decomposition + constexpr index_t X0 = K_Lanes; + constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load + + return make_static_tile_distribution( + tile_distribution_encoding, // ? + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{}); + constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0); + constexpr index_t M_Lane = TileShape::WarpTile::at(I0); + constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{}); + constexpr index_t MWavePerBlk = M_Warp; + + return make_static_tile_distribution( + tile_distribution_encoding, // ? + tuple, // second direction + sequence>, // first direction + tuple, sequence<2, 1>>, // which direction + tuple, sequence<0, 1>>, // which index + // + sequence<2>, + sequence<1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{}); + constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1); + constexpr index_t N_Lane = TileShape::WarpTile::at(I1); + constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{}); + constexpr index_t NWavePerBlk = N_Warp; + + return make_static_tile_distribution( + tile_distribution_encoding, // ? + tuple, // second direction + sequence>, // first direction + tuple, sequence<2, 1>>, // which direction + tuple, sequence<0, 1>>, // which index + // + sequence<2>, + sequence<1>>{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 04d36cf0ea..a60270b578 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -146,6 +146,9 @@ template<> struct WarpGemmDispatcher struct WarpGemmDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; +template<> struct WarpGemmDispatcher { + using Type = WarpGemmMfma_f32_16x16x128_fp4; }; + //WMMA cases template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_f8_f8; }; template struct WarpGemmDispatcher { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8; };