Skip to content

Commit d4495e1

Browse files
Fixes accuracy issues with backward pass of ring-of-experts technique
PiperOrigin-RevId: 808794168
1 parent fd2ff2b commit d4495e1

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

src/MaxText/layers/moe.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
515515
intermediate_layer = jnp.multiply(layer_act, layer_w1)
516516
return intermediate_layer.astype(self.dtype)
517517

518-
def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None):
518+
def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None):
519519
"""Permute tokens to group by expert to fit gmm call."""
520520
# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
521521
inputs_shape = inputs.shape
@@ -530,6 +530,8 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
530530
inputs_2d = inputs_2d * router_scores.reshape(bsz_times_seq_len, -1)
531531

532532
flatten_selected_experts = jnp.ravel(selected_experts)
533+
if roll_to_expert_id is not None:
534+
flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % self.num_experts
533535
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
534536
# sort inputs for number of selected experts
535537
replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
@@ -942,23 +944,17 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
942944
)
943945

944946
# "Route" tokens within each shard.
945-
x, sorted_selected_experts, weights, full_group_sizes, selected_experts = self.permute(
946-
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp)
947+
num_experts_per_shard = self.config.num_experts // num_expert_parallelism
948+
x, sorted_selected_experts, weights, group_sizes, selected_experts = self.permute(
949+
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp,
950+
roll_to_expert_id=num_experts_per_shard * expert_shard_id)
947951

948952
# Filter down to the group sizes that apply to only the experts in the
949953
# current shard.
950-
full_group_sizes = jnp.reshape(full_group_sizes, (num_expert_parallelism, -1))
951-
group_sizes = full_group_sizes[expert_shard_id]
952-
953-
# Move the tokens for the experts in the current shard to the start of
954-
# the inputs array.
955-
num_tokens_to_skip = (
956-
jnp.cumsum(jnp.sum(full_group_sizes, axis=-1))[expert_shard_id] -
957-
jnp.sum(full_group_sizes[expert_shard_id])
958-
)
959-
x = jnp.roll(x, shift=-num_tokens_to_skip, axis=0)
954+
group_sizes = group_sizes[:num_experts_per_shard]
955+
mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes)
956+
x = jnp.where(mask[:, None], x, 0)
960957
else:
961-
num_tokens_to_skip = None
962958
x, sorted_selected_experts, weights, group_sizes, selected_experts = self.permute(
963959
x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs)
964960

@@ -1056,9 +1052,6 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
10561052
mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(group_sizes)
10571053
intermediate_output = jnp.where(mask[:, None], intermediate_output, 0)
10581054

1059-
# Move the tokens back to their original positions.
1060-
intermediate_output = jnp.roll(intermediate_output, shift=num_tokens_to_skip, axis=0)
1061-
10621055
# Unsort and deduplicate the outputs locally.
10631056
output = self.unpermute(
10641057
intermediate_output,

0 commit comments

Comments
 (0)