Skip to content

Commit 2122f27

Browse files
Make token group alignment size configurable
1 parent 881f0ca commit 2122f27

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5959
and job_config.parallelism.tensor_parallel_degree > 1
6060
), "TP not yet supported with torch.compile for mxfp8"
6161

62+
# For MoE training with mxfp8, token group sizes must be multiples of 32
63+
if job_config.mx.moe_fqns_prototype:
64+
from torchtitan.experiments.llama4.infra import expert_parallel
65+
66+
expert_parallel.TOKEN_GROUP_ALIGN_SIZE_M = 32
67+
print(
68+
f"Setting TOKEN_GROUP_ALIGN_SIZE_M to {expert_parallel.TOKEN_GROUP_ALIGN_SIZE_M}"
69+
)
70+
6271
# Configure MXFP8
6372
from torchao.prototype.mx_formats.config import (
6473
MXFP8Dim1CastKernelChoice,

torchtitan/config/job_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,12 +544,18 @@ class MX:
544544

545545
filter_fqns: list[str] = field(default_factory=lambda: ["output"])
546546
"""
547-
Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to.
547+
Comma-separated list of fully qualified names of modules to skip applying mxfp8 training to.
548548
nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements.
549549
By default we always skip the output layer.
550550
Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output"
551551
"""
552552

553+
moe_fqns_prototype: list[str] | str = field(default_factory=list)
554+
"""
555+
Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to.
556+
This is a prototype feature that requires the torchao nightly build.
557+
Example: --float8.moe_fqns_prototype="experts"
558+
"""
553559

554560
@dataclass
555561
class Comm:

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from torch.distributed.tensor.placement_types import Placement
2525

2626

27+
TOKEN_GROUP_ALIGN_SIZE_M = 16
28+
29+
2730
# implementation of Tensor Parallel for the GroupedExperts in MoE
2831
class TensorParallel(ParallelStyle):
2932
def _partition_fn(self, name, module, device_mesh):
@@ -251,6 +254,7 @@ def wrapper(
251254
x: torch.Tensor,
252255
num_tokens_per_expert: torch.Tensor | None = None,
253256
) -> torch.Tensor:
257+
global TOKEN_GROUP_ALIGN_SIZE_M
254258
if isinstance(w1, DTensor):
255259
w1 = w1.to_local()
256260
w2 = w2.to_local()
@@ -264,7 +268,6 @@ def wrapper(
264268
experts_per_ep_rank = w1.shape[0]
265269
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
266270

267-
ALIGN_SIZE_M = 16
268271
with torch.no_grad():
269272
(
270273
permuted_indices,
@@ -274,8 +277,8 @@ def wrapper(
274277
num_tokens_per_expert,
275278
experts_per_ep_rank,
276279
num_ep_ranks,
277-
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M,
278-
ALIGN_SIZE_M,
280+
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
281+
TOKEN_GROUP_ALIGN_SIZE_M,
279282
)
280283

281284
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))

0 commit comments

Comments
 (0)