Skip to content

Commit 2991c77

Browse files
rootWangzheee
authored andcommitted
support swapAB for m_grouped_fp8_gemm_nt_masked
1 parent 79f48ee commit 2991c77

File tree

7 files changed

+456
-6
lines changed

7 files changed

+456
-6
lines changed

csrc/jit_kernels/heuristics/common.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,15 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
157157
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
158158
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
159159
block_ms = std::vector{64, 128};
160+
160161
std::vector<int> block_ns;
161162
for (int i = 16; i <= 256; i += 16)
162163
block_ns.push_back(i);
164+
if(get_env<int>("ENABLE_SWAPAB")){
165+
block_ms = std::vector{32}; // 32, 64
166+
block_ns = std::vector{256}; // 256 for H20, and can choose 64, 128, 256
167+
}
168+
163169

164170
// K block size is selected in a fixed manner
165171
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));

csrc/jit_kernels/heuristics/sm90.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@ struct SM90ArchSpec {
4242

4343
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
4444
// Or too many register spills
45-
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
46-
return false;
4745

46+
if(get_env<int>("ENABLE_SWAPAB")){
47+
if (block_n != 64 and block_n != 128 and block_n != 256)
48+
return false;
49+
}else{
50+
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
51+
return false;
52+
}
53+
4854
// Avoid bank conflicts for FP32 output
4955
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
5056
return false;
@@ -79,7 +85,13 @@ struct SM90ArchSpec {
7985

8086
static ThreadConfig get_thread_config(const KernelType& kernel_type,
8187
const int& block_m, const int& block_n) {
82-
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
88+
int tile = 64;
89+
if(get_env<int>("ENABLE_SWAPAB")){
90+
tile = block_n;
91+
}else{
92+
tile = block_m;
93+
}
94+
return ThreadConfig::sm90(128, (tile > 64 ? 2 : 1) * 128);
8395
}
8496

8597
static int get_smem_cd_size(const KernelType& kernel_type,
@@ -104,7 +116,8 @@ struct SM90ArchSpec {
104116

105117
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
106118
const int& block_m, const int& block_n, const int& block_k) {
107-
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
119+
const auto& use_uniform_sfb = get_env<int>("ENABLE_SWAPAB") ? (block_n / 64):(block_k % block_n == 0 ? 1 : 2);
120+
108121
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
109122
}
110123

csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,19 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
2929
};
3030

3131
static std::string generate_impl(const Args& args) {
32+
33+
const char* kernel_name =
34+
get_env<int>("ENABLE_SWAPAB") ?
35+
"swapAB_sm90_fp8_gemm_1d2d_impl" :
36+
"sm90_fp8_gemm_1d2d_impl";
37+
3238
return fmt::format(R"(
3339
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
3440
3541
using namespace deep_gemm;
3642
3743
static void __instantiate_kernel() {{
38-
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
44+
auto ptr = reinterpret_cast<void*>(&{}<
3945
{}, {}, {},
4046
{},
4147
{}, {}, {},
@@ -47,6 +53,7 @@ static void __instantiate_kernel() {{
4753
>);
4854
}};
4955
)",
56+
kernel_name,
5057
// TODO: add CD dtype
5158
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
5259
args.num_groups,

deep_gemm/include/deep_gemm/common/sm90_utils.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ struct SM90_U32x2_STSM_N {
144144
}
145145
};
146146

147+
template <typename dtype_t>
148+
struct SM90_U32x2_STSM_T
149+
{
150+
__device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst)
151+
{
152+
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
153+
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]),
154+
"r"(src[1]));
155+
}
156+
};
157+
147158
__forceinline__ __device__ void warpgroup_arrive() {
148159
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
149160
}

deep_gemm/include/deep_gemm/common/utils.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ __device__ __forceinline__ float ld_shared(const float* ptr) {
122122
return ret;
123123
}
124124

125+
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
126+
float2 ret;
127+
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
128+
return ret;
129+
}
130+
125131
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
126132
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
127133
}

0 commit comments

Comments
 (0)