-
Notifications
You must be signed in to change notification settings - Fork 459
Make token group alignment size configurable #1503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -59,6 +59,13 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||
and job_config.parallelism.tensor_parallel_degree > 1 | |||
), "TP not yet supported with torch.compile for mxfp8" | |||
|
|||
# For MoE training with mxfp8, token group sizes must be multiples of 32 | |||
if job_config.mx.moe_fqns_prototype: | |||
from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks OK to me.
I will do some refactor to move expert_parallel
into torchtitan/distributed
. Then the import will look nicer.
if job_config.mx.moe_fqns_prototype: | ||
from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size | ||
mxfp8_block_size = 32 | ||
set_token_group_alignment_size(mxfp8_block_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we need to do this for Float8 as well, as IIRC it supports grouped gemm too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but the default (16) is what is needed for float8, so no need to manually set it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we should use 16 as default.
For bf16, is 16 enough or is 8 enough?
I think we should still set it, in case the default changes later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually yeah I think you're right.
- For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
- For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
- For mxfp8, we need 32 (or
block_size
) because scaling block size is (1 x 32), so when doing per-token-group quantization on each logically distinct subtensor, we need to ensure the contracting dim is divisible by block_size. In the backward pass,grad_weight = (grad_output_t @ input).t()
has gemm dims (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, so we need 32 element alignment.
Updated this accordingly.
@tianyu-l I addressed your comments, and did a test run with llama4 debug model with bf16 to make sure it runs correctly with the new default. However, I keep getting linter errors, despite pre-commit passing locally. I uninstalled requirements-dev.txt and re-installed, ran precommit again, but it still says no errors locally and fails in CI. Any thoughts on how to proceed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please address comments before merge.
@@ -24,6 +24,29 @@ | |||
from torch.distributed.tensor.placement_types import Placement | |||
|
|||
|
|||
TOKEN_GROUP_ALIGN_SIZE_M = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK for now. Later we may want to set this as private field and provide a getter function too.
the error says it's extra whitespace in expert_parallel.py, is it not legit? |
Maybe it's legit but my confusion is how the local linter and CI linter are out of sync somehow, even after re-installation.... |
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
- For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
Summary
grad_weight = grad_output_t @ input
, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879Test plan