-
Notifications
You must be signed in to change notification settings - Fork 416
Better sharding for dsv3 moe layer #2373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@richjames0 for thoughts =D
@@ -300,8 +300,13 @@ def __init__( | |||
self.quant = quant | |||
self.rngs = rngs | |||
|
|||
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") | |||
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") | |||
# special sharding for dsv3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@richjames0 for all sharding changes =D
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp") | ||
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp") | ||
# special sharding for dsv3 | ||
if self.config.num_experts == 256: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a better logic for such a specific conditional. It should be controllable via a base.yml config argument like ("expert_first_dim")
This change doesn't surprise me that expert first is the most performant. Are there are any downsides to just flipping it here? Other than having to modify all of our checkpoint conversion scripts =D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If modifying the checkpoint conversion scripts is really the hardest part we can support both options (with a base.yml config) for a while with a deprecation warning on the current behavior to migrate to the new one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Qinwen! Could you also attach test results in the description?
w1_pspec = nn.logical_to_mesh_axes(("exp", "embed_tensor_transpose", "mlp_no_fsdp")) | ||
wo_pspec = nn.logical_to_mesh_axes(("exp", "mlp_no_fsdp", "embed_tensor_transpose")) | ||
# special sharding for dsv3 to remove overhead between gmm/AG | ||
if self.config.num_experts == 256: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment with Matt. For instance, we could have a flag similar to expert_shard_attention_option
, and we could have expert_shard_mlp_option
or similar.
This condition self.config.num_experts == 256
may raise up question, how about 128 or 512 experts?
Description
land sharding strategy for moe layer.
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):