Skip to content

Conversation

suexu1025
Copy link
Collaborator

@suexu1025 suexu1025 commented Sep 19, 2025

Description

land sharding strategy for moe layer.

  • dsv3 step time decrease from 47s to 43s.
  • no change for mixtral 8x7b model.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@suexu1025 suexu1025 changed the title Better Sharding for dsv3 moe layer Better sharding for dsv3 moe layer Sep 19, 2025
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@richjames0 for thoughts =D

@@ -300,8 +300,13 @@ def __init__(
self.quant = quant
self.rngs = rngs

self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
# special sharding for dsv3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@richjames0 for all sharding changes =D

self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
# special sharding for dsv3
if self.config.num_experts == 256:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a better logic for such a specific conditional. It should be controllable via a base.yml config argument like ("expert_first_dim")

This change doesn't surprise me that expert first is the most performant. Are there are any downsides to just flipping it here? Other than having to modify all of our checkpoint conversion scripts =D

Copy link
Collaborator

@gobbleturk gobbleturk Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If modifying the checkpoint conversion scripts is really the hardest part we can support both options (with a base.yml config) for a while with a deprecation warning on the current behavior to migrate to the new one.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Qinwen! Could you also attach test results in the description?

w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp"))
wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose"))
# special sharding for dsv3 to remove overhead between gmm/AG
if self.config.num_experts == 256:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment with Matt. For instance, we could have a flag similar to expert_shard_attention_option, and we could have expert_shard_mlp_option or similar.

This condition self.config.num_experts == 256 may raise up question, how about 128 or 512 experts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants