@@ -23,76 +23,27 @@ def __init__(
23
23
):
24
24
super ().__init__ ()
25
25
self .num_experts = num_experts
26
- self .w1 = nn .Parameter (torch .empty (num_experts , dim , hidden_dim ))
27
- self .w2 = nn .Parameter (torch .empty (num_experts , hidden_dim , dim ))
28
- self .w3 = nn .Parameter (torch .empty (num_experts , dim , hidden_dim ))
26
+ # Combine w1 and w3 into a single tensor to perform so we can combine
27
+ # `x @ w1` and `x @ w3` into a single grouped mm.
28
+ self .w13 = nn .Parameter (torch .empty (num_experts , hidden_dim , dim * 2 ))
29
+ self .w2 = nn .Parameter (torch .empty (num_experts , dim , hidden_dim ))
29
30
self .use_grouped_mm = use_grouped_mm
30
31
31
32
def forward (
32
33
self ,
33
34
x : torch .Tensor ,
34
35
num_tokens_per_expert : torch .Tensor | None = None ,
35
36
) -> torch .Tensor :
36
- if self .use_grouped_mm :
37
- return GroupedExperts ._run_experts_grouped_mm (
38
- self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
39
- )
40
- else :
41
- return GroupedExperts ._run_experts_for_loop (
42
- self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
43
- )
44
-
45
- # TODO: keeping this for-loop implementation for comparison
46
- # and readability, may remove later
47
- @expert_parallel
48
- @staticmethod
49
- def _run_experts_for_loop (
50
- w1 : torch .Tensor ,
51
- w2 : torch .Tensor ,
52
- w3 : torch .Tensor ,
53
- x : torch .Tensor ,
54
- num_tokens_per_expert : torch .Tensor | None = None ,
55
- ) -> torch .Tensor :
56
- if num_tokens_per_expert is not None :
57
- # NOTE: this would incur a synchronization between device and host
58
- num_tokens_per_expert = num_tokens_per_expert .tolist ()
59
-
60
- # side-effect code due to the usage of generate_permute_indices
61
- num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
62
-
63
- # a tuple of tensors indexed by experts
64
- # each with shape (tokens_per_expert(varying), dim)
65
- x = torch .split (
66
- x [: sum (num_tokens_per_expert )],
67
- split_size_or_sections = num_tokens_per_expert ,
68
- dim = 0 ,
69
- )
70
- out_experts_splits = []
71
- for expert_idx , x_expert in enumerate (x ):
72
- h = F .silu (torch .matmul (x_expert , w1 [expert_idx ]))
73
- h = h * torch .matmul (x_expert , w3 [expert_idx ])
74
- h = torch .matmul (h , w2 [expert_idx ])
75
- # h shape (tokens_per_expert(varying), dim)
76
- out_experts_splits .append (h )
77
- out = torch .cat (out_experts_splits , dim = 0 )
78
-
79
- # side-effect code due to the usage of generate_permute_indices
80
- out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
81
- else :
82
- # x shape (num_experts, tokens_per_expert, dim)
83
- h = F .silu (torch .bmm (x , w1 ))
84
- h = h * torch .bmm (x , w3 )
85
- # out shape (num_experts, tokens_per_expert, dim)
86
- out = torch .bmm (h , w2 )
37
+ return GroupedExperts ._run_experts_grouped_mm (
38
+ self .w13 , self .w2 , x , num_tokens_per_expert
39
+ )
87
40
88
- return out
89
41
90
42
@expert_parallel
91
43
@staticmethod
92
44
def _run_experts_grouped_mm (
93
- w1 : torch .Tensor ,
45
+ w13 : torch .Tensor ,
94
46
w2 : torch .Tensor ,
95
- w3 : torch .Tensor ,
96
47
x : torch .Tensor ,
97
48
num_tokens_per_expert : torch .Tensor | None = None ,
98
49
) -> torch .Tensor :
@@ -105,16 +56,14 @@ def _run_experts_grouped_mm(
105
56
# fall back to regular bmm between 3D tensors
106
57
assert x .dim () == 3
107
58
108
- h = F .silu (torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 (), offs = offsets ))
109
- h = h * torch ._grouped_mm (x .bfloat16 (), w3 .bfloat16 (), offs = offsets )
110
- out = torch ._grouped_mm (h , w2 .bfloat16 (), offs = offsets ).type_as (x )
111
-
59
+ x1 , x3 = torch ._grouped_mm (x , w13 .transpose (- 2 , - 1 ), offs = offsets ).chunk (2 , dim = - 1 )
60
+ y = F .silu (x1 ) * x3
61
+ out = torch ._grouped_mm (y , w2 .transpose (- 2 , - 1 ), offs = offsets ).type_as (x )
112
62
return out
113
63
114
64
def init_weights (self , init_std : float ):
115
- nn .init .trunc_normal_ (self .w1 , mean = 0.0 , std = 0.02 )
65
+ nn .init .trunc_normal_ (self .w13 , mean = 0.0 , std = 0.02 )
116
66
nn .init .trunc_normal_ (self .w2 , mean = 0.0 , std = init_std )
117
- nn .init .trunc_normal_ (self .w3 , mean = 0.0 , std = init_std )
118
67
119
68
120
69
class TokenChoiceTopKRouter (nn .Module ):
@@ -299,7 +248,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
248
300
249
# shared expert
301
250
if self .shared_expert is not None :
302
- out = self .shared_expert (x .reshape (1 , bs * slen , dim )).reshape (
251
+ out = self .shared_expert (x .reshape (1 , bs * slen , dim ))
252
+ out = out .reshape (
303
253
bs * slen , dim
304
254
)
305
255
else :
0 commit comments