6
6
7
7
8
8
from functools import partial
9
- from typing import Callable
9
+ from typing import Callable , Literal
10
10
11
11
import torch
12
12
import torch .distributed as dist
25
25
26
26
27
27
TOKEN_GROUP_ALIGN_SIZE_M = 8
28
+ ValidTokenGroupAlignmentSize = Literal [8 , 16 , 32 ]
28
29
29
30
30
- def set_token_group_alignment_size_m (m : int ) -> None :
31
+ def set_token_group_alignment_size_m (
32
+ alignment_size : ValidTokenGroupAlignmentSize ,
33
+ ) -> None :
31
34
"""
32
35
Set the token group alignment size for token groups in MoE. This is implemented by
33
36
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
37
+
38
+ Valid values are: 8, 16, or 32.
34
39
Different values are needed for different cases:
35
40
36
41
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
@@ -43,8 +48,7 @@ def set_token_group_alignment_size_m(m: int) -> None:
43
48
so we need 32 element alignment.
44
49
"""
45
50
global TOKEN_GROUP_ALIGN_SIZE_M
46
- assert m > 0 , "Alignment size must be positive"
47
- TOKEN_GROUP_ALIGN_SIZE_M = m
51
+ TOKEN_GROUP_ALIGN_SIZE_M = alignment_size
48
52
49
53
50
54
# implementation of Tensor Parallel for the GroupedExperts in MoE
0 commit comments