Skip to content

Commit e276d45

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Fix input to TritonSplitK performance benchmark
Summary: The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors. The performance of Triton is significantly affected probably due to no gqa packing. Before: [before change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(83886080, 160, 32, 1) 220 GBps | **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** | |---------------|---------------------|---------------------------|---------------| | (16, 1, 1024, 32768, 5, 1, 128) | 971.46 | 226.55 | 0.23x | | (16, 1, 2048, 32768, 5, 1, 128) | 1559.77 | 226.45 | 0.15x | | (16, 1, 4096, 32768, 5, 1, 128) | 2347.07 | 221.76 | 0.09x | | (16, 1, 8190, 32768, 5, 1, 128) | 1536.70 | 241.50 | 0.16x | | (16, 1, 32760, 32768, 5, 1, 128) | 2064.07 | 325.89 | 0.16x | | (32, 1, 1024, 32768, 5, 1, 128) | 1684.13 | 225.33 | 0.13x | | (32, 1, 2048, 32768, 5, 1, 128) | 2615.06 | 226.94 | 0.09x | | (32, 1, 4096, 32768, 5, 1, 128) | 1642.11 | 241.08 | 0.15x | | (32, 1, 8190, 32768, 5, 1, 128) | 1652.16 | 255.77 | 0.15x | | (32, 1, 32760, 32768, 5, 1, 128) | 2136.59 | 335.58 | 0.16x | | (64, 1, 1024, 32768, 5, 1, 128) | 2530.08 | 238.56 | 0.09x | | (64, 1, 2048, 32768, 5, 1, 128) | 1731.80 | 257.47 | 0.15x | | (64, 1, 4096, 32768, 5, 1, 128) | 1755.01 | 275.44 | 0.16x | | (64, 1, 8190, 32768, 5, 1, 128) | 1823.19 | 282.90 | 0.16x | | (64, 1, 32760, 32768, 5, 1, 128) | 2183.06 | 339.31 | 0.16x | | (128, 1, 1024, 32768, 5, 1, 128) | 1692.65 | 206.45 | 0.12x | | (128, 1, 2048, 32768, 5, 1, 128) | 1785.97 | 297.12 | 0.17x | | (128, 1, 4096, 32768, 5, 1, 128) | 1911.85 | 315.38 | 0.16x | | (128, 1, 8190, 32768, 5, 1, 128) | 1922.52 | 328.28 | 0.17x | | (128, 1, 32760, 32768, 5, 1, 128) | 2221.66 | 324.47 | 0.15x | After: [after change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(16777216, 32, 0, 1) ~ 1000 GBps |---------------|---------------------|---------------------------|---------------| | **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** | | (16, 1, 1024, 32768, 5, 1, 128) | 974.43 | 368.21 | 0.38x | | (16, 1, 2048, 32768, 5, 1, 128) | 1547.81 | 664.53 | 0.43x | | (16, 1, 4096, 32768, 5, 1, 128) | 2464.77 | 1060.36 | 0.43x | | (16, 1, 8190, 32768, 5, 1, 128) | 1582.76 | 929.04 | 0.59x | | (16, 1, 32760, 32768, 5, 1, 128) | 2078.04 | 1443.88 | 0.69x | | (32, 1, 1024, 32768, 5, 1, 128) | 1674.33 | 694.27 | 0.41x | | (32, 1, 2048, 32768, 5, 1, 128) | 2630.66 | 1101.50 | 0.42x | | (32, 1, 4096, 32768, 5, 1, 128) | 1670.73 | 1147.36 | 0.69x | | (32, 1, 8190, 32768, 5, 1, 128) | 1664.33 | 907.95 | 0.55x | | (32, 1, 32760, 32768, 5, 1, 128) | 2152.65 | 1524.07 | 0.71x | | (64, 1, 1024, 32768, 5, 1, 128) | 2558.07 | 1161.96 | 0.45x | | (64, 1, 2048, 32768, 5, 1, 128) | 1672.36 | 1195.78 | 0.72x | | (64, 1, 4096, 32768, 5, 1, 128) | 1754.12 | 1126.56 | 0.64x | | (64, 1, 8190, 32768, 5, 1, 128) | 1824.65 | 961.22 | 0.53x | | (64, 1, 32760, 32768, 5, 1, 128) | 2181.59 | 1591.11 | 0.73x | | (128, 1, 1024, 32768, 5, 1, 128) | 1712.90 | 1190.96 | 0.70x | | (128, 1, 2048, 32768, 5, 1, 128) | 1788.32 | 1156.16 | 0.65x | | (128, 1, 4096, 32768, 5, 1, 128) | 1909.89 | 1228.37 | 0.64x | | (128, 1, 8190, 32768, 5, 1, 128) | 1922.10 | 1016.18 | 0.53x | | (128, 1, 32760, 32768, 5, 1, 128) | 2203.02 | 1644.25 | 0.75x | Reviewed By: y-sq, sijiac Differential Revision: D79282737
1 parent 2ed2c59 commit e276d45

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tritonbench/operators/decoding_attention/operator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,11 +469,10 @@ def triton_splitk_fp8kv(
469469
) -> Callable:
470470
_, _, num_q_heads, _ = q.shape
471471
batch_size, max_sequence_length, _, _ = k_cache.shape
472+
k_cache = k_cache.to(torch.uint8).view(torch.int32)
473+
v_cache = v_cache.to(torch.uint8).view(torch.int32)
472474
_q, _k, _v, attn_bias = _pack_xformer_input(q, k_cache, v_cache, cache_seqlens)
473475

474-
_k = _k.to(torch.uint8).view(torch.int32)
475-
_v = _v.to(torch.uint8).view(torch.int32)
476-
477476
k_fp8_scales_shifts = torch.zeros(
478477
batch_size * max_sequence_length,
479478
dtype=torch.int32,

0 commit comments

Comments
 (0)