Skip to content

Commit b55fbc7

Browse files
address comments
1 parent 8f27e56 commit b55fbc7

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from torchtitan.config.job_config import Float8, JobConfig
1212
from torchtitan.distributed import ParallelDims
13+
from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size_m
1314
from torchtitan.protocols.model_converter import (
1415
ModelConverter,
1516
register_model_converter,
@@ -66,6 +67,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6667
job_config.parallelism.context_parallel_degree == 1
6768
), "Float8 MoE training prototype does not yet support context parallelism"
6869

70+
# For fp8 grouped GEMM, token group sizes must be multiples of 16
71+
# (16 byte alignment / 1 byte per elem = 16 elements)
72+
set_token_group_alignment_size_m(16)
73+
6974
if float8_config.recipe_name is not None:
7075
assert not float8_config.enable_fsdp_float8_all_gather, (
7176
"using `float8_config.enable_fsdp_float8_all_gather` together "

torchtitan/config/job_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ class MX:
577577
"""
578578
Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to.
579579
This is a prototype feature that requires the torchao nightly build.
580-
Example: --float8.moe_fqns_prototype="experts"
580+
Example: --mx.moe_fqns_prototype="experts"
581581
"""
582582

583583
@dataclass

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,26 @@
2424
from torch.distributed.tensor.placement_types import Placement
2525

2626

27-
TOKEN_GROUP_ALIGN_SIZE_M = 16
27+
TOKEN_GROUP_ALIGN_SIZE_M = 8
2828

2929

3030
def set_token_group_alignment_size_m(m: int) -> None:
31-
"""Set the alignment size for token groups in MoE."""
31+
"""
32+
Set the token group alignment size for token groups in MoE. This is implemented by
33+
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
34+
Different values are needed for different cases:
35+
36+
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
37+
* For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
38+
* For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32),
39+
so when doing per-token-group quantization on each logically distinct subtensor,
40+
we need to ensure the contracting dim is divisible by block_size.
41+
In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims
42+
of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M,
43+
so we need 32 element alignment.
44+
"""
3245
global TOKEN_GROUP_ALIGN_SIZE_M
3346
assert m > 0, "Alignment size must be positive"
34-
assert m % 16 == 0, "Alignment size must always be a multiple of 16 due to hardware constraints"
3547
TOKEN_GROUP_ALIGN_SIZE_M = m
3648

3749

0 commit comments

Comments
 (0)