From 9fb783e828d0351311e05ccadeaa68cf3f0e5ac4 Mon Sep 17 00:00:00 2001 From: Aya Ibrahim Date: Thu, 28 Aug 2025 19:15:07 -0700 Subject: [PATCH] Fix input to TritonSplitK performance benchmark (#323) Summary: Pull Request resolved: https://github.com/meta-pytorch/tritonbench/pull/323 The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors. The performance of Triton is significantly affected probably due to no gqa packing. Before: [before change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(83886080, 160, 32, 1) 220 GBps | **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** | |---------------|---------------------|---------------------------|---------------| | (16, 1, 1024, 32768, 5, 1, 128) | 971.46 | 226.55 | 0.23x | | (16, 1, 2048, 32768, 5, 1, 128) | 1559.77 | 226.45 | 0.15x | | (16, 1, 4096, 32768, 5, 1, 128) | 2347.07 | 221.76 | 0.09x | | (16, 1, 8190, 32768, 5, 1, 128) | 1536.70 | 241.50 | 0.16x | | (16, 1, 32760, 32768, 5, 1, 128) | 2064.07 | 325.89 | 0.16x | | (32, 1, 1024, 32768, 5, 1, 128) | 1684.13 | 225.33 | 0.13x | | (32, 1, 2048, 32768, 5, 1, 128) | 2615.06 | 226.94 | 0.09x | | (32, 1, 4096, 32768, 5, 1, 128) | 1642.11 | 241.08 | 0.15x | | (32, 1, 8190, 32768, 5, 1, 128) | 1652.16 | 255.77 | 0.15x | | (32, 1, 32760, 32768, 5, 1, 128) | 2136.59 | 335.58 | 0.16x | | (64, 1, 1024, 32768, 5, 1, 128) | 2530.08 | 238.56 | 0.09x | | (64, 1, 2048, 32768, 5, 1, 128) | 1731.80 | 257.47 | 0.15x | | (64, 1, 4096, 32768, 5, 1, 128) | 1755.01 | 275.44 | 0.16x | | (64, 1, 8190, 32768, 5, 1, 128) | 1823.19 | 282.90 | 0.16x | | (64, 1, 32760, 32768, 5, 1, 128) | 2183.06 | 339.31 | 0.16x | | (128, 1, 1024, 32768, 5, 1, 128) | 1692.65 | 206.45 | 0.12x | | (128, 1, 2048, 32768, 5, 1, 128) | 1785.97 | 297.12 | 0.17x | | (128, 1, 4096, 32768, 5, 1, 128) | 1911.85 | 315.38 | 0.16x | | (128, 1, 8190, 32768, 5, 1, 128) | 1922.52 | 328.28 | 0.17x | | (128, 1, 32760, 32768, 5, 1, 128) | 2221.66 | 324.47 | 0.15x | After: [after change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(16777216, 32, 0, 1) ~ 1000 GBps |---------------|---------------------|---------------------------|---------------| | **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** | | (16, 1, 1024, 32768, 5, 1, 128) | 974.43 | 368.21 | 0.38x | | (16, 1, 2048, 32768, 5, 1, 128) | 1547.81 | 664.53 | 0.43x | | (16, 1, 4096, 32768, 5, 1, 128) | 2464.77 | 1060.36 | 0.43x | | (16, 1, 8190, 32768, 5, 1, 128) | 1582.76 | 929.04 | 0.59x | | (16, 1, 32760, 32768, 5, 1, 128) | 2078.04 | 1443.88 | 0.69x | | (32, 1, 1024, 32768, 5, 1, 128) | 1674.33 | 694.27 | 0.41x | | (32, 1, 2048, 32768, 5, 1, 128) | 2630.66 | 1101.50 | 0.42x | | (32, 1, 4096, 32768, 5, 1, 128) | 1670.73 | 1147.36 | 0.69x | | (32, 1, 8190, 32768, 5, 1, 128) | 1664.33 | 907.95 | 0.55x | | (32, 1, 32760, 32768, 5, 1, 128) | 2152.65 | 1524.07 | 0.71x | | (64, 1, 1024, 32768, 5, 1, 128) | 2558.07 | 1161.96 | 0.45x | | (64, 1, 2048, 32768, 5, 1, 128) | 1672.36 | 1195.78 | 0.72x | | (64, 1, 4096, 32768, 5, 1, 128) | 1754.12 | 1126.56 | 0.64x | | (64, 1, 8190, 32768, 5, 1, 128) | 1824.65 | 961.22 | 0.53x | | (64, 1, 32760, 32768, 5, 1, 128) | 2181.59 | 1591.11 | 0.73x | | (128, 1, 1024, 32768, 5, 1, 128) | 1712.90 | 1190.96 | 0.70x | | (128, 1, 2048, 32768, 5, 1, 128) | 1788.32 | 1156.16 | 0.65x | | (128, 1, 4096, 32768, 5, 1, 128) | 1909.89 | 1228.37 | 0.64x | | (128, 1, 8190, 32768, 5, 1, 128) | 1922.10 | 1016.18 | 0.53x | | (128, 1, 32760, 32768, 5, 1, 128) | 2203.02 | 1644.25 | 0.75x | Reviewed By: y-sq, sijiac Differential Revision: D79282737 --- tritonbench/operators/decoding_attention/operator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tritonbench/operators/decoding_attention/operator.py b/tritonbench/operators/decoding_attention/operator.py index 680b0d5df..7d6a7adf7 100644 --- a/tritonbench/operators/decoding_attention/operator.py +++ b/tritonbench/operators/decoding_attention/operator.py @@ -469,11 +469,10 @@ def triton_splitk_fp8kv( ) -> Callable: _, _, num_q_heads, _ = q.shape batch_size, max_sequence_length, _, _ = k_cache.shape + k_cache = k_cache.to(torch.uint8).view(torch.int32) + v_cache = v_cache.to(torch.uint8).view(torch.int32) _q, _k, _v, attn_bias = _pack_xformer_input(q, k_cache, v_cache, cache_seqlens) - _k = _k.to(torch.uint8).view(torch.int32) - _v = _v.to(torch.uint8).view(torch.int32) - k_fp8_scales_shifts = torch.zeros( batch_size * max_sequence_length, dtype=torch.int32,