Skip to content

Add Triton kernels for fp8 blockwise quantization and GEMMs #2617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 28, 2025

Stacked PRs:


Add Triton kernels for fp8 blockwise quantization and GEMMs

  • I wrote the following triton kernels to perform fp8 blockwize quantization ops and GEMMs needed for the forward + backward of a linear layer
  • Added unit tests verifying numerics

GEMMS:

  • blockwise_fp8_gemm_1x128_128x128 for:
    • out = input @ weight.T
    • grad_input = grad_output @ weight
  • blockwise_fp8_gemm_1x128_128x1 for:
    • grad_weight = grad_output.T @ input

Quantization:

  • fp8_blockwise_act_quant_lhs
  • fp8_blockwise_act_quant_rhs
  • fp8_blockwise_act_quant_transposed_lhs
  • fp8_blockwise_weight_quant_rhs
  • fp8_blockwise_weight_quant_transposed_rhs

Test plan

  • pytest test/prototype/blockwise_fp8_training/test_blockwise_kernels.py

Attempted usage of DeepGEMM cutlass kernels

Unfortunately the GEMM APIs in @vkuzo's PoC here no longer exist in DeepGEMM. I tried using the new GEMM APIs (fp8_gemm_nt etc), and:

  • On B200, with both Vasiliy's PR and my PR, got device-side asserts on this line, that were not immediately clear how to resolve.
  • On H100, I only tried Vasiliy's PR, but got undefined symbols error from CUDA, despite using CUDA toolkit 12.8+ as stated in the readme.

Attempted usage of torch._scaled_mm

I also tried using torch._scaled_mm in torch nightly, and the error messages indicate it does not support groupwise scaled "A" tensor with blockwise scaled "B" tensor, so I went ahead and finished writing these triton GEMMs.

Today, however, I talked to Luca and it seems the error message is inaccurate and it is indeed supported, but will require some changes to scale strides and alignment to adhere to the API requirements.

If we want to make this prototype more performant, I can make these changes and swap out the GEMMs to use torch._scaled_mm (and update the PT core error message to be more accurate).

Copy link

pytorch-bot bot commented Jul 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2617

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 14420ef with merge base 0e00df3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 28, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/17 branch from 4c0250f to 89357e5 Compare July 28, 2025 16:06
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 28, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/17 branch 2 times, most recently from b2e5d4d to 7f21a73 Compare July 28, 2025 17:21
ref_fp32 = ref_fp8.to(torch.float32)

# Check that the quantized tensors are close
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be bit exact? If it's not bit exact and there no exact reason why not, I wouldn't really trust the triton kernel.

IMO I would just go with torch native kernels for everything for now (since they are the easiest to verify numerical correctness for) except the gemms for now, and leave writing triton kernels for quant of weights/activations as a future thing

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, updated the tests to use torch.equal instead of allclose to assert bitwise equivalence, let me know what you think.

@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/17 branch from 7f21a73 to 5ce55f6 Compare July 28, 2025 18:15

sqnr = compute_error(C, C_q)
min_sqnr = 28.0
print(f"blockwise_fp8_gemm_1x128_128x128 ({M},{N},{K}) SQNR: {sqnr}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the prints before landing

]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for sm90?

ref_fp32 = ref_fp8.to(torch.float32)

# Check that the quantized tensors are close
assert torch.equal(triton_fp32, ref_fp32), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: torch.testing.assert_close(..., rtol=0, atol=0) everywhere

stack-info: PR: #2617, branch: danielvegamyhre/stack/17
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/17 branch from 5ce55f6 to 14420ef Compare July 28, 2025 19:49
@danielvegamyhre
Copy link
Contributor Author

Thanks for the review @vkuzo, I finished addressing your comments, this is ready for another look.

@danielvegamyhre danielvegamyhre merged commit 6b82931 into main Jul 30, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants