Skip to content

Commit fd2ff2b

Browse files
Fix token sorting with custom vjp in moe.py
PiperOrigin-RevId: 808698648
1 parent cbd599f commit fd2ff2b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/MaxText/layers/moe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,14 @@ def _sort_activations(
5353
) -> jax.Array:
5454
"""Sort activations by `sort_indices`.
5555
56-
If `use_custom_vjp` is True, then we use a custom backward pass that
56+
If `use_custom_vjp=True`, then we use a custom backward pass that
5757
reverses the sort order. Specifically, this unsort operation is simply a sort
5858
with `jnp.argsort(sort_indices)` as the sort indices. This is only needed in
5959
the case where the compiler generates a less efficient backward pass op.
60-
60+
61+
Note that `use_custom_vjp=True` assumes that `sort_indices` is a permutation
62+
of `jnp.arange(inputs.shape[0])`.
63+
6164
Args:
6265
inputs: `(tokens, ...)`-shaped array of input activations to sort.
6366
sort_indices: `(tokens,)`-shaped array containing the sort order.
@@ -511,7 +514,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
511514
layer_act = self.activation_fn(layer_w0)
512515
intermediate_layer = jnp.multiply(layer_act, layer_w1)
513516
return intermediate_layer.astype(self.dtype)
514-
517+
515518
def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None):
516519
"""Permute tokens to group by expert to fit gmm call."""
517520
# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
@@ -528,12 +531,9 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
528531

529532
flatten_selected_experts = jnp.ravel(selected_experts)
530533
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
531-
sorted_indices = sorted_selected_experts // self.num_experts_per_tok
532534
# sort inputs for number of selected experts
533-
replicated_inputs_2d = jnp.reshape(
534-
jnp.broadcast_to(inputs_2d[None, ...], (self.num_experts_per_tok, *inputs_2d.shape)),
535-
(self.num_experts_per_tok * inputs_2d.shape[0], inputs_2d.shape[1]))
536-
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_indices, use_custom_sort_vjp).astype(self.dtype)
535+
replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
536+
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(self.dtype)
537537
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
538538
# Return the experts for each sorted input.
539539
expert_indices = jnp.arange(self.num_experts)

0 commit comments

Comments
 (0)