Skip to content
Open
3 changes: 3 additions & 0 deletions src/MaxText/configs/models/qwen3-next-80b-a3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ gdn_chunk_size: 64
# RoPE Settings
rope_max_timescale: 10000000
partial_rotary_factor: 0.25

# General Model Settings
enable_dropout: False
6 changes: 3 additions & 3 deletions src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,6 @@ def __call__(
bidirectional_mask,
self.sinks,
)
if self.is_qwen3_next:
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
out = out * jax.nn.sigmoid(gate)
if model_mode == MODEL_MODE_PREFILL:
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
Expand All @@ -1026,6 +1023,9 @@ def __call__(
out = self._maybe_shard_with_logical(out, self.out_axis_names)
else:
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)
if self.is_qwen3_next:
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
out = out * jax.nn.sigmoid(gate)
out = self.out_projection(out, out_sharding=out_sharding)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 High - The if self.is_qwen3_next: block has been moved to after the sharding logic. This means the reshaping and sigmoid gating for Qwen3-Next will now occur after the output has potentially been sharded. This could lead to incorrect behavior if the sharding expects a different shape or if the reshape/gating needs to happen before sharding. Please verify if this change is intentional and correct, or if the block should remain before the sharding logic.

Suggested change
out = self.out_projection(out, out_sharding=out_sharding)
if self.is_qwen3_next:
out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim)
out = out * jax.nn.sigmoid(gate)
if model_mode == MODEL_MODE_PREFILL:
out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
out = self._maybe_shard_with_logical(out, self.out_axis_names)
else:
out = self._maybe_shard_with_logical(out, self.decode_out_axis_names)

out = checkpoint_name(out, "out_proj")
return out
34 changes: 18 additions & 16 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""""Module for decoder layers."""
""" "Module for decoder layers."""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - Minor formatting: remove the extra space after the opening triple quotes in the docstring.

Suggested change
""" "Module for decoder layers."""
"""Module for decoder layers."""

# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

Expand All @@ -26,6 +26,7 @@

from flax import linen as nn
from flax import nnx
from flax.linen import initializers as linen_initializers
from flax.linen.partitioning import ScanIn

from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT
Expand All @@ -34,6 +35,7 @@
from MaxText import max_utils
from MaxText.inference import page_manager
from MaxText.layers import linears
from MaxText.layers import normalizations
from MaxText.layers import quantizations
from MaxText.layers import pipeline
from MaxText import maxtext_utils
Expand All @@ -58,6 +60,7 @@
qwen3,
simple_layer,
)
from MaxText.layers import nnx_wrappers

# ------------------------------------------------------------------------------
# The network: Decoder Definitions
Expand Down Expand Up @@ -465,7 +468,6 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.GEMMA3,
DecoderBlockType.QWEN3,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.QWEN3_NEXT,
DecoderBlockType.GPT_OSS,
DecoderBlockType.SIMPLE,
DecoderBlockType.SIMPLE_MLP,
Expand All @@ -474,6 +476,10 @@ def get_norm_layer(self, num_features: int):
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
elif self.config.decoder_block == DecoderBlockType.GPT3:
return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True)
elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT:
return functools.partial(
normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode
)
else:
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")

Expand Down Expand Up @@ -595,13 +601,11 @@ def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determ
if cfg.shard_mode == ShardMode.EXPLICIT:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
if cfg.shard_mode == ShardMode.EXPLICIT:
nn.logical_to_mesh_axes((
"activation_batch",
"activation_length_no_exp",
"activation_embed",
)),

