Skip to content

Commit 78d434c

Browse files
authored
refactor: move TileShape into launch_mha_kernel_sm80 (#468)
1 parent 459b943 commit 78d434c

File tree

3 files changed

+43
-100
lines changed

3 files changed

+43
-100
lines changed

src/kernels/attention/generate_instantiation_cu.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,18 @@
2222
MHA_KERNEL_TEMPLATE = """
2323
#include "mha_kernel_sm80.cuh" // IWYU pragma: export
2424
#include "mha_params.h" // IWYU pragma: export
25-
#include "mha_traits_sm80.h" // IWYU pragma: export
2625
2726
namespace llm {{
2827
29-
using Traits = MHATraitsSM80<{DTYPE}, {HEAD_DIM}, {BLK_M}, {BLK_N}, {BLK_K}>;
3028
using Params = MHAPagedKVParams;
3129
32-
template void launch_mha_kernel_sm80<Traits,
33-
Params,
30+
template void launch_mha_kernel_sm80</*DTYPE=*/{DTYPE},
31+
/*HEAD_DIM=*/{HEAD_DIM},
3432
/*EVEN_K=*/{EVEN_K},
3533
/*ALIBI=*/{ALIBI},
3634
/*SOFT_CAP=*/{SOFT_CAP},
37-
/*LOCAL=*/{LOCAL}>(const Params& params,
35+
/*LOCAL=*/{LOCAL},
36+
Params>(const Params& params,
3837
cudaStream_t stream);
3938
}} // namespace llm
4039
"""
@@ -59,24 +58,16 @@
5958
class MHAKernel:
6059
dtype: str
6160
head_dim: int
62-
blk_m: int
63-
blk_n: int
64-
blk_k: int
6561
even_k: bool
6662
alibi: bool
6763
soft_cap: bool
6864
local: bool
6965

7066
@property
7167
def template(self) -> str:
72-
assert self.head_dim % self.blk_k == 0
73-
7468
return MHA_KERNEL_TEMPLATE.format(
7569
DTYPE=DTYPE_MAP[self.dtype],
7670
HEAD_DIM=self.head_dim,
77-
BLK_M=self.blk_m,
78-
BLK_N=self.blk_n,
79-
BLK_K=self.blk_k,
8071
EVEN_K=BOOL_MAP[self.even_k],
8172
ALIBI=BOOL_MAP[self.alibi],
8273
SOFT_CAP=BOOL_MAP[self.soft_cap],
@@ -88,7 +79,7 @@ def filename(self) -> str:
8879
def to_str(val: bool) -> str:
8980
return "1" if val else "0"
9081

91-
return f"mha_{self.dtype}_hd{self.head_dim}_m{self.blk_m}_n{self.blk_n}_k{self.blk_k}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"
82+
return f"mha_{self.dtype}_hd{self.head_dim}_ek{to_str(self.even_k)}_al{to_str(self.alibi)}_sc{to_str(self.soft_cap)}_lc{to_str(self.local)}_sm80.cu"
9283

9384

9485
@dataclass
@@ -125,33 +116,21 @@ def gen_mha_kernels() -> Iterator[MHAKernel]:
125116
for (
126117
dtype,
127118
head_dim,
128-
blk_m,
129-
blk_n,
130-
blk_k,
131119
even_k,
132120
alibi,
133121
soft_cap,
134122
local,
135123
) in itertools.product(
136124
["fp16", "bf16"], # dtype
137125
[64, 96, 128, 256], # head_dim
138-
[64], # blk_m
139-
[64], # blk_n
140-
[32, 64], # blk_k
141126
[False, True], # even_k
142127
[False, True], # alibi
143128
[False, True], # soft_cap
144129
[False, True], # local
145130
):
146-
# skip invalid configurations
147-
if head_dim % blk_k != 0:
148-
continue
149131
yield MHAKernel(
150132
dtype=dtype,
151133
head_dim=head_dim,
152-
blk_m=blk_m,
153-
blk_n=blk_n,
154-
blk_k=blk_k,
155134
even_k=even_k,
156135
alibi=alibi,
157136
soft_cap=soft_cap,

src/kernels/attention/mha_dispatch_sm80.cuh

Lines changed: 15 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,96 +3,41 @@
33
#include <cute/int_tuple.hpp>
44
#include <cute/layout.hpp>
55

6-
#include "mha_traits_sm80.h"
76
#include "static_dispatch.h"
87

98
namespace llm {
109
// forward declaration
11-
template <typename Traits,
12-
typename Params,
10+
template <typename Dtype,
11+
int HEAD_DIM,
1312
bool EVEN_K,
1413
bool ALIBI,
1514
bool SOFT_CAP,
16-
bool LOCAL>
15+
bool LOCAL,
16+
typename Params>
1717
void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream);
1818

19-
namespace detail {
19+
// user-facing function to run the attention kernel
20+
template <typename Dtype, int HEAD_DIM, typename Params>
21+
void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
22+
// normalize params that for performance optimization
23+
params.normalize();
2024

21-
template <typename Traits, typename Params>
22-
void dispatch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
2325
// dispatch to proper kernel instantiation based on params
24-
DISPATCH_BOOL(params.head_dim == Traits::kHeadDim, EVEN_K, [&] {
26+
DISPATCH_BOOL(params.head_dim == HEAD_DIM, EVEN_K, [&] {
2527
DISPATCH_BOOL(params.alibi_slopes_ptr != nullptr, ALIBI, [&] {
2628
DISPATCH_BOOL(params.logits_soft_cap > 0, SOFT_CAP, [&] {
2729
DISPATCH_BOOL(params.sliding_window >= 0, LOCAL, [&] {
28-
launch_mha_kernel_sm80<Traits,
29-
Params,
30+
launch_mha_kernel_sm80<Dtype,
31+
HEAD_DIM,
3032
EVEN_K,
3133
ALIBI,
3234
SOFT_CAP,
33-
LOCAL>(params, stream);
35+
LOCAL,
36+
Params>(params, stream);
3437
});
3538
});
3639
});
3740
});
3841
}
3942

40-
} // namespace detail
41-
42-
// user-facing function to run the attention kernel
43-
template <typename Dtype, int HEAD_DIM, typename Params>
44-
void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
45-
// normalize params that for performance optimization
46-
params.normalize();
47-
48-
// TODO: tune block shape MNK based on the head dim and smem size
49-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
50-
// SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0|
51-
// Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 |
52-
// valid dynamic shared memory sizes for different compute capabilities:
53-
// * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96
54-
// * 7.5 : 0, 32, 64
55-
// * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164
56-
// * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100
57-
// * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228
58-
// * 12.0 : 0, 8, 16, 32, 64, 100
59-
if constexpr (HEAD_DIM == 64) {
60-
using Traits = MHATraitsSM80<Dtype,
61-
HEAD_DIM,
62-
/*BLK_M=*/64,
63-
/*BLK_N=*/64,
64-
/*BLK_K=*/64>;
65-
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
66-
} else if constexpr (HEAD_DIM == 96) {
67-
using Traits = MHATraitsSM80<Dtype,
68-
HEAD_DIM,
69-
/*BLK_M=*/64,
70-
/*BLK_N=*/64,
71-
/*BLK_K=*/32>;
72-
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
73-
} else if constexpr (HEAD_DIM == 128) {
74-
using Traits = MHATraitsSM80<Dtype,
75-
HEAD_DIM,
76-
/*BLK_M=*/64,
77-
/*BLK_N=*/64,
78-
/*BLK_K=*/64>;
79-
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
80-
} else if constexpr (HEAD_DIM == 256) {
81-
using Traits = MHATraitsSM80<Dtype,
82-
HEAD_DIM,
83-
/*BLK_M=*/64,
84-
/*BLK_N=*/64,
85-
/*BLK_K=*/64>;
86-
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
87-
} else {
88-
// use the default block size
89-
using Traits = MHATraitsSM80<Dtype,
90-
HEAD_DIM,
91-
/*BLK_M=*/64,
92-
/*BLK_N=*/64,
93-
/*BLK_K=*/64>;
94-
detail::dispatch_mha_kernel_sm80<Traits>(params, stream);
95-
}
96-
}
97-
98-
} // namespace llm
43+
} // namespace llm

src/kernels/attention/mha_kernel_sm80.cuh

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "layout_convertor.h"
1414
#include "mask.h"
1515
#include "mha_tile.h"
16+
#include "mha_traits_sm80.h"
1617
#include "online_softmax.cuh"
1718

1819
namespace llm {
@@ -436,17 +437,35 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mha_kernel_sm80(
436437
epilogue(tOrO);
437438
}
438439

439-
template <typename Traits,
440-
typename Params,
440+
template <typename Dtype,
441+
int HEAD_DIM,
441442
bool EVEN_K,
442443
bool ALIBI,
443444
bool SOFT_CAP,
444-
bool LOCAL>
445+
bool LOCAL,
446+
typename Params>
445447
void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
446448
const auto batch_size = params.batch_size;
447449
const auto n_kv_heads = params.n_kv_heads;
448450
const auto max_q_packed_len = params.max_q_len * params.group_size;
449451

452+
// TODO: tune block shape MNK based on the head dim and smem size
453+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
454+
// SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0|
455+
// Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 |
456+
// valid dynamic shared memory sizes for different compute capabilities:
457+
// * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96
458+
// * 7.5 : 0, 32, 64
459+
// * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164
460+
// * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100
461+
// * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228
462+
// * 12.0 : 0, 8, 16, 32, 64, 100
463+
464+
constexpr int BLK_M = 64;
465+
constexpr int BLK_N = 64;
466+
constexpr int BLK_K = HEAD_DIM % 64 == 0 ? 64 : 32;
467+
using Traits = MHATraitsSM80<Dtype, HEAD_DIM, BLK_M, BLK_N, BLK_K>;
468+
450469
const auto smem_size = sizeof(MHASharedStorage<Traits>);
451470
auto mha_kernel =
452471
mha_kernel_sm80<Traits, Params, EVEN_K, ALIBI, SOFT_CAP, LOCAL>;
@@ -460,4 +479,4 @@ void launch_mha_kernel_sm80(const Params& params, cudaStream_t stream) {
460479
mha_kernel<<<grid, block, smem_size, stream>>>(params);
461480
}
462481

463-
} // namespace llm
482+
} // namespace llm

0 commit comments

Comments
 (0)