Skip to content

Make token group alignment size configurable #1503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from torchtitan.config.job_config import Float8, JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.llama4.infra.expert_parallel import (
set_token_group_alignment_size_m,
)
from torchtitan.protocols.model_converter import (
ModelConverter,
register_model_converter,
Expand Down Expand Up @@ -66,6 +69,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
job_config.parallelism.context_parallel_degree == 1
), "Float8 MoE training prototype does not yet support context parallelism"

# For fp8 grouped GEMM, token group sizes must be multiples of 16
# (16 byte alignment / 1 byte per elem = 16 elements)
set_token_group_alignment_size_m(16)

if float8_config.recipe_name is not None:
assert not float8_config.enable_fsdp_float8_all_gather, (
"using `float8_config.enable_fsdp_float8_all_gather` together "
Expand Down
10 changes: 10 additions & 0 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
and job_config.parallelism.tensor_parallel_degree > 1
), "TP not yet supported with torch.compile for mxfp8"

# For MoE training with mxfp8, token group sizes must be multiples of 32
if job_config.mx.moe_fqns_prototype:
from torchtitan.experiments.llama4.infra.expert_parallel import (
set_token_group_alignment_size,
)

mxfp8_block_size = 32
set_token_group_alignment_size(mxfp8_block_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need to do this for Float8 as well, as IIRC it supports grouped gemm too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but the default (16) is what is needed for float8, so no need to manually set it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should use 16 as default.
For bf16, is 16 enough or is 8 enough?
I think we should still set it, in case the default changes later.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yeah I think you're right.

  • For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
  • For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
  • For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32), so when doing per-token-group quantization on each logically distinct subtensor, we need to ensure the contracting dim is divisible by block_size. In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M, so we need 32 element alignment.

Updated this accordingly.

logger.info(f"Setting token group alignment size to {mxfp8_block_size}")

# Configure MXFP8
from torchao.prototype.mx_formats.config import (
MXFP8Dim1CastKernelChoice,
Expand Down
9 changes: 8 additions & 1 deletion torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,19 @@ class MX:

filter_fqns: list[str] = field(default_factory=lambda: ["output"])
"""
Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to.
Comma-separated list of fully qualified names of modules to skip applying mxfp8 training to.
nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements.
By default we always skip the output layer.
Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output"
"""

moe_fqns_prototype: list[str] | str = field(default_factory=list)
"""
Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to.
This is a prototype feature that requires the torchao nightly build.
Example: --mx.moe_fqns_prototype="experts"
"""


@dataclass
class Comm:
Expand Down
35 changes: 31 additions & 4 deletions torchtitan/experiments/llama4/infra/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


from functools import partial
from typing import Callable
from typing import Callable, Literal

import torch
import torch.distributed as dist
Expand All @@ -24,6 +24,33 @@
from torch.distributed.tensor.placement_types import Placement


TOKEN_GROUP_ALIGN_SIZE_M = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK for now. Later we may want to set this as private field and provide a getter function too.

ValidTokenGroupAlignmentSize = Literal[8, 16, 32]


def set_token_group_alignment_size_m(
alignment_size: ValidTokenGroupAlignmentSize,
) -> None:
"""
Set the token group alignment size for token groups in MoE. This is implemented by
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.

Valid values are: 8, 16, or 32.
Different values are needed for different cases:

* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
* For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
* For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32),
so when doing per-token-group quantization on each logically distinct subtensor,
we need to ensure the contracting dim is divisible by block_size.
In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims
of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M,
so we need 32 element alignment.
"""
global TOKEN_GROUP_ALIGN_SIZE_M
TOKEN_GROUP_ALIGN_SIZE_M = alignment_size


# implementation of Tensor Parallel for the GroupedExperts in MoE
class TensorParallel(ParallelStyle):
def _partition_fn(self, name, module, device_mesh):
Expand Down Expand Up @@ -251,6 +278,7 @@ def wrapper(
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
global TOKEN_GROUP_ALIGN_SIZE_M
if isinstance(w1, DTensor):
w1 = w1.to_local()
w2 = w2.to_local()
Expand All @@ -264,7 +292,6 @@ def wrapper(
experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

ALIGN_SIZE_M = 16
with torch.no_grad():
(
permuted_indices,
Expand All @@ -274,8 +301,8 @@ def wrapper(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M,
ALIGN_SIZE_M,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
Expand Down
Loading