Skip to content

Commit 444c9e9

Browse files
use literal
1 parent 41230f4 commit 444c9e9

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from functools import partial
9-
from typing import Callable
9+
from typing import Callable, Literal
1010

1111
import torch
1212
import torch.distributed as dist
@@ -25,12 +25,17 @@
2525

2626

2727
TOKEN_GROUP_ALIGN_SIZE_M = 8
28+
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
2829

2930

30-
def set_token_group_alignment_size_m(m: int) -> None:
31+
def set_token_group_alignment_size_m(
32+
alignment_size: ValidTokenGroupAlignmentSize,
33+
) -> None:
3134
"""
3235
Set the token group alignment size for token groups in MoE. This is implemented by
3336
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
37+
38+
Valid values are: 8, 16, or 32.
3439
Different values are needed for different cases:
3540
3641
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
@@ -43,8 +48,7 @@ def set_token_group_alignment_size_m(m: int) -> None:
4348
so we need 32 element alignment.
4449
"""
4550
global TOKEN_GROUP_ALIGN_SIZE_M
46-
assert m > 0, "Alignment size must be positive"
47-
TOKEN_GROUP_ALIGN_SIZE_M = m
51+
TOKEN_GROUP_ALIGN_SIZE_M = alignment_size
4852

4953

5054
# implementation of Tensor Parallel for the GroupedExperts in MoE

0 commit comments

Comments
 (0)