From 8f27e56edafba025c14727e4a9301ccd8807e25f Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 31 Jul 2025 10:25:28 -0700 Subject: [PATCH 1/4] Make token group alignment size configurable --- torchtitan/components/quantization/mx.py | 7 +++++++ torchtitan/config/job_config.py | 8 +++++++- .../experiments/llama4/infra/expert_parallel.py | 17 ++++++++++++++--- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f2c6820a7..8c5b7b4fe 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -59,6 +59,13 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): and job_config.parallelism.tensor_parallel_degree > 1 ), "TP not yet supported with torch.compile for mxfp8" + # For MoE training with mxfp8, token group sizes must be multiples of 32 + if job_config.mx.moe_fqns_prototype: + from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size + mxfp8_block_size = 32 + set_token_group_alignment_size(mxfp8_block_size) + logger.info(f"Setting token group alignment size to {mxfp8_block_size}") + # Configure MXFP8 from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 5255de3da..1052494d6 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -567,12 +567,18 @@ class MX: filter_fqns: list[str] = field(default_factory=lambda: ["output"]) """ - Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to. + Comma-separated list of fully qualified names of modules to skip applying mxfp8 training to. nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements. By default we always skip the output layer. Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output" """ + moe_fqns_prototype: list[str] | str = field(default_factory=list) + """ + Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to. + This is a prototype feature that requires the torchao nightly build. + Example: --float8.moe_fqns_prototype="experts" + """ @dataclass class Comm: diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 0e8aef8ee..88daff4dd 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -24,6 +24,17 @@ from torch.distributed.tensor.placement_types import Placement +TOKEN_GROUP_ALIGN_SIZE_M = 16 + + +def set_token_group_alignment_size_m(m: int) -> None: + """Set the alignment size for token groups in MoE.""" + global TOKEN_GROUP_ALIGN_SIZE_M + assert m > 0, "Alignment size must be positive" + assert m % 16 == 0, "Alignment size must always be a multiple of 16 due to hardware constraints" + TOKEN_GROUP_ALIGN_SIZE_M = m + + # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): def _partition_fn(self, name, module, device_mesh): @@ -251,6 +262,7 @@ def wrapper( x: torch.Tensor, num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: + global TOKEN_GROUP_ALIGN_SIZE_M if isinstance(w1, DTensor): w1 = w1.to_local() w2 = w2.to_local() @@ -264,7 +276,6 @@ def wrapper( experts_per_ep_rank = w1.shape[0] num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank - ALIGN_SIZE_M = 16 with torch.no_grad(): ( permuted_indices, @@ -274,8 +285,8 @@ def wrapper( num_tokens_per_expert, experts_per_ep_rank, num_ep_ranks, - x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, - ALIGN_SIZE_M, + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + TOKEN_GROUP_ALIGN_SIZE_M, ) x = torch.vstack((x, x.new_zeros((x.shape[-1])))) From b55fbc751630ee54878f00fc4dc15a24370e8d5c Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 31 Jul 2025 13:53:32 -0700 Subject: [PATCH 2/4] address comments --- torchtitan/components/quantization/float8.py | 5 +++++ torchtitan/config/job_config.py | 2 +- .../llama4/infra/expert_parallel.py | 18 +++++++++++++++--- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 863ea266f..42e09996a 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -10,6 +10,7 @@ from torchtitan.config.job_config import Float8, JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, @@ -66,6 +67,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): job_config.parallelism.context_parallel_degree == 1 ), "Float8 MoE training prototype does not yet support context parallelism" + # For fp8 grouped GEMM, token group sizes must be multiples of 16 + # (16 byte alignment / 1 byte per elem = 16 elements) + set_token_group_alignment_size_m(16) + if float8_config.recipe_name is not None: assert not float8_config.enable_fsdp_float8_all_gather, ( "using `float8_config.enable_fsdp_float8_all_gather` together " diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 1052494d6..833a8a00f 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -577,7 +577,7 @@ class MX: """ Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to. This is a prototype feature that requires the torchao nightly build. - Example: --float8.moe_fqns_prototype="experts" + Example: --mx.moe_fqns_prototype="experts" """ @dataclass diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 88daff4dd..a57b14565 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -24,14 +24,26 @@ from torch.distributed.tensor.placement_types import Placement -TOKEN_GROUP_ALIGN_SIZE_M = 16 +TOKEN_GROUP_ALIGN_SIZE_M = 8 def set_token_group_alignment_size_m(m: int) -> None: - """Set the alignment size for token groups in MoE.""" + """ + Set the token group alignment size for token groups in MoE. This is implemented by + padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M. + Different values are needed for different cases: + + * For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements). + * For fp8, 16 byte alignment / 1 byte per elem = 16 elements. + * For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), + so when doing per-token-group quantization on each logically distinct subtensor, + we need to ensure the contracting dim is divisible by block_size. + In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims + of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, + so we need 32 element alignment. + """ global TOKEN_GROUP_ALIGN_SIZE_M assert m > 0, "Alignment size must be positive" - assert m % 16 == 0, "Alignment size must always be a multiple of 16 due to hardware constraints" TOKEN_GROUP_ALIGN_SIZE_M = m From 41230f4e413cf1b83954a29d14104cfd82dee091 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 31 Jul 2025 16:33:35 -0700 Subject: [PATCH 3/4] lint --- torchtitan/components/quantization/float8.py | 4 +++- torchtitan/components/quantization/mx.py | 5 ++++- torchtitan/config/job_config.py | 1 + torchtitan/experiments/llama4/infra/expert_parallel.py | 10 +++++----- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 42e09996a..58699b92e 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -10,7 +10,9 @@ from torchtitan.config.job_config import Float8, JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size_m +from torchtitan.experiments.llama4.infra.expert_parallel import ( + set_token_group_alignment_size_m, +) from torchtitan.protocols.model_converter import ( ModelConverter, register_model_converter, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 8c5b7b4fe..ce4d89ffe 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -61,7 +61,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # For MoE training with mxfp8, token group sizes must be multiples of 32 if job_config.mx.moe_fqns_prototype: - from torchtitan.experiments.llama4.infra.expert_parallel import set_token_group_alignment_size + from torchtitan.experiments.llama4.infra.expert_parallel import ( + set_token_group_alignment_size, + ) + mxfp8_block_size = 32 set_token_group_alignment_size(mxfp8_block_size) logger.info(f"Setting token group alignment size to {mxfp8_block_size}") diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 833a8a00f..39e81f7a9 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -580,6 +580,7 @@ class MX: Example: --mx.moe_fqns_prototype="experts" """ + @dataclass class Comm: init_timeout_seconds: int = 300 diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index a57b14565..de25b206f 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -30,16 +30,16 @@ def set_token_group_alignment_size_m(m: int) -> None: """ Set the token group alignment size for token groups in MoE. This is implemented by - padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M. + padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M. Different values are needed for different cases: - + * For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements). * For fp8, 16 byte alignment / 1 byte per elem = 16 elements. - * For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), + * For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), so when doing per-token-group quantization on each logically distinct subtensor, - we need to ensure the contracting dim is divisible by block_size. + we need to ensure the contracting dim is divisible by block_size. In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims - of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, + of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, so we need 32 element alignment. """ global TOKEN_GROUP_ALIGN_SIZE_M From 444c9e9c7788d504046fbf1f4aa6eafa9a161baf Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 31 Jul 2025 17:34:52 -0700 Subject: [PATCH 4/4] use literal --- .../experiments/llama4/infra/expert_parallel.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index de25b206f..f40dbae2b 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -6,7 +6,7 @@ from functools import partial -from typing import Callable +from typing import Callable, Literal import torch import torch.distributed as dist @@ -25,12 +25,17 @@ TOKEN_GROUP_ALIGN_SIZE_M = 8 +ValidTokenGroupAlignmentSize = Literal[8, 16, 32] -def set_token_group_alignment_size_m(m: int) -> None: +def set_token_group_alignment_size_m( + alignment_size: ValidTokenGroupAlignmentSize, +) -> None: """ Set the token group alignment size for token groups in MoE. This is implemented by padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M. + + Valid values are: 8, 16, or 32. Different values are needed for different cases: * For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements). @@ -43,8 +48,7 @@ def set_token_group_alignment_size_m(m: int) -> None: so we need 32 element alignment. """ global TOKEN_GROUP_ALIGN_SIZE_M - assert m > 0, "Alignment size must be positive" - TOKEN_GROUP_ALIGN_SIZE_M = m + TOKEN_GROUP_ALIGN_SIZE_M = alignment_size # implementation of Tensor Parallel for the GroupedExperts in MoE