Skip to content

[moe training] set token group alignment size to 16 for fp8 training test #2678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 5, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 4, 2025

In pytorch/torchtitan#1503 the default TOKEN_GROUP_ALIGNMENT_SIZE_M was changed from 16 (required for fp8) to 8 (minimum for bf16). See PR description for details.

Thus, in our fp8 training tests, we need to set it to 16. This is required so that
each logically distinct gemm in the grouped gemm grad_weight = grad_output_t @ input
has the contraction dim be divisible by 16. 16 byte alignment is required for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.

Test plan

Test: pytest test/prototype/moe_training/test_training.py

Error without change:

E       torch.AcceleratorError: CUDA error: device-side assert triggered
E       Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
E       Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

torchao/prototype/moe_training/scaled_grouped_mm.py:259: AcceleratorError
---------------------------------------------------------------------------------- Captured stderr call ----------------------------------------------------------------------------------
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [0,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [1,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [2,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [3,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
______________________________________________________________

With change, tests pass.

Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2678

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Cancelled Job

As of commit 3a87a56 with merge base 7dbc816 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2025
@danielvegamyhre danielvegamyhre added topic: not user facing Use this tag if you don't want this PR to show up in release notes and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Aug 4, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2025
@danielvegamyhre danielvegamyhre merged commit be40518 into main Aug 5, 2025
16 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants