[moe training] set token group alignment size to 16 for fp8 training test #2678
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In pytorch/torchtitan#1503 the default TOKEN_GROUP_ALIGNMENT_SIZE_M was changed from 16 (required for fp8) to 8 (minimum for bf16). See PR description for details.
Thus, in our fp8 training tests, we need to set it to 16. This is required so that
each logically distinct gemm in the grouped gemm
grad_weight = grad_output_t @ input
has the contraction dim be divisible by 16. 16 byte alignment is required for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
Test plan
Test:
pytest test/prototype/moe_training/test_training.py
Error without change:
With change, tests pass.