Skip to content

Make scaling type configurable for MoE training #2642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 30, 2025

Stacked PRs:


Make scaling type configurable for MoE training

Summary

  • Update user facing MoE conversion api to make scaling type configurable
  • Note: after Make token group alignment size configurable torchtitan#1503 lands making token group alignment size configurable, I don't think we'll actually need "per token group" scaling for mxfp8 for using torchtitan, since the scaling groups will no longer cross token group boundaries. However, for now I am leaving this in, since it's still numerically equivalent and we may need this functionality for other pretraining frameworks/models - it's still early days.

Test plan

  • Added integration test using torchtitan w/ changes to make token group size alignment configurable

stack-info: PR: #2642, branch: danielvegamyhre/stack/26
Copy link

pytorch-bot bot commented Jul 30, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2642

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bb05933 with merge base 1f0d2bb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Jul 30, 2025
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/26 branch from 507b6cc to 4fbf578 Compare July 30, 2025 23:19
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 30, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jul 30, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/25 to main July 31, 2025 17:24
danielvegamyhre added a commit that referenced this pull request Jul 31, 2025
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/26 branch from 4fbf578 to 1434e9b Compare July 31, 2025 17:24
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/25 July 31, 2025 17:24
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/25 to main July 31, 2025 17:33
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/26 branch from 1434e9b to a5403ac Compare July 31, 2025 17:33
danielvegamyhre added a commit that referenced this pull request Jul 31, 2025
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/25 July 31, 2025 17:33
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/25 to main July 31, 2025 22:31
danielvegamyhre added a commit that referenced this pull request Jul 31, 2025
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/26 branch from a5403ac to a828d09 Compare July 31, 2025 22:31
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/25 July 31, 2025 22:31
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Aug 1, 2025
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Aug 1, 2025
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Aug 1, 2025
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Aug 1, 2025
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit to pytorch/torchtitan that referenced this pull request Aug 1, 2025
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg for prototype, we might need to change this later

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/25 to main August 1, 2025 18:37
danielvegamyhre added a commit that referenced this pull request Aug 1, 2025
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/26 branch from a828d09 to 82e707e Compare August 1, 2025 18:37
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/25 August 1, 2025 18:37
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/25 branch from e9ba18b to 2aabb15 Compare August 1, 2025 20:18
danielvegamyhre added a commit that referenced this pull request Aug 1, 2025
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/26 branch 2 times, most recently from 1b362ee to bb05933 Compare August 1, 2025 20:18
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/25 to main August 1, 2025 20:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants