@@ -515,7 +515,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
515
515
intermediate_layer = jnp .multiply (layer_act , layer_w1 )
516
516
return intermediate_layer .astype (self .dtype )
517
517
518
- def permute (self , inputs , gate_logits , pre_bias_logits , use_custom_sort_vjp = True , rngs = None ):
518
+ def permute (self , inputs , gate_logits , pre_bias_logits , use_custom_sort_vjp = True , rngs = None , roll_to_expert_id = None ):
519
519
"""Permute tokens to group by expert to fit gmm call."""
520
520
# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
521
521
inputs_shape = inputs .shape
@@ -530,6 +530,8 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
530
530
inputs_2d = inputs_2d * router_scores .reshape (bsz_times_seq_len , - 1 )
531
531
532
532
flatten_selected_experts = jnp .ravel (selected_experts )
533
+ if roll_to_expert_id is not None :
534
+ flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id ) % self .num_experts
533
535
sorted_selected_experts = jnp .argsort (flatten_selected_experts )
534
536
# sort inputs for number of selected experts
535
537
replicated_inputs_2d = jnp .repeat (inputs_2d , self .num_experts_per_tok , axis = 0 )
@@ -942,23 +944,17 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
942
944
)
943
945
944
946
# "Route" tokens within each shard.
945
- x , sorted_selected_experts , weights , full_group_sizes , selected_experts = self .permute (
946
- x , logits , pre_bias_logits , self .config .use_custom_sort_vjp )
947
+ num_experts_per_shard = self .config .num_experts // num_expert_parallelism
948
+ x , sorted_selected_experts , weights , group_sizes , selected_experts = self .permute (
949
+ x , logits , pre_bias_logits , self .config .use_custom_sort_vjp ,
950
+ roll_to_expert_id = num_experts_per_shard * expert_shard_id )
947
951
948
952
# Filter down to the group sizes that apply to only the experts in the
949
953
# current shard.
950
- full_group_sizes = jnp .reshape (full_group_sizes , (num_expert_parallelism , - 1 ))
951
- group_sizes = full_group_sizes [expert_shard_id ]
952
-
953
- # Move the tokens for the experts in the current shard to the start of
954
- # the inputs array.
955
- num_tokens_to_skip = (
956
- jnp .cumsum (jnp .sum (full_group_sizes , axis = - 1 ))[expert_shard_id ] -
957
- jnp .sum (full_group_sizes [expert_shard_id ])
958
- )
959
- x = jnp .roll (x , shift = - num_tokens_to_skip , axis = 0 )
954
+ group_sizes = group_sizes [:num_experts_per_shard ]
955
+ mask = jnp .arange (x .shape [0 ]) < jnp .sum (group_sizes )
956
+ x = jnp .where (mask [:, None ], x , 0 )
960
957
else :
961
- num_tokens_to_skip = None
962
958
x , sorted_selected_experts , weights , group_sizes , selected_experts = self .permute (
963
959
x , logits , pre_bias_logits , self .config .use_custom_sort_vjp , rngs )
964
960
@@ -1056,9 +1052,6 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
1056
1052
mask = jnp .arange (intermediate_output .shape [0 ]) < jnp .sum (group_sizes )
1057
1053
intermediate_output = jnp .where (mask [:, None ], intermediate_output , 0 )
1058
1054
1059
- # Move the tokens back to their original positions.
1060
- intermediate_output = jnp .roll (intermediate_output , shift = num_tokens_to_skip , axis = 0 )
1061
-
1062
1055
# Unsort and deduplicate the outputs locally.
1063
1056
output = self .unpermute (
1064
1057
intermediate_output ,
0 commit comments