Skip to content

Conversation

Aya-ZIbra
Copy link
Contributor

Summary:
The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors.

The performance of Triton is significantly affected probably due to no gqa packing.

Before:
220 GBps

After:

~ 1000 GBps

Differential Revision: D79282737

@meta-cla meta-cla bot added the cla signed label Aug 1, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79282737

@Aya-ZIbra Aya-ZIbra temporarily deployed to docker-s3-upload August 1, 2025 18:00 — with GitHub Actions Inactive
@Aya-ZIbra Aya-ZIbra temporarily deployed to docker-s3-upload August 1, 2025 18:00 — with GitHub Actions Inactive
Aya-ZIbra added a commit to Aya-ZIbra/tritonbench that referenced this pull request Aug 6, 2025
Summary:
X-link: meta-pytorch#323

The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors.

The performance of Triton is significantly affected probably due to no gqa packing.


Before:

  [before change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(83886080, 160, 32, 1)
220 GBps

| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
|---------------|---------------------|---------------------------|---------------|
| (16, 1, 1024, 32768, 5, 1, 128) | 971.46 | 226.55 | 0.23x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1559.77 | 226.45 | 0.15x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2347.07 | 221.76 | 0.09x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1536.70 | 241.50 | 0.16x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2064.07 | 325.89 | 0.16x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1684.13 | 225.33 | 0.13x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2615.06 | 226.94 | 0.09x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1642.11 | 241.08 | 0.15x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1652.16 | 255.77 | 0.15x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2136.59 | 335.58 | 0.16x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2530.08 | 238.56 | 0.09x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1731.80 | 257.47 | 0.15x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1755.01 | 275.44 | 0.16x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1823.19 | 282.90 | 0.16x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2183.06 | 339.31 | 0.16x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1692.65 | 206.45 | 0.12x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1785.97 | 297.12 | 0.17x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1911.85 | 315.38 | 0.16x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.52 | 328.28 | 0.17x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2221.66 | 324.47 | 0.15x |

After:


   [after change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(16777216, 32, 0, 1)

~ 1000 GBps
|---------------|---------------------|---------------------------|---------------|
| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
| (16, 1, 1024, 32768, 5, 1, 128) | 974.43 | 368.21 | 0.38x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1547.81 | 664.53 | 0.43x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2464.77 | 1060.36 | 0.43x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1582.76 | 929.04 | 0.59x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2078.04 | 1443.88 | 0.69x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1674.33 | 694.27 | 0.41x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2630.66 | 1101.50 | 0.42x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1670.73 | 1147.36 | 0.69x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1664.33 | 907.95 | 0.55x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2152.65 | 1524.07 | 0.71x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2558.07 | 1161.96 | 0.45x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1672.36 | 1195.78 | 0.72x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1754.12 | 1126.56 | 0.64x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1824.65 | 961.22 | 0.53x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2181.59 | 1591.11 | 0.73x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1712.90 | 1190.96 | 0.70x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1788.32 | 1156.16 | 0.65x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1909.89 | 1228.37 | 0.64x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.10 | 1016.18 | 0.53x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2203.02 | 1644.25 | 0.75x |

Reviewed By: y-sq, sijiac

Differential Revision: D79282737
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79282737

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79282737

Aya-ZIbra added a commit to Aya-ZIbra/tritonbench that referenced this pull request Aug 29, 2025
Summary:

The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors.

The performance of Triton is significantly affected probably due to no gqa packing.


Before:

  [before change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(83886080, 160, 32, 1)
220 GBps

| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
|---------------|---------------------|---------------------------|---------------|
| (16, 1, 1024, 32768, 5, 1, 128) | 971.46 | 226.55 | 0.23x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1559.77 | 226.45 | 0.15x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2347.07 | 221.76 | 0.09x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1536.70 | 241.50 | 0.16x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2064.07 | 325.89 | 0.16x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1684.13 | 225.33 | 0.13x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2615.06 | 226.94 | 0.09x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1642.11 | 241.08 | 0.15x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1652.16 | 255.77 | 0.15x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2136.59 | 335.58 | 0.16x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2530.08 | 238.56 | 0.09x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1731.80 | 257.47 | 0.15x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1755.01 | 275.44 | 0.16x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1823.19 | 282.90 | 0.16x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2183.06 | 339.31 | 0.16x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1692.65 | 206.45 | 0.12x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1785.97 | 297.12 | 0.17x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1911.85 | 315.38 | 0.16x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.52 | 328.28 | 0.17x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2221.66 | 324.47 | 0.15x |

After:


   [after change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(16777216, 32, 0, 1)

~ 1000 GBps
|---------------|---------------------|---------------------------|---------------|
| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
| (16, 1, 1024, 32768, 5, 1, 128) | 974.43 | 368.21 | 0.38x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1547.81 | 664.53 | 0.43x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2464.77 | 1060.36 | 0.43x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1582.76 | 929.04 | 0.59x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2078.04 | 1443.88 | 0.69x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1674.33 | 694.27 | 0.41x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2630.66 | 1101.50 | 0.42x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1670.73 | 1147.36 | 0.69x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1664.33 | 907.95 | 0.55x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2152.65 | 1524.07 | 0.71x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2558.07 | 1161.96 | 0.45x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1672.36 | 1195.78 | 0.72x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1754.12 | 1126.56 | 0.64x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1824.65 | 961.22 | 0.53x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2181.59 | 1591.11 | 0.73x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1712.90 | 1190.96 | 0.70x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1788.32 | 1156.16 | 0.65x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1909.89 | 1228.37 | 0.64x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.10 | 1016.18 | 0.53x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2203.02 | 1644.25 | 0.75x |

Reviewed By: y-sq, sijiac

Differential Revision: D79282737
Summary:
Pull Request resolved: meta-pytorch#323

The current code applies conversion to uint8 after expand. This results in in non-zero stride in the qhead dim of KV tensors.

The performance of Triton is significantly affected probably due to no gqa packing.

Before:

  [before change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(83886080, 160, 32, 1)
220 GBps

| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
|---------------|---------------------|---------------------------|---------------|
| (16, 1, 1024, 32768, 5, 1, 128) | 971.46 | 226.55 | 0.23x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1559.77 | 226.45 | 0.15x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2347.07 | 221.76 | 0.09x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1536.70 | 241.50 | 0.16x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2064.07 | 325.89 | 0.16x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1684.13 | 225.33 | 0.13x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2615.06 | 226.94 | 0.09x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1642.11 | 241.08 | 0.15x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1652.16 | 255.77 | 0.15x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2136.59 | 335.58 | 0.16x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2530.08 | 238.56 | 0.09x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1731.80 | 257.47 | 0.15x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1755.01 | 275.44 | 0.16x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1823.19 | 282.90 | 0.16x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2183.06 | 339.31 | 0.16x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1692.65 | 206.45 | 0.12x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1785.97 | 297.12 | 0.17x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1911.85 | 315.38 | 0.16x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.52 | 328.28 | 0.17x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2221.66 | 324.47 | 0.15x |

After:

   [after change] _k shape : stride = torch.Size([1, 524288, 5, 32]):(16777216, 32, 0, 1)

~ 1000 GBps
|---------------|---------------------|---------------------------|---------------|
| **(Batch, SeqLenQ, SeqLenKV, MaxLenKV, HeadQ, HeadKV, HeadD)** | triton_splitk (GB/s) | triton_splitk_fp8kv (GB/s) | **FP8/BF16** |
| (16, 1, 1024, 32768, 5, 1, 128) | 974.43 | 368.21 | 0.38x |
| (16, 1, 2048, 32768, 5, 1, 128) | 1547.81 | 664.53 | 0.43x |
| (16, 1, 4096, 32768, 5, 1, 128) | 2464.77 | 1060.36 | 0.43x |
| (16, 1, 8190, 32768, 5, 1, 128) | 1582.76 | 929.04 | 0.59x |
| (16, 1, 32760, 32768, 5, 1, 128) | 2078.04 | 1443.88 | 0.69x |
| (32, 1, 1024, 32768, 5, 1, 128) | 1674.33 | 694.27 | 0.41x |
| (32, 1, 2048, 32768, 5, 1, 128) | 2630.66 | 1101.50 | 0.42x |
| (32, 1, 4096, 32768, 5, 1, 128) | 1670.73 | 1147.36 | 0.69x |
| (32, 1, 8190, 32768, 5, 1, 128) | 1664.33 | 907.95 | 0.55x |
| (32, 1, 32760, 32768, 5, 1, 128) | 2152.65 | 1524.07 | 0.71x |
| (64, 1, 1024, 32768, 5, 1, 128) | 2558.07 | 1161.96 | 0.45x |
| (64, 1, 2048, 32768, 5, 1, 128) | 1672.36 | 1195.78 | 0.72x |
| (64, 1, 4096, 32768, 5, 1, 128) | 1754.12 | 1126.56 | 0.64x |
| (64, 1, 8190, 32768, 5, 1, 128) | 1824.65 | 961.22 | 0.53x |
| (64, 1, 32760, 32768, 5, 1, 128) | 2181.59 | 1591.11 | 0.73x |
| (128, 1, 1024, 32768, 5, 1, 128) | 1712.90 | 1190.96 | 0.70x |
| (128, 1, 2048, 32768, 5, 1, 128) | 1788.32 | 1156.16 | 0.65x |
| (128, 1, 4096, 32768, 5, 1, 128) | 1909.89 | 1228.37 | 0.64x |
| (128, 1, 8190, 32768, 5, 1, 128) | 1922.10 | 1016.18 | 0.53x |
| (128, 1, 32760, 32768, 5, 1, 128) | 2203.02 | 1644.25 | 0.75x |

Reviewed By: y-sq, sijiac

Differential Revision: D79282737
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D79282737

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants