-
Notifications
You must be signed in to change notification settings - Fork 421
Add conversion script for Qwen3 Next and Readme #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
144729d
0e825d1
c818a4c
f5a4198
37fea75
2d53f0a
3bb6a72
40ae07c
7fa9967
cbab3e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||
|
|
||||||||||||||||||
| """"Module for decoder layers.""" | ||||||||||||||||||
| """ "Module for decoder layers.""" | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟢 Low - Minor formatting: remove the extra space after the opening triple quotes in the docstring.
Suggested change
|
||||||||||||||||||
| # pylint: disable=arguments-differ | ||||||||||||||||||
| # pylint: disable=no-name-in-module | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -26,6 +26,7 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| from flax import linen as nn | ||||||||||||||||||
| from flax import nnx | ||||||||||||||||||
| from flax.linen import initializers as linen_initializers | ||||||||||||||||||
| from flax.linen.partitioning import ScanIn | ||||||||||||||||||
|
|
||||||||||||||||||
| from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT | ||||||||||||||||||
|
|
@@ -34,6 +35,7 @@ | |||||||||||||||||
| from MaxText import max_utils | ||||||||||||||||||
| from MaxText.inference import page_manager | ||||||||||||||||||
| from MaxText.layers import linears | ||||||||||||||||||
| from MaxText.layers import normalizations | ||||||||||||||||||
| from MaxText.layers import quantizations | ||||||||||||||||||
| from MaxText.layers import pipeline | ||||||||||||||||||
| from MaxText import maxtext_utils | ||||||||||||||||||
|
|
@@ -58,6 +60,7 @@ | |||||||||||||||||
| qwen3, | ||||||||||||||||||
| simple_layer, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from MaxText.layers import nnx_wrappers | ||||||||||||||||||
|
|
||||||||||||||||||
| # ------------------------------------------------------------------------------ | ||||||||||||||||||
| # The network: Decoder Definitions | ||||||||||||||||||
|
|
@@ -465,7 +468,6 @@ def get_norm_layer(self, num_features: int): | |||||||||||||||||
| DecoderBlockType.GEMMA3, | ||||||||||||||||||
| DecoderBlockType.QWEN3, | ||||||||||||||||||
| DecoderBlockType.QWEN3_MOE, | ||||||||||||||||||
| DecoderBlockType.QWEN3_NEXT, | ||||||||||||||||||
| DecoderBlockType.GPT_OSS, | ||||||||||||||||||
| DecoderBlockType.SIMPLE, | ||||||||||||||||||
| DecoderBlockType.SIMPLE_MLP, | ||||||||||||||||||
|
|
@@ -474,6 +476,10 @@ def get_norm_layer(self, num_features: int): | |||||||||||||||||
| return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) | ||||||||||||||||||
| elif self.config.decoder_block == DecoderBlockType.GPT3: | ||||||||||||||||||
| return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) | ||||||||||||||||||
| elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: | ||||||||||||||||||
| return functools.partial( | ||||||||||||||||||
| normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -595,13 +601,11 @@ def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determ | |||||||||||||||||
| if cfg.shard_mode == ShardMode.EXPLICIT: | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟢 Low - The removal of extra parentheses around the tuple in
Suggested change
|
||||||||||||||||||
| norm_out_sharding = NamedSharding( | ||||||||||||||||||
| self.mesh, | ||||||||||||||||||
| nn.logical_to_mesh_axes( | ||||||||||||||||||
| ( | ||||||||||||||||||
| "activation_batch", | ||||||||||||||||||
| "activation_length_no_exp", | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟢 Low - The removal of extra parentheses around the tuple in
Suggested change
|
||||||||||||||||||
| "activation_embed", | ||||||||||||||||||
| ) | ||||||||||||||||||
| ), | ||||||||||||||||||
| nn.logical_to_mesh_axes(( | ||||||||||||||||||
| "activation_batch", | ||||||||||||||||||
| "activation_length_no_exp", | ||||||||||||||||||
| "activation_embed", | ||||||||||||||||||
| )), | ||||||||||||||||||
| ) | ||||||||||||||||||
| else: | ||||||||||||||||||
| norm_out_sharding = None | ||||||||||||||||||
|
|
@@ -621,13 +625,11 @@ def _apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determ | |||||||||||||||||
| else: | ||||||||||||||||||
| out_sharding = NamedSharding( | ||||||||||||||||||
| self.mesh, | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟢 Low - The removal of extra parentheses around the tuple in
Suggested change
|
||||||||||||||||||
| nn.logical_to_mesh_axes( | ||||||||||||||||||
| ( | ||||||||||||||||||
| "activation_embed_and_logits_batch", | ||||||||||||||||||
| "activation_length_no_exp", | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟢 Low - The removal of extra parentheses around the tuple in
Suggested change
|
||||||||||||||||||
| "activation_vocab", | ||||||||||||||||||
| ) | ||||||||||||||||||
| ), | ||||||||||||||||||
| nn.logical_to_mesh_axes(( | ||||||||||||||||||
| "activation_embed_and_logits_batch", | ||||||||||||||||||
| "activation_length_no_exp", | ||||||||||||||||||
| "activation_vocab", | ||||||||||||||||||
| )), | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # [batch, length, emb_dim] -> [batch, length, vocab_size] | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -323,6 +323,7 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs | |||||||||||||||||||||||||
| self.value_dim = self.head_v_dim * self.num_v_heads | ||||||||||||||||||||||||||
| conv_dim = self.key_dim * 2 + self.value_dim | ||||||||||||||||||||||||||
| conv_kernel_size = cfg.gdn_conv_kernel_dim | ||||||||||||||||||||||||||
| self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Submodule instantiations | ||||||||||||||||||||||||||
| self.in_proj_qkvz = linears.DenseGeneral( | ||||||||||||||||||||||||||
|
|
@@ -380,33 +381,86 @@ def a_log_init(key, shape, dtype=jnp.float32): | |||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def __call__(self, hidden_states: Array) -> Array: | ||||||||||||||||||||||||||
| # hidden_states: (B, S, E) | ||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Medium - The reshaping and splitting logic for
Suggested change
|
||||||||||||||||||||||||||
| cfg = self.config | ||||||||||||||||||||||||||
| batch, seq_len, _ = hidden_states.shape | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # ========================================================================= | ||||||||||||||||||||||||||
| # STEP A: Input Projections | ||||||||||||||||||||||||||
| # ========================================================================= | ||||||||||||||||||||||||||
| # hidden_states shape: (B, S, E) | ||||||||||||||||||||||||||
| # qkvz shape: (B, S, 2*key_dim + 2*value_dim) | ||||||||||||||||||||||||||
| # qkvz: (B, S, 2 * K_dim + 2 * V_dim) | ||||||||||||||||||||||||||
| qkvz = self.in_proj_qkvz(hidden_states) | ||||||||||||||||||||||||||
| # ba shape: (B, S, 2*H_v) | ||||||||||||||||||||||||||
| # ba: (B, S, 2 * H_v) | ||||||||||||||||||||||||||
| ba = self.in_proj_ba(hidden_states) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # q shape: (B, S, key_dim), k shape: (B, S, key_dim), v shape: (B, S, value_dim), z shape: (B, S, value_dim) | ||||||||||||||||||||||||||
| q, k, v, z = jnp.split(qkvz, [self.key_dim, 2 * self.key_dim, 2 * self.key_dim + self.value_dim], axis=-1) | ||||||||||||||||||||||||||
| # b shape: (B, S, H_v), a shape: (B, S, H_v) | ||||||||||||||||||||||||||
| b, a = jnp.split(ba, [self.num_v_heads], axis=-1) | ||||||||||||||||||||||||||
| # QKVZ Reshaping and Splitting | ||||||||||||||||||||||||||
| # Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K | ||||||||||||||||||||||||||
| new_shape_qkvz = ( | ||||||||||||||||||||||||||
| batch, | ||||||||||||||||||||||||||
| seq_len, | ||||||||||||||||||||||||||
| self.num_k_heads, # H_k | ||||||||||||||||||||||||||
| 2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K) | ||||||||||||||||||||||||||
| mixed_qkvz = qkvz.reshape(new_shape_qkvz) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| split_indices_qkvz = [ | ||||||||||||||||||||||||||
| self.head_k_dim, # D_k | ||||||||||||||||||||||||||
| 2 * self.head_k_dim, # 2 * D_k | ||||||||||||||||||||||||||
| 2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim), # 2 * D_k + V_per_K * D_v | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| # query: (B, S, H_k, D_k) | ||||||||||||||||||||||||||
| # key: (B, S, H_k, D_k) | ||||||||||||||||||||||||||
| # value_raw: (B, S, H_k, V_per_K * D_v) | ||||||||||||||||||||||||||
| # z_raw: (B, S, H_k, V_per_K * D_v) | ||||||||||||||||||||||||||
| query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # value: (B, S, H_v, D_v) | ||||||||||||||||||||||||||
| value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) | ||||||||||||||||||||||||||
| # z: (B, S, H_v, D_v) | ||||||||||||||||||||||||||
| z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # BA Reshaping and Splitting | ||||||||||||||||||||||||||
| new_shape_ba = ( | ||||||||||||||||||||||||||
| batch, | ||||||||||||||||||||||||||
| seq_len, | ||||||||||||||||||||||||||
| self.num_k_heads, # H_k | ||||||||||||||||||||||||||
| 2 * self.v_heads_per_k_head, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # mixed_ba: (B, S, H_k, 2 * V_per_K) | ||||||||||||||||||||||||||
| mixed_ba = ba.reshape(new_shape_ba) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| split_indices_ba = [self.v_heads_per_k_head] | ||||||||||||||||||||||||||
| # b_raw: (B, S, H_k, V_per_K) | ||||||||||||||||||||||||||
| # a_raw: (B, S, H_k, V_per_K) | ||||||||||||||||||||||||||
| b_raw, a_raw = jnp.split(mixed_ba, split_indices_ba, axis=3) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # b: (B, S, H_v) | ||||||||||||||||||||||||||
| b = b_raw.reshape(batch, seq_len, self.num_v_heads) | ||||||||||||||||||||||||||
| # a: (B, S, H_v) | ||||||||||||||||||||||||||
| a = a_raw.reshape(batch, seq_len, self.num_v_heads) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Flatten head dimensions for concatenation before conv | ||||||||||||||||||||||||||
| # q: (B, S, K_dim) | ||||||||||||||||||||||||||
| q = query.reshape(batch, seq_len, -1) | ||||||||||||||||||||||||||
| # k: (B, S, K_dim) | ||||||||||||||||||||||||||
| k = key.reshape(batch, seq_len, -1) | ||||||||||||||||||||||||||
| # v: (B, S, V_dim) | ||||||||||||||||||||||||||
| v = value.reshape(batch, seq_len, -1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # ========================================================================= | ||||||||||||||||||||||||||
| # STEP B: 1D Convolution | ||||||||||||||||||||||||||
| # ========================================================================= | ||||||||||||||||||||||||||
| # qkv shape: (B, S, conv_dim) | ||||||||||||||||||||||||||
| # conv_dim = 2 * K_dim + V_dim | ||||||||||||||||||||||||||
| # qkv: (B, S, 2 * K_dim + V_dim) | ||||||||||||||||||||||||||
| qkv = jnp.concatenate([q, k, v], axis=-1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # TODO(parambole): Implement caching logic for conv_state and recurrent_state | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Input to conv_layer should be (B, S, C) | ||||||||||||||||||||||||||
| # qkv_conv shape: (B, S, conv_dim) | ||||||||||||||||||||||||||
| qkv_conv = jax.nn.silu(self.conv1d(qkv).astype(jnp.float32)).astype(cfg.dtype) | ||||||||||||||||||||||||||
| conv_out = self.conv1d(qkv) | ||||||||||||||||||||||||||
| qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype) | ||||||||||||||||||||||||||
| # q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim) | ||||||||||||||||||||||||||
| q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -449,13 +503,11 @@ def __call__(self, hidden_states: Array) -> Array: | |||||||||||||||||||||||||
| # ========================================================================= | ||||||||||||||||||||||||||
| # STEP D: Final Output Stage | ||||||||||||||||||||||||||
| # ========================================================================= | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # The normalization and gating is applied per-head on the value dimension. | ||||||||||||||||||||||||||
| # We first reshape the `z` tensor to match the multi-head structure of `core_attn_out`. | ||||||||||||||||||||||||||
| # z shape from (B, S, value_dim) -> (B, S, H_v, D_v) | ||||||||||||||||||||||||||
| z_reshaped = z.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Apply the norm and gate. Output shape: (B, S, H_v, D_v) | ||||||||||||||||||||||||||
| gated_output_reshaped = self.norm(core_attn_out, z_reshaped) | ||||||||||||||||||||||||||
| gated_output_reshaped = self.norm(core_attn_out, z) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Reshape back to a single feature dimension for the final projection. | ||||||||||||||||||||||||||
| # Shape from (B, S, H_v, D_v) -> (B, S, value_dim) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -445,10 +445,8 @@ def validate_qwen3_next_config(keys: dict): | |
| keys: the raw config in dict form | ||
|
|
||
| """ | ||
| if keys["sparse_matmul"]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Medium - The |
||
| raise ValueError( | ||
| "For Qwen3-Next, sparse_matmul must be False for now. The dense path has been verified against reference." | ||
| ) | ||
| if int(keys["gdn_num_value_heads"]) % int(keys["gdn_num_key_heads"]) != 0: | ||
| raise ValueError("gdn_num_value_heads must be divisible by gdn_num_key_heads") | ||
| rotary_dim = int(keys["head_dim"] * keys["partial_rotary_factor"]) | ||
| if rotary_dim % 2 != 0: | ||
| raise ValueError(f"Calculated rotary dimension ({rotary_dim}) must be a multiple of 2.") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟠 High - The
if self.is_qwen3_next:block has been moved to after the sharding logic. This means the reshaping and sigmoid gating for Qwen3-Next will now occur after the output has potentially been sharded. This could lead to incorrect behavior if the sharding expects a different shape or if the reshape/gating needs to happen before sharding. Please verify if this change is intentional and correct, or if the block should remain before the sharding logic.