-
Notifications
You must be signed in to change notification settings - Fork 309
Make scaling type configurable for MoE training #2642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2642
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bb05933 with merge base 1f0d2bb ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
507b6cc
to
4fbf578
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
4fbf578
to
1434e9b
Compare
1434e9b
to
a5403ac
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
a5403ac
to
a828d09
Compare
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg for prototype, we might need to change this later
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
a828d09
to
82e707e
Compare
e9ba18b
to
2aabb15
Compare
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
1b362ee
to
bb05933
Compare
Stacked PRs:
Make scaling type configurable for MoE training
Summary
Test plan