Skip to content

Commit 98984e0

Browse files
authored
[FA2] flash-attn-mma 3080/L20/4090 bench✔️ (#184)
* Create .gitignore * Create matrix_trans_swizzle.cu * Create hgemm_mma_naive_swizzle.cu * Update matrix_trans_swizzle.cu * Create ncu_bank_conflicts.md * Rename hgemm_mma_naive_swizzle.cu to hgemm_mma_swizzle.cu * Update and rename ncu_bank_conflicts.md to bank_conflicts.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md
1 parent b9d1590 commit 98984e0

File tree

6 files changed

+493
-7
lines changed

6 files changed

+493
-7
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
5252
|✔️|✔️|✔️|✔️|
5353
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
5454
|✔️|✔️|✔️|✔️|
55-
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
56-
|✔️|✔️|✔️|?|
55+
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
56+
|✔️|✔️|✔️|✔️|
5757

5858
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. On NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](#mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA (EFFICIENT ATTENTION). However, for large-scale attention, there remains a performance gap. Stay tuned for updates ~ (👇Benchmark)
5959

@@ -66,7 +66,7 @@ Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run
6666
|SDPA(EFFICIENT ATTENTION)|(1,48,8192,512)|16 TFLOPS|58 TFLOPS|85 TFLOPS|
6767
|mma(split-q+tiling-qk+stage2)|(1,48,8192,512)|**23 TFLOPS**|**81 TFLOPS**|**120 TFLOPS**|
6868

69-
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
69+
The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` method, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).
7070

7171
- 📚 Split KV (Basic, FlashAttention-1)
7272
<div id="mma-split-kv"></div>
@@ -427,6 +427,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
427427
| [[cute系列详解][Swizzle]📖cute Swizzle细谈](https://zhuanlan.zhihu.com/p/684250988)|@进击的Killua|
428428
| [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(一)](https://zhuanlan.zhihu.com/p/710337546)|@Titus|
429429
| [[cute系列详解][Swizzle]📖cutlass swizzle机制解析(二)](https://zhuanlan.zhihu.com/p/711398930)|@Titus|
430+
| [[cute系列详解][Swizzle]📖CUDA避免bank conflict的swizzle机制解析](https://zhuanlan.zhihu.com/p/4746910252)|@frankshi|
430431
| [[cute系列详解][GEMM]📖cute 之 简单GEMM实现](https://zhuanlan.zhihu.com/p/667521327)|@reed|
431432
| [[cute系列详解][GEMM]📖cute 之 GEMM流水线](https://zhuanlan.zhihu.com/p/665082713)|@reed|
432433
| [[cute系列详解][GEMM]📖cute 之 高效GEMM实现](https://zhuanlan.zhihu.com/p/675308830)|@reed|

kernels/flash-attn/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
|✔️|✔️|✔️|✔️|
88
|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
99
|✔️|✔️|✔️|✔️|
10-
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shuffle & Reg Reuse)|**Split KV/Q**|
10+
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**|
11+
|✔️|✔️|✔️|✔️|
12+
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|**QK Fine-grained Tiling**|
1113
|✔️|✔️|✔️|✔️|
12-
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
13-
|✔️|✔️|✔️|?|
1414

1515
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` it can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ (👇Benchmark)
1616

17-
|Algorithm| (B,H,N,D) | NVIDIA GeForce RTX 3080 Laptop | NVIDIA L20 | NVIDIA RTX 4090 |
17+
|Algorithm| (B,H,N,D) | NVIDIA RTX 3080 Laptop | NVIDIA L20 | NVIDIA GeForce RTX 4090 |
1818
|:---:|:---:|:---:|:---:|:---:|
1919
|FlashAttention-2|(1,8,8192,64)|37 TFLOPS|100 TFLOPS|145 TFLOPS|
2020
|mma(split-q+share-qkv+stage2)|(1,8,8192,64)|**55 TFLOPS**|96 TFLOPS|**218 TFLOPS**|
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
## Check Bank Conflicts via NCU
2+
3+
- 检查device支持的metrics
4+
```bash
5+
# ncu check bank conflicts
6+
# 先查看当前devices支持的metrics有哪些
7+
ncu --query-metrics | grep data | grep bank | grep l1tex
8+
```
9+
metrics:
10+
```bash
11+
ncu --query-metrics | grep data | grep bank | grep l1tex
12+
l1tex__data_bank_conflicts_pipe_lsu Counter # of data bank conflicts generated by LSU pipe
13+
l1tex__data_bank_conflicts_pipe_lsu_cmd_read Counter # of data bank conflicts generated by LSU reads
14+
l1tex__data_bank_conflicts_pipe_lsu_cmd_write Counter # of data bank conflicts generated by LSU writes
15+
l1tex__data_bank_conflicts_pipe_lsu_mem_global Counter # of data bank conflicts generated by global ops
16+
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_atom Counter # of data bank conflicts generated by global atomics
17+
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_ld Counter # of data bank conflicts generated by global loads
18+
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_red Counter # of data bank conflicts generated by global reductions
19+
l1tex__data_bank_conflicts_pipe_lsu_mem_global_op_st Counter # of data bank conflicts generated by global stores
20+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared Counter # of shared memory data bank conflicts generated by LDS, LD, 3D
21+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_atom Counter # of shared memory data bank conflicts generated by ATOMS, ATOM
22+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld Counter # of shared memory data bank conflicts generated by LDS, LD, 3D
23+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of data bank conflicts generated by shared ldgsts ops
24+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST, 3D
25+
l1tex__data_bank_reads Counter # of data bank reads
26+
l1tex__data_bank_writes Counter # of data bank writes
27+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS
28+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS
29+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS
30+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM
31+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST
32+
sm__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes
33+
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts Counter # of shared memory data bank conflicts generated by LDGSTS
34+
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of shared memory data bank conflicts generated by LDGSTS.ACCESS
35+
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldgsts_cache_bypass Counter # of shared memory data bank conflicts generated by LDGSTS.BYPASS
36+
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm Counter # of shared memory data bank conflicts generated by LDSM
37+
smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_st Counter # of shared memory data bank conflicts generated by STS, ST
38+
smsp__sass_l1tex_data_bank_writes_pipe_lsu_mem_shared_op_ldgsts_cache_access Counter # of LDGSTS.ACCESS shared data bank writes
39+
```
40+
41+
- 由LD指令产生的bank conflicts
42+
```bash
43+
# profile l1tex smem data bank conflicts
44+
# 由LDS, LD指令产生的bank conflicts
45+
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_mma_stage.89.bin
46+
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum hgemm_cute.89.debug.bin
47+
ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld \
48+
python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
49+
```
50+
log:
51+
```bash
52+
void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<64, 64, 256, 4, 0, 0, cutlass::half_t, Flash_kernel_traits<64, 64, 256, 4, cutlass::half_t>>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
53+
Section: Command line profiler metrics
54+
-------------------------------------------------------- ----------- ------------
55+
Metric Name Metric Unit Metric Value
56+
-------------------------------------------------------- ----------- ------------
57+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.avg 11.18
58+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.max 13
59+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.min 10
60+
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 1029
61+
-------------------------------------------------------- ----------- ------------
62+
```
63+
64+
- 由LDSM指令产生的bank conflicts
65+
66+
```bash
67+
# 由LDSM(ldmatrix)指令产生的bank conflicts
68+
ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
69+
python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
70+
ncu --metrics smsp__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm \
71+
python3 flash_attn_mma.py --B 1 --H 1 --D 64 --N 4096 --w 0 --i 1
72+
```
73+
log:
74+
```bash
75+
void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<64, 64, 256, 4, 0, 0, cutlass::half_t, Flash_kernel_traits<64, 64, 256, 4, cutlass::half_t>>, 8, 3, 1>(Flash_fwd_params) (512, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
76+
Section: Command line profiler metrics
77+
------------------------------------------------------------------ ----------- ------------
78+
Metric Name Metric Unit Metric Value
79+
------------------------------------------------------------------ ----------- ------------
80+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0
81+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0
82+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
83+
sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0
84+
------------------------------------------------------------------ ----------- ------------
85+
```

kernels/swizzle/.gitignore

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
*.so
2+
*.a
3+
*.dylib
4+
*.dll
5+
*.lib
6+
.DS_Store
7+
build
8+
*.whl
9+
tmp
10+
__pycache__
11+
*.onnx
12+
*.engine
13+
*.pt
14+
*.pth
15+
*.nsys*
16+
*.ncu*
17+
*.sqlite*
18+
*.engine
19+
*.bin
20+
*.out
21+
*bin
22+
bin
23+
output
24+
*.egg-info
25+
*.whl
26+
dist
27+
*.pdf
28+
*.tex
29+
*.log
30+
*.md5
31+
*.aux*
32+
*.dpth

0 commit comments

Comments
 (0)