Skip to content

Make token group alignment size configurable #1503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 1, 2025
Merged

Make token group alignment size configurable #1503

merged 4 commits into from
Aug 1, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 31, 2025

Summary

  • For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for grad_weight = grad_output_t @ input, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879
  • To solve this, this PR makes the token group M aligment configurable.

Test plan

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 31, 2025
@@ -59,6 +59,13 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
and job_config.parallelism.tensor_parallel_degree > 1
), "TP not yet supported with torch.compile for mxfp8"

# For MoE training with mxfp8, token group sizes must be multiples of 32
if job_config.mx.moe_fqns_prototype:
from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK to me.
I will do some refactor to move expert_parallel into torchtitan/distributed. Then the import will look nicer.

if job_config.mx.moe_fqns_prototype:
from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size
mxfp8_block_size = 32
set_token_group_alignment_size(mxfp8_block_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to do this for Float8 as well, as IIRC it supports grouped gemm too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but the default (16) is what is needed for float8, so no need to manually set it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should use 16 as default.
For bf16, is 16 enough or is 8 enough?
I think we should still set it, in case the default changes later.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yeah I think you're right.

  • For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
  • For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
  • For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), so when doing per-token-group quantization on each logically distinct subtensor, we need to ensure the contracting dim is divisible by block_size. In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, so we need 32 element alignment.

Updated this accordingly.

@danielvegamyhre
Copy link
Contributor Author

@tianyu-l I addressed your comments, and did a test run with llama4 debug model with bf16 to make sure it runs correctly with the new default. However, I keep getting linter errors, despite pre-commit passing locally. I uninstalled requirements-dev.txt and re-installed, ran precommit again, but it still says no errors locally and fails in CI. Any thoughts on how to proceed?

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address comments before merge.

@@ -24,6 +24,29 @@
from torch.distributed.tensor.placement_types import Placement


TOKEN_GROUP_ALIGN_SIZE_M = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK for now. Later we may want to set this as private field and provide a getter function too.

@tianyu-l
Copy link
Contributor

I addressed your comments, and did a test run with llama4 debug model with bf16 to make sure it runs correctly with the new default. However, I keep getting linter errors, despite pre-commit passing locally. I uninstalled requirements-dev.txt and re-installed, ran precommit again, but it still says no errors locally and fails in CI. Any thoughts on how to proceed?

the error says it's extra whitespace in expert_parallel.py, is it not legit?

@danielvegamyhre
Copy link
Contributor Author

I addressed your comments, and did a test run with llama4 debug model with bf16 to make sure it runs correctly with the new default. However, I keep getting linter errors, despite pre-commit passing locally. I uninstalled requirements-dev.txt and re-installed, ran precommit again, but it still says no errors locally and fails in CI. Any thoughts on how to proceed?

the error says it's extra whitespace in expert_parallel.py, is it not legit?

Maybe it's legit but my confusion is how the local linter and CI linter are out of sync somehow, even after re-installation....

@danielvegamyhre danielvegamyhre merged commit d655e16 into main Aug 1, 2025
8 checks passed
danielvegamyhre added a commit that referenced this pull request Aug 1, 2025
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit that referenced this pull request Aug 1, 2025
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit that referenced this pull request Aug 1, 2025
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
danielvegamyhre added a commit that referenced this pull request Aug 1, 2025
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.

## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants