-
Notifications
You must be signed in to change notification settings - Fork 421
Add conversion script for Qwen3 Next and Readme #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
add debug statements Conversion script ran without failing test verify orbax hf tensors Add unscanned conversion script for qwen3 next Move gating op to after sharding optimizations added zero centered rmsnorm Add layer by layer comparision script Remove debug files
src/MaxText/convert_qwen3_next.py
Outdated
| # llama_or_mistral_ckpt.save_weights_to_checkpoint( | ||
| # args.maxtext_model_path, | ||
| # jax_weights, | ||
| # args.simulated_cpu_devices_count, | ||
| # args.use_ocdbt, | ||
| # args.use_zarr3, | ||
| # ) | ||
| # max_logging.log("Checkpoint saved successfully.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused code ?
src/MaxText/convert_qwen3_next.py
Outdated
|
|
||
| model_params = MODEL_PARAMS_DICT[args.model_size] | ||
| max_logging.log(f"Starting conversion for Qwen3-Next model size: {args.model_size}") | ||
| # jax_weights = convert_hf_to_maxtext(args.base_model_path, model_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not required ?
src/MaxText/convert_qwen3_next.py
Outdated
| } | ||
|
|
||
|
|
||
| def verify_conversion(maxtext_weights: Dict[str, Any], chkpt_vars: Dict[str, torch.Tensor], model_params: Dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on this ?
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This Pull Request introduces comprehensive support for the Qwen3-Next model, including both scanned and unscanned checkpoint conversion scripts. The integration of heterogeneous layers and the new configuration validation are positive additions, demonstrating a thoughtful approach to supporting this new model.
🔍 General Feedback
- The overall structure for Qwen3-Next integration appears well-designed, particularly the handling of alternating Gated Delta Net and Gated Attention layers.
- The addition of configuration validation for
gdn_num_value_headsis a good practice. - There are a few areas identified for potential improvement in terms of code clarity, naming conventions, and a critical logic change in the attention mechanism that warrants further review and verification.
| if self.is_qwen3_next: | ||
| out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim) | ||
| out = out * jax.nn.sigmoid(gate) | ||
| out = self.out_projection(out, out_sharding=out_sharding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟠 High - The if self.is_qwen3_next: block has been moved to after the sharding logic. This means the reshaping and sigmoid gating for Qwen3-Next will now occur after the output has potentially been sharded. This could lead to incorrect behavior if the sharding expects a different shape or if the reshape/gating needs to happen before sharding. Please verify if this change is intentional and correct, or if the block should remain before the sharding logic.
| out = self.out_projection(out, out_sharding=out_sharding) | |
| if self.is_qwen3_next: | |
| out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim) | |
| out = out * jax.nn.sigmoid(gate) | |
| if model_mode == MODEL_MODE_PREFILL: | |
| out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names) | |
| elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: | |
| out = self._maybe_shard_with_logical(out, self.out_axis_names) | |
| else: | |
| out = self._maybe_shard_with_logical(out, self.decode_out_axis_names) |
| # limitations under the License. | ||
|
|
||
| """"Module for decoder layers.""" | ||
| """ "Module for decoder layers.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Low - Minor formatting: remove the extra space after the opening triple quotes in the docstring.
| """ "Module for decoder layers.""" | |
| """Module for decoder layers.""" |
| nn.logical_to_mesh_axes( | ||
| ( | ||
| "activation_batch", | ||
| "activation_length_no_exp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.
| "activation_length_no_exp", | |
| nn.logical_to_mesh_axes( | |
| ( | |
| "activation_batch", | |
| "activation_length_no_exp", | |
| "activation_embed", | |
| ) | |
| ), |
| """Applies final normalization and projects hidden states to logits.""" | ||
|
|
||
| cfg = self.config | ||
| if cfg.shard_mode == ShardMode.EXPLICIT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.
| if cfg.shard_mode == ShardMode.EXPLICIT: | |
| nn.logical_to_mesh_axes(( | |
| "activation_batch", | |
| "activation_length_no_exp", | |
| "activation_embed", | |
| )), |
| nn.logical_to_mesh_axes( | ||
| ( | ||
| "activation_embed_and_logits_batch", | ||
| "activation_length_no_exp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.
| "activation_length_no_exp", | |
| nn.logical_to_mesh_axes( | |
| ( | |
| "activation_embed_and_logits_batch", | |
| "activation_length_no_exp", | |
| "activation_vocab", | |
| ) | |
| ), |
| out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes((None, None, "activation_vocab"))) | ||
| else: | ||
| out_sharding = NamedSharding( | ||
| self.mesh, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.
| self.mesh, | |
| nn.logical_to_mesh_axes(( | |
| "activation_embed_and_logits_batch", | |
| "activation_length_no_exp", | |
| "activation_vocab", | |
| )), |
| ) | ||
|
|
||
| def __call__(self, hidden_states: Array) -> Array: | ||
| # hidden_states: (B, S, E) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Medium - The reshaping and splitting logic for qkvz and ba in the __call__ method is quite complex. While the comments are helpful, consider encapsulating some of this logic into smaller, well-named helper functions. This could improve readability, maintainability, and potentially reusability if similar patterns are used elsewhere.
| # hidden_states: (B, S, E) | |
| # STEP A: Input Projections | |
| # hidden_states: (B, S, E) | |
| qkvz = self.in_proj_qkvz(hidden_states) | |
| ba = self.in_proj_ba(hidden_states) | |
| query, key, value, z, b, a = self._split_and_reshape_qkvz_ba(batch, seq_len, qkvz, ba) | |
| # Flatten head dimensions for concatenation before conv | |
| q = query.reshape(batch, seq_len, -1) | |
| k = key.reshape(batch, seq_len, -1) | |
| v = value.reshape(batch, seq_len, -1) |
| keys: the raw config in dict form | ||
| """ | ||
| if keys["sparse_matmul"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Medium - The sparse_matmul check for Qwen3-Next has been removed. Please clarify if sparse_matmul is now supported for Qwen3-Next, or if the dense path is always intended for this model. If it's the latter, a comment explaining this decision would be beneficial for future maintainers.
|
|
||
|
|
||
| def create_scanned_layer_pytree(layer_idx) -> Dict[str, Any]: | ||
| """Creates the nested dictionary for one scanned layer.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Medium - The function create_scanned_layer_pytree is used in the unscanned conversion script. This name is misleading as it suggests it's creating a structure for scanned layers. Please rename it to create_unscanned_layer_pytree or a more general name like create_layer_pytree_structure to accurately reflect its purpose in this context.
Description
This pr is the final pr to have qwen3 next model fully supported in maxtext. The pr will include conversion scripts from huggingface, and verify logits comparision between the hf and maxtext model.
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
Tests
(Unscanned) Forward pass logit checker: https://paste.googleplex.com/6146326802857984
(Scanned) Forward pass logit checker: https://paste.googleplex.com/6195553369194496
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.