norm_out_sharding = NamedSharding(
self.mesh,
nn.logical_to_mesh_axes(
(
"activation_batch",
"activation_length_no_exp",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
"activation_length_no_exp",
nn.logical_to_mesh_axes(
(
"activation_batch",
"activation_length_no_exp",
"activation_embed",
)
),

"activation_embed",
)
),
nn.logical_to_mesh_axes((
"activation_batch",
"activation_length_no_exp",
"activation_embed",
)),
)
else:
norm_out_sharding = None
Expand All @@ -621,13 +625,11 @@ def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determ
else:
out_sharding = NamedSharding(
self.mesh,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
self.mesh,
nn.logical_to_mesh_axes((
"activation_embed_and_logits_batch",
"activation_length_no_exp",
"activation_vocab",
)),

nn.logical_to_mesh_axes(
(
"activation_embed_and_logits_batch",
"activation_length_no_exp",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Low - The removal of extra parentheses around the tuple in nn.logical_to_mesh_axes is a stylistic improvement for readability.

Suggested change
"activation_length_no_exp",
nn.logical_to_mesh_axes(
(
"activation_embed_and_logits_batch",
"activation_length_no_exp",
"activation_vocab",
)
),

"activation_vocab",
)
),
nn.logical_to_mesh_axes((
"activation_embed_and_logits_batch",
"activation_length_no_exp",
"activation_vocab",
)),
)

# [batch, length, emb_dim] -> [batch, length, vocab_size]
Expand Down
8 changes: 8 additions & 0 deletions src/MaxText/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,11 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array:

inv_norm = jax.lax.rsqrt((x * x).sum(axis=dim, keepdims=True) + jnp.array(eps, dtype=x.dtype))
return x * inv_norm


Qwen3NextRMSNormLinen = nnx_wrappers.to_linen_class(
RMSNorm,
base_metadata_fn=variable_to_logically_partitioned,
scale_init=linen_initializers.zeros,
scale_offset=1.0,
)
78 changes: 65 additions & 13 deletions src/MaxText/layers/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs
self.value_dim = self.head_v_dim * self.num_v_heads
conv_dim = self.key_dim * 2 + self.value_dim
conv_kernel_size = cfg.gdn_conv_kernel_dim
self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads

# Submodule instantiations
self.in_proj_qkvz = linears.DenseGeneral(
Expand Down Expand Up @@ -380,33 +381,86 @@ def a_log_init(key, shape, dtype=jnp.float32):
)

def __call__(self, hidden_states: Array) -> Array:
# hidden_states: (B, S, E)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - The reshaping and splitting logic for qkvz and ba in the __call__ method is quite complex. While the comments are helpful, consider encapsulating some of this logic into smaller, well-named helper functions. This could improve readability, maintainability, and potentially reusability if similar patterns are used elsewhere.

Suggested change
# hidden_states: (B, S, E)
# STEP A: Input Projections
# hidden_states: (B, S, E)
qkvz = self.in_proj_qkvz(hidden_states)
ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self._split_and_reshape_qkvz_ba(batch, seq_len, qkvz, ba)
# Flatten head dimensions for concatenation before conv
q = query.reshape(batch, seq_len, -1)
k = key.reshape(batch, seq_len, -1)
v = value.reshape(batch, seq_len, -1)

cfg = self.config
batch, seq_len, _ = hidden_states.shape

# =========================================================================
# STEP A: Input Projections
# =========================================================================
# hidden_states shape: (B, S, E)
# qkvz shape: (B, S, 2*key_dim + 2*value_dim)
# qkvz: (B, S, 2 * K_dim + 2 * V_dim)
qkvz = self.in_proj_qkvz(hidden_states)
# ba shape: (B, S, 2*H_v)
# ba: (B, S, 2 * H_v)
ba = self.in_proj_ba(hidden_states)

# q shape: (B, S, key_dim), k shape: (B, S, key_dim), v shape: (B, S, value_dim), z shape: (B, S, value_dim)
q, k, v, z = jnp.split(qkvz, [self.key_dim, 2 * self.key_dim, 2 * self.key_dim + self.value_dim], axis=-1)
# b shape: (B, S, H_v), a shape: (B, S, H_v)
b, a = jnp.split(ba, [self.num_v_heads], axis=-1)
# QKVZ Reshaping and Splitting
# Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K
new_shape_qkvz = (
batch,
seq_len,
self.num_k_heads, # H_k
2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head,
)
# mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K)
mixed_qkvz = qkvz.reshape(new_shape_qkvz)

