File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
torchtitan/experiments/llama4/infra Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change 6
6
7
7
8
8
from functools import partial
9
- from typing import Callable
9
+ from typing import Callable , Literal
10
10
11
11
import torch
12
12
import torch .distributed as dist
25
25
26
26
27
27
TOKEN_GROUP_ALIGN_SIZE_M = 8
28
+ ValidTokenGroupAlignmentSize = Literal [8 , 16 , 32 ]
28
29
29
30
30
- def set_token_group_alignment_size_m (m : int ) -> None :
31
+ def set_token_group_alignment_size_m (m : ValidTokenGroupAlignmentSize ) -> None :
31
32
"""
32
33
Set the token group alignment size for token groups in MoE. This is implemented by
33
34
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
35
+
36
+ Valid values are: 8, 16, or 32.
34
37
Different values are needed for different cases:
35
38
36
39
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
You can’t perform that action at this time.
0 commit comments