From 79f48ee15a82dd5fad5cd9beaa393c1f755e6b55 Mon Sep 17 00:00:00 2001 From: yukuai26 <93142162+yukuai26@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:12:27 +0800 Subject: [PATCH 1/2] Fix multicast bug and optimize masked GEMM (#193) * Fix multicast bug and profile masked GEMM * Updates and lint --------- Co-authored-by: Kuai Yu Co-authored-by: Chenggang Zhao --- csrc/jit_kernels/heuristics/common.hpp | 9 ++++++--- csrc/jit_kernels/heuristics/sm100.hpp | 2 +- csrc/jit_kernels/heuristics/sm90.hpp | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 3ed4d2a1..681e6546 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -152,8 +152,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Select M/N block sizes // TODO: support `% 16 == 8` block size on SM90 - const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ? - std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256}; + auto block_ms = std::vector{64, 128, 256}; + if (gemm_type == GemmType::MGroupedContiguous) + block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; + if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance + block_ms = std::vector{64, 128}; std::vector block_ns; for (int i = 16; i <= 256; i += 16) block_ns.push_back(i); @@ -214,7 +217,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k MulticastConfig best_multicast_config = {1, true}; const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( gemm_type, m, n, best_block_m, best_block_n, num_sms); - const bool is_legal[2] = {is_legal_on_a, is_legal_on_b}; + const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; bool order[2] = {false, true}; if (best_block_m > best_block_n) std::swap(order[0], order[1]); diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 4e582891..0679cad2 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -91,8 +91,8 @@ struct SM100ArchSpec { const int& num_sms) { // TODO: support other layouts return { - is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), false, + is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous), }; } diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 16ca018c..58faecf0 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -71,7 +71,9 @@ struct SM90ArchSpec { const int& num_sms) { return { is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked), - is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked, + // For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even + is_multicast_legal(m, block_m, 2, num_sms, false) + and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true)) }; } From 2991c774d0ae7ee066622bcbec27f7a71ff31ef3 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 11 Sep 2025 22:14:35 +0800 Subject: [PATCH 2/2] support swapAB for m_grouped_fp8_gemm_nt_masked --- csrc/jit_kernels/heuristics/common.hpp | 6 + csrc/jit_kernels/heuristics/sm90.hpp | 21 +- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 9 +- .../include/deep_gemm/common/sm90_utils.cuh | 11 + deep_gemm/include/deep_gemm/common/utils.cuh | 6 + .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 407 ++++++++++++++++++ tests/generators.py | 2 +- 7 files changed, 456 insertions(+), 6 deletions(-) diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index 681e6546..7169c02b 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -157,9 +157,15 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance block_ms = std::vector{64, 128}; + std::vector block_ns; for (int i = 16; i <= 256; i += 16) block_ns.push_back(i); + if(get_env("ENABLE_SWAPAB")){ + block_ms = std::vector{32}; // 32, 64 + block_ns = std::vector{256}; // 256 for H20, and can choose 64, 128, 256 + } + // K block size is selected in a fixed manner const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 58faecf0..be266b78 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -42,9 +42,15 @@ struct SM90ArchSpec { // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k` // Or too many register spills - if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) - return false; + if(get_env("ENABLE_SWAPAB")){ + if (block_n != 64 and block_n != 128 and block_n != 256) + return false; + }else{ + if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192)) + return false; + } + // Avoid bank conflicts for FP32 output if (cd_dtype == torch::kFloat and block_n % 16 == 0) return false; @@ -79,7 +85,13 @@ struct SM90ArchSpec { static ThreadConfig get_thread_config(const KernelType& kernel_type, const int& block_m, const int& block_n) { - return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128); + int tile = 64; + if(get_env("ENABLE_SWAPAB")){ + tile = block_n; + }else{ + tile = block_m; + } + return ThreadConfig::sm90(128, (tile > 64 ? 2 : 1) * 128); } static int get_smem_cd_size(const KernelType& kernel_type, @@ -104,7 +116,8 @@ struct SM90ArchSpec { static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k) { - const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2; + const auto& use_uniform_sfb = get_env("ENABLE_SWAPAB") ? (block_n / 64):(block_k % block_n == 0 ? 1 : 2); + return align(ceil_div(k, block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); } diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 3afc2d33..93c4adb1 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -29,13 +29,19 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime }; static std::string generate_impl(const Args& args) { + + const char* kernel_name = + get_env("ENABLE_SWAPAB") ? + "swapAB_sm90_fp8_gemm_1d2d_impl" : + "sm90_fp8_gemm_1d2d_impl"; + return fmt::format(R"( #include using namespace deep_gemm; static void __instantiate_kernel() {{ - auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d2d_impl< + auto ptr = reinterpret_cast(&{}< {}, {}, {}, {}, {}, {}, {}, @@ -47,6 +53,7 @@ static void __instantiate_kernel() {{ >); }}; )", + kernel_name, // TODO: add CD dtype get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.num_groups, diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index e590b479..e54b9969 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -144,6 +144,17 @@ struct SM90_U32x2_STSM_N { } }; +template +struct SM90_U32x2_STSM_T +{ + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst) + { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]), + "r"(src[1])); + } +}; + __forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index fc84b696..4f1b100a 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -122,6 +122,12 @@ __device__ __forceinline__ float ld_shared(const float* ptr) { return ret; } +__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); + return ret; +} + __device__ __forceinline__ void st_shared(const float* ptr, float val) { asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val)); } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 5a65d69e..213166c4 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -436,6 +436,413 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, #endif } +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +swapAB_sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(BLOCK_N == 64 or BLOCK_N == 128 or BLOCK_N == 256 , "Only support BLOCK_N=64 or 128 or 256"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M == WGMMA::N, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + constexpr uint32_t n_nums = (BLOCK_N+127)/128; + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * n_nums * sizeof(float), sizeof(Barrier)); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (threadIdx.x == kNumMathThreads) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_sfa[kNumStages]; + float* smem_sfb; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + + // Fill shared memory pointers + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_sfa[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE); + } + smem_sfb = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + struct NotDivisibleK {}; + struct SkipComputation {}; + struct NotSkipComputation {}; + auto launch_k_iterations = [=](const auto& func, bool skip_computation) { + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = 0; + + // NOTES: for too-many branches (> 5), we disable this optimization + outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) { + if (skip_computation) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type); + } else if (shape_k % kFullKOfAllStages == 0) { + for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } else { + for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter) + func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type); + func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type); + } + }, func, 0); + }; + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; + + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + // NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all + // shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + + // Issue TMA A + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[s]; + uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a); + tma_copy(&tensor_map_sfa, reinterpret_cast(&full_barrier), + smem_sfa[s], m_block_idx * BLOCK_M, + scheduler.get_global_idx(shape_k_scales, 1, k_idx / BLOCK_K), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }, false); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) + empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_sfb = shape_k_scales * n_nums; + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = scheduler.get_global_idx(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx); + auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales; + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(local_sfb + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * n_nums] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + // Launch MMAs + launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { + constexpr bool kSkipComputation = cute::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; + constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); + + #pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++ s) { + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < n_nums; ++ local_idx) { + auto n_offset = local_idx * 128; + // Read B scales + float scale_b = ld_shared(smem_sfb + local_idx * shape_k_scales + k_iter * kNumStages + s); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[s] + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + (math_wg_idx * WGMMA::M + n_offset) * BLOCK_K + k * WGMMA::K, 1); + WGMMA::wgmma(desc_b, desc_a, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == n_nums - 1) + empty_barrier_arrive(s); + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + int in_blk_c = (lane_idx & 3) * 2; + int col_base = (i << 3) + in_blk_c; + float2 sa = ld_shared(reinterpret_cast(smem_sfa[s] + col_base)); + float scale_0_0 = sa.x * scale_b, scale_1_0 = sa.y * scale_b; + shifted_accum[i * 4 + 0] += (scale_0_0) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (scale_1_0) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (scale_0_0) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (scale_1_0) * accum[i * 4 + 3]; + } + } + } + + // Wait unaligned cases + #pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) { + full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }, not scheduler.is_computation_valid(m_block_idx, 0)); + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + + + #pragma unroll + for (uint32_t local_idx = 0; local_idx < n_nums; ++ local_idx) { + auto n_offset = local_idx * 128; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { // WGMMA::M * WGMMA::N / 128 / 4 + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + uint32_t src_row = warp_idx * WGMMA_M_PER_WARP + lane_idx; + uint32_t src_col8 = i * 8; + uint32_t blk_r = src_row >> 3; + uint32_t blk_c = src_col8 >> 3; + uint32_t in_r = src_row & 7; + uint32_t in_c = src_col8 & 7; + uint32_t dst_row = blk_c * 8 + in_r; + uint32_t dst_col8 = blk_r * 8 + in_c; + dst_col8 += n_offset; + uint32_t warp_idx_tr = dst_row / WGMMA_M_PER_WARP; + uint32_t lane_idx_tr = dst_row % WGMMA_M_PER_WARP; + constexpr uint32_t kNumBankGroupBytes = 16; + uint32_t i_tr = dst_col8 >> 3; + uint32_t atom_offset = i_tr / (TMA_D_BLOCK_N / 8); + uint32_t in_atom_offset = i_tr % (TMA_D_BLOCK_N / 8); + uint32_t bank_group_idx = + in_atom_offset + lane_idx_tr * (kSwizzleDMode / kNumBankGroupBytes); + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + uint32_t row = kHasShortcut ? (in_atom_offset / 8 + lane_idx_tr) + : (bank_group_idx / 8); + uint32_t col = kHasShortcut ? (in_atom_offset) + : (bank_group_idx % 8); + col ^= row % (kSwizzleDMode / 16); + smem_ptr = reinterpret_cast(smem_d) + + warp_idx_tr * (WGMMA_M_PER_WARP * kSwizzleDMode) + + atom_offset * BLOCK_M * kSwizzleDMode + + row * (kNumBankGroupBytes * 8) + + col * kNumBankGroupBytes; + } else { + // No swizzling, just padding + int row = warp_idx * WGMMA_M_PER_WARP + lane_idx; + int col8 = i * 8; + + int blk_r = row >> 3; + int blk_c = col8 >> 3; + int in_r = row & 7; + int in_c = col8 & 7; + int dst_blk_r = blk_c; + int dst_blk_c = blk_r; + int dst_row = dst_blk_r * 8 + in_r; + int dst_col8 = dst_blk_c * 8 + in_c; + + smem_ptr = reinterpret_cast( + smem_d + n_offset + dst_row * BLOCK_N + dst_col8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_T::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + }; // namespace deep_gemm #pragma clang diagnostic pop diff --git a/tests/generators.py b/tests/generators.py index 82cdbdcc..331ed865 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -88,7 +88,7 @@ def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator: def enumerate_m_grouped_masked() -> Generator: max_m = 4096 for kernel_type in get_kernel_types(): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for num_groups, m in ((1, 1024), (2, 512), (4, 256), (16, 32), (16, 64)): for n, k in ((4096, 7168), (7168, 2048), ): yield kernel_type, num_groups, max_m, m, n, k