Skip to content

Commit 6b953db

Browse files
use literal
1 parent 41230f4 commit 6b953db

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from functools import partial
9-
from typing import Callable
9+
from typing import Callable, Literal
1010

1111
import torch
1212
import torch.distributed as dist
@@ -25,12 +25,15 @@
2525

2626

2727
TOKEN_GROUP_ALIGN_SIZE_M = 8
28+
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
2829

2930

30-
def set_token_group_alignment_size_m(m: int) -> None:
31+
def set_token_group_alignment_size_m(m: ValidTokenGroupAlignmentSize) -> None:
3132
"""
3233
Set the token group alignment size for token groups in MoE. This is implemented by
3334
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
35+
36+
Valid values are: 8, 16, or 32.
3437
Different values are needed for different cases:
3538
3639
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).

0 commit comments

Comments
 (0)