@@ -53,11 +53,14 @@ def _sort_activations(
53
53
) -> jax .Array :
54
54
"""Sort activations by `sort_indices`.
55
55
56
- If `use_custom_vjp` is True, then we use a custom backward pass that
56
+ If `use_custom_vjp= True` , then we use a custom backward pass that
57
57
reverses the sort order. Specifically, this unsort operation is simply a sort
58
58
with `jnp.argsort(sort_indices)` as the sort indices. This is only needed in
59
59
the case where the compiler generates a less efficient backward pass op.
60
-
60
+
61
+ Note that `use_custom_vjp=True` assumes that `sort_indices` is a permutation
62
+ of `jnp.arange(inputs.shape[0])`.
63
+
61
64
Args:
62
65
inputs: `(tokens, ...)`-shaped array of input activations to sort.
63
66
sort_indices: `(tokens,)`-shaped array containing the sort order.
@@ -511,7 +514,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
511
514
layer_act = self .activation_fn (layer_w0 )
512
515
intermediate_layer = jnp .multiply (layer_act , layer_w1 )
513
516
return intermediate_layer .astype (self .dtype )
514
-
517
+
515
518
def permute (self , inputs , gate_logits , pre_bias_logits , use_custom_sort_vjp = True , rngs = None ):
516
519
"""Permute tokens to group by expert to fit gmm call."""
517
520
# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
@@ -528,12 +531,9 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
528
531
529
532
flatten_selected_experts = jnp .ravel (selected_experts )
530
533
sorted_selected_experts = jnp .argsort (flatten_selected_experts )
531
- sorted_indices = sorted_selected_experts // self .num_experts_per_tok
532
534
# sort inputs for number of selected experts
533
- replicated_inputs_2d = jnp .reshape (
534
- jnp .broadcast_to (inputs_2d [None , ...], (self .num_experts_per_tok , * inputs_2d .shape )),
535
- (self .num_experts_per_tok * inputs_2d .shape [0 ], inputs_2d .shape [1 ]))
536
- sorted_inputs = _sort_activations (replicated_inputs_2d , sorted_indices , use_custom_sort_vjp ).astype (self .dtype )
535
+ replicated_inputs_2d = jnp .repeat (inputs_2d , self .num_experts_per_tok , axis = 0 )
536
+ sorted_inputs = _sort_activations (replicated_inputs_2d , sorted_selected_experts , use_custom_sort_vjp ).astype (self .dtype )
537
537
group_size = jnp .bincount (flatten_selected_experts , length = self .num_experts )
538
538
# Return the experts for each sorted input.
539
539
expert_indices = jnp .arange (self .num_experts )
0 commit comments