-
Notifications
You must be signed in to change notification settings - Fork 315
add differentiable mxfp8 grouped gemm with dynamic quant (forward pass) #2627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2627
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5376f65 with merge base 9834869 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
453201b
to
b3962a0
Compare
e29fb79
to
fa77af5
Compare
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
b3962a0
to
9792e76
Compare
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
9792e76
to
fd92301
Compare
offs: Optional[torch.Tensor] = None, | ||
block_size: int = 32, | ||
out_dtype: Optional[torch.dtype] = torch.bfloat16, | ||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add an emulated
flag and assert that it's True until we have a real kernel, to make the intent crystal clear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -249,3 +249,25 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts): | |||
sqnr = compute_error(ref_out, out) | |||
min_sqnr = 27.0 | |||
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" | |||
|
|||
|
|||
@pytest.mark.parametrize("M", (1024, 4096)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: might be good to test one case where MKN are the same, and one where they are all different. If you want to do that and keep # of tests manageable, it would probably be iterating on MKN in one go instead of iterating on each individually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to parameterize M,N,K together and test "all same" and "all different" cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ergh my stack-pr got in a weird state, changes didn't go through somehow... let me try again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok it's updated now.
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
fd92301
to
18991a4
Compare
18991a4
to
fd04d1a
Compare
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
fd04d1a
to
5376f65
Compare
Stacked PRs:
add differentiable mxfp8 grouped gemm with dynamic quant (forward pass)