split_indices_qkvz = [
self.head_k_dim, # D_k
2 * self.head_k_dim, # 2 * D_k
2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim), # 2 * D_k + V_per_K * D_v
]
# query: (B, S, H_k, D_k)
# key: (B, S, H_k, D_k)
# value_raw: (B, S, H_k, V_per_K * D_v)
# z_raw: (B, S, H_k, V_per_K * D_v)
query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3)

# value: (B, S, H_v, D_v)
value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)
# z: (B, S, H_v, D_v)
z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)

# BA Reshaping and Splitting
new_shape_ba = (
batch,
seq_len,
self.num_k_heads, # H_k
2 * self.v_heads_per_k_head,
)
# mixed_ba: (B, S, H_k, 2 * V_per_K)
mixed_ba = ba.reshape(new_shape_ba)

split_indices_ba = [self.v_heads_per_k_head]
# b_raw: (B, S, H_k, V_per_K)
# a_raw: (B, S, H_k, V_per_K)
b_raw, a_raw = jnp.split(mixed_ba, split_indices_ba, axis=3)

# b: (B, S, H_v)
b = b_raw.reshape(batch, seq_len, self.num_v_heads)
# a: (B, S, H_v)
a = a_raw.reshape(batch, seq_len, self.num_v_heads)

# Flatten head dimensions for concatenation before conv
# q: (B, S, K_dim)
q = query.reshape(batch, seq_len, -1)
# k: (B, S, K_dim)
k = key.reshape(batch, seq_len, -1)
# v: (B, S, V_dim)
v = value.reshape(batch, seq_len, -1)

# =========================================================================
# STEP B: 1D Convolution
# =========================================================================
# qkv shape: (B, S, conv_dim)
# conv_dim = 2 * K_dim + V_dim
# qkv: (B, S, 2 * K_dim + V_dim)
qkv = jnp.concatenate([q, k, v], axis=-1)

# TODO(parambole): Implement caching logic for conv_state and recurrent_state

# Input to conv_layer should be (B, S, C)
# qkv_conv shape: (B, S, conv_dim)
qkv_conv = jax.nn.silu(self.conv1d(qkv).astype(jnp.float32)).astype(cfg.dtype)
conv_out = self.conv1d(qkv)
qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype)
# q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim)
q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1)

Expand Down Expand Up @@ -449,13 +503,11 @@ def __call__(self, hidden_states: Array) -> Array:
# =========================================================================
# STEP D: Final Output Stage
# =========================================================================

# The normalization and gating is applied per-head on the value dimension.
# We first reshape the `z` tensor to match the multi-head structure of `core_attn_out`.
# z shape from (B, S, value_dim) -> (B, S, H_v, D_v)
z_reshaped = z.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim)

# Apply the norm and gate. Output shape: (B, S, H_v, D_v)
gated_output_reshaped = self.norm(core_attn_out, z_reshaped)
gated_output_reshaped = self.norm(core_attn_out, z)

# Reshape back to a single feature dimension for the final projection.
# Shape from (B, S, H_v, D_v) -> (B, S, value_dim)
Expand Down
6 changes: 2 additions & 4 deletions src/MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,8 @@ def validate_qwen3_next_config(keys: dict):
keys: the raw config in dict form

"""
if keys["sparse_matmul"]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Medium - The sparse_matmul check for Qwen3-Next has been removed. Please clarify if sparse_matmul is now supported for Qwen3-Next, or if the dense path is always intended for this model. If it's the latter, a comment explaining this decision would be beneficial for future maintainers.

raise ValueError(
"For Qwen3-Next, sparse_matmul must be False for now. The dense path has been verified against reference."
)
if int(keys["gdn_num_value_heads"]) % int(keys["gdn_num_key_heads"]) != 0:
raise ValueError("gdn_num_value_heads must be divisible by gdn_num_key_heads")
rotary_dim = int(keys["head_dim"] * keys["partial_rotary_factor"])
if rotary_dim % 2 != 0:
raise ValueError(f"Calculated rotary dimension ({rotary_dim}) must be a multiple of 2.")
Expand Down
Loading
Loading