-
Notifications
You must be signed in to change notification settings - Fork 310
Add Triton kernels for fp8 blockwise quantization and GEMMs #2617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2617
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 14420ef with merge base 0e00df3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
4c0250f
to
89357e5
Compare
b2e5d4d
to
7f21a73
Compare
ref_fp32 = ref_fp8.to(torch.float32) | ||
|
||
# Check that the quantized tensors are close | ||
assert torch.allclose(triton_fp32, ref_fp32, rtol=1e-3, atol=1e-3), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be bit exact? If it's not bit exact and there no exact reason why not, I wouldn't really trust the triton kernel.
IMO I would just go with torch native kernels for everything for now (since they are the easiest to verify numerical correctness for) except the gemms for now, and leave writing triton kernels for quant of weights/activations as a future thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, updated the tests to use torch.equal
instead of allclose to assert bitwise equivalence, let me know what you think.
7f21a73
to
5ce55f6
Compare
|
||
sqnr = compute_error(C, C_q) | ||
min_sqnr = 28.0 | ||
print(f"blockwise_fp8_gemm_1x128_128x128 ({M},{N},{K}) SQNR: {sqnr}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the prints before landing
] | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for sm90?
ref_fp32 = ref_fp8.to(torch.float32) | ||
|
||
# Check that the quantized tensors are close | ||
assert torch.equal(triton_fp32, ref_fp32), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: torch.testing.assert_close(..., rtol=0, atol=0)
everywhere
stack-info: PR: #2617, branch: danielvegamyhre/stack/17
5ce55f6
to
14420ef
Compare
Thanks for the review @vkuzo, I finished addressing your comments, this is ready for another look. |
Stacked PRs:
Add Triton kernels for fp8 blockwise quantization and GEMMs
GEMMS:
blockwise_fp8_gemm_1x128_128x128
for:out = input @ weight.T
grad_input = grad_output @ weight
blockwise_fp8_gemm_1x128_128x1
for:grad_weight = grad_output.T @ input
Quantization:
fp8_blockwise_act_quant_lhs
fp8_blockwise_act_quant_rhs
fp8_blockwise_act_quant_transposed_lhs
fp8_blockwise_weight_quant_rhs
fp8_blockwise_weight_quant_transposed_rhs
Test plan
pytest test/prototype/blockwise_fp8_training/test_blockwise_kernels.py
Attempted usage of DeepGEMM cutlass kernels
Unfortunately the GEMM APIs in @vkuzo's PoC here no longer exist in DeepGEMM. I tried using the new GEMM APIs (
fp8_gemm_nt
etc), and:Attempted usage of torch._scaled_mm
I also tried using torch._scaled_mm in torch nightly, and the error messages indicate it does not support groupwise scaled "A" tensor with blockwise scaled "B" tensor, so I went ahead and finished writing these triton GEMMs.
Today, however, I talked to Luca and it seems the error message is inaccurate and it is indeed supported, but will require some changes to scale strides and alignment to adhere to the API requirements.
If we want to make this prototype more performant, I can make these changes and swap out the GEMMs to use torch._scaled_mm (and update the PT core error message to be more accurate).