Skip to content

Conversation

@Rohan-Bierneni
Copy link
Collaborator

@Rohan-Bierneni Rohan-Bierneni commented Nov 12, 2025

Description

This pr is the final pr to have qwen3 next model fully supported in maxtext. The pr will include conversion scripts from huggingface, and verify logits comparision between the hf and maxtext model.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456

Tests

(Unscanned) Forward pass logit checker: https://paste.googleplex.com/6146326802857984

(Scanned) Forward pass logit checker: https://paste.googleplex.com/6195553369194496

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

add debug statements

Conversion script ran without failing

test verify orbax hf tensors

Add unscanned conversion script for qwen3 next

Move gating op to after sharding optimizations

added zero centered rmsnorm

Add layer by layer comparision script

Remove debug files
Comment on lines 676 to 683
# llama_or_mistral_ckpt.save_weights_to_checkpoint(
# args.maxtext_model_path,
# jax_weights,
# args.simulated_cpu_devices_count,
# args.use_ocdbt,
# args.use_zarr3,
# )
# max_logging.log("Checkpoint saved successfully.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused code ?


model_params = MODEL_PARAMS_DICT[args.model_size]
max_logging.log(f"Starting conversion for Qwen3-Next model size: {args.model_size}")
# jax_weights = convert_hf_to_maxtext(args.base_model_path, model_params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not required ?

}


def verify_conversion(maxtext_weights: Dict[str, Any], chkpt_vars: Dict[str, torch.Tensor], model_params: Dict[str, Any]):
Copy link
Collaborator

@parambole parambole Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on this ?

@github-actions
Copy link

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces comprehensive support for the Qwen3-Next model, including both scanned and unscanned checkpoint conversion scripts. The integration of heterogeneous layers and the new configuration validation are positive additions, demonstrating a thoughtful approach to supporting this new model.

🔍 General Feedback

  • The overall structure for Qwen3-Next integration appears well-designed, particularly the handling of alternating Gated Delta Net and Gated Attention layers.
  • The addition of configuration validation for gdn_num_value_heads is a good practice.
  • There are a few areas identified for potential improvement in terms of code clarity, naming conventions, and a critical logic change in the attention mechanism that warrants further review and verification.

if self.is_qwen3_next:
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
out = out * jax.nn.sigmoid(gate)
out = self.out_projection(out, out_sharding=out_sharding)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - The if self.is_qwen3_next: block has been moved to after the sharding logic. This means the reshaping and sigmoid gating for Qwen3-Next will now occur after the output has potentially been sharded. This could lead to incorrect behavior if the sharding expects a different shape or if the reshape/gating needs to happen before sharding. Please verify if this change is intentional and correct, or if the block should remain before the sharding logic.

Suggested change
out = self.out_projection(out, out_sharding=out_sharding)
if self.is_qwen3_next:
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
out = out * jax.nn.sigmoid(gate)
if model_mode == MODEL_MODE_PREFILL:
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
out = self._maybe_shard_with_logical(out, self.out_axis_names)
else:
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)

# limitations under the License.

""""Module for decoder layers."""
""" "Module for decoder layers."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - Minor formatting: remove the extra space after the opening triple quotes in the docstring.

Suggested change
""" "Module for decoder layers."""
"""Module for decoder layers."""

nn.logical_to_mesh_axes(
(
"activation_batch",
"activation_length_no_exp",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
"activation_length_no_exp",
nn.logical_to_mesh_axes(
(
"activation_batch",
"activation_length_no_exp",
"activation_embed",
)
),

"""Applies final normalization and projects hidden states to logits."""

cfg = self.config
if cfg.shard_mode == ShardMode.EXPLICIT:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
if cfg.shard_mode == ShardMode.EXPLICIT:
nn.logical_to_mesh_axes((
"activation_batch",
"activation_length_no_exp",
"activation_embed",
)),

nn.logical_to_mesh_axes(
(
"activation_embed_and_logits_batch",
"activation_length_no_exp",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
"activation_length_no_exp",
nn.logical_to_mesh_axes(
(
"activation_embed_and_logits_batch",
"activation_length_no_exp",
"activation_vocab",
)
),

out_sharding = NamedSharding(self.mesh, nn.logical_to_mesh_axes((None, None, "activation_vocab")))
else:
out_sharding = NamedSharding(
self.mesh,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
self.mesh,
nn.logical_to_mesh_axes((
"activation_embed_and_logits_batch",
"activation_length_no_exp",
"activation_vocab",
)),

)

def __call__(self, hidden_states: Array) -> Array:
# hidden_states: (B, S, E)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - The reshaping and splitting logic for qkvz and ba in the __call__ method is quite complex. While the comments are helpful, consider encapsulating some of this logic into smaller, well-named helper functions. This could improve readability, maintainability, and potentially reusability if similar patterns are used elsewhere.

Suggested change
# hidden_states: (B, S, E)
# STEP A: Input Projections
# hidden_states: (B, S, E)
qkvz = self.in_proj_qkvz(hidden_states)
ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self._split_and_reshape_qkvz_ba(batch, seq_len, qkvz, ba)
# Flatten head dimensions for concatenation before conv
q = query.reshape(batch, seq_len, -1)
k = key.reshape(batch, seq_len, -1)
v = value.reshape(batch, seq_len, -1)

keys: the raw config in dict form
"""
if keys["sparse_matmul"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - The sparse_matmul check for Qwen3-Next has been removed. Please clarify if sparse_matmul is now supported for Qwen3-Next, or if the dense path is always intended for this model. If it's the latter, a comment explaining this decision would be beneficial for future maintainers.



def create_scanned_layer_pytree(layer_idx) -> Dict[str, Any]:
"""Creates the nested dictionary for one scanned layer."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - The function create_scanned_layer_pytree is used in the unscanned conversion script. This name is misleading as it suggests it's creating a structure for scanned layers. Please rename it to create_unscanned_layer_pytree or a more general name like create_layer_pytree_structure to accurately reflect its purpose in this context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants