diff --git a/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh b/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh new file mode 100644 index 000000000..db6a08dfb --- /dev/null +++ b/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +# This script validates a pre-converted MaxText checkpoint against its original +# HuggingFace counterpart to ensure numerical correctness. + +# --- +# Example Usage: +# +# # (Required) Path to the converted MaxText checkpoint +# export MAXTEXT_CHECKPOINT_PATH=gs://path/to/converted_ckpt/0/items/ +# +# # (Optional) Override the default HF model +# export HF_MODEL_PATH=MyCustom/Qwen3-variant +# +# bash end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh +# --- + +set -ex + +# --- Configuration & Input Validation --- + +if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then + echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set." + echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights." + exit 1 +fi + +# Set a default for the HF model path if it's not provided by the user +if [ -z "${HF_MODEL_PATH}" ]; then + export HF_MODEL_PATH="Qwen/Qwen3-Coder-480B-A35B-Instruct" + echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}" +fi + +# Install dependencies required for the logit checker. +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# --- Run the Forward Pass Logit Checker --- + +echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" +echo "Against original HF model: ${HF_MODEL_PATH}" + +# This command runs the core validation logic. +JAX_PLATFORMS=cpu python3 -m MaxText.tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ + tokenizer_type=huggingface \ + tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText/assets}}"/qwen3-tokenizer \ + megablox=False \ + sparse_matmul=False \ + load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \ + model_name=qwen3-next-80b-a3b \ + checkpoint_storage_concurrent_gb=1024 \ + skip_jax_distributed_system=True \ + dtype=float32 \ + weight_dtype=float32 \ + matmul_precision=highest \ + --hf_model_path=${HF_MODEL_PATH} \ + --max_kl_div=0.03 \ + --run_hf_model=True + +echo "Validation complete." \ No newline at end of file diff --git a/end_to_end/tpu/qwen/next/run_qwen3_next.md b/end_to_end/tpu/qwen/next/run_qwen3_next.md new file mode 100644 index 000000000..e0678db1d --- /dev/null +++ b/end_to_end/tpu/qwen/next/run_qwen3_next.md @@ -0,0 +1,97 @@ +Qwen3 Next +========= + +Qwen3-Next is Alibaba 80B Mixture-of-Experts (MoE) model (activating only 3B parameters) that features a novel **hybrid attention** architecture combining Gated DeltaNet and standard attention for massive context scaling This documentation covers the integration of **Qwen3-Next-80B-A3B** into MaxText: + +For more details on the architecture, see the [Qwen3 Technical Blog](https://qwen.ai/blog?id=4074cca80393150c248e508aa62983f9cb7d27cd&from=research.latest-advancements-list). + +* * * * * + +Checkpoint Conversion +--------------------- + +To get started, you first need a MaxText-compatible checkpoint. + +1. **Download the Model**: Download the official model from Hugging Face. You can use a tool like `hf_transfer` for a fast download. + + ``` + # Example for Qwen3-Next-80B-A3B-Instruct + hf_transfer download Qwen/Qwen3-Next-80B-A3B-Instruct --local-dir /path/to/qwen3_next_hf_checkpoint + ``` + +2. **Convert the Checkpoint**: Run the `convert_qwen3_next_scanned.py` script to convert the downloaded Hugging Face weights into the Orbax format required by MaxText. + + ``` + python3 -m MaxText.utils.ckpt_scripts.convert_qwen3_next_scanned \ + --base_model_path /path/to/qwen3_next_hf_checkpoint \ + --maxtext_model_path gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \ + --model_size qwen3-next-80b-a3b + ``` + +* * * * * + +Pre-training and Fine-tuning +---------------------------- + +After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument. + +``` +python3 -m MaxText.train src/MaxText/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + dataset_path=${DATASET_PATH} \ + load_parameters_path=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt/0/items \ + run_name=qwen3_next_finetuning \ + per_device_batch_size=1 \ + model_name=qwen3-next-80b-a3b \ + steps=500 \ + max_target_length=8192 \ + ici_fsdp_parallelism=256 \ + tokenizer_type=huggingface \ + tokenizer_path=src/MaxText/assets/qwen3-tokenizer + +``` + +* * * * * + +Decoding +-------- + +To generate text with a trained model, use the `decode` command. The command below is an example for decoding on a v5p-512 slice. + +``` +python3 -m MaxText.decode src/MaxText/configs/base.yml \ + load_parameters_path=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt/0/items \ + tokenizer_type=huggingface \ + tokenizer_path=src/MaxText/assets/qwen3-tokenizer \ + prompt="Today is a beautiful day to" \ + model_name=qwen3-next-80b-a3b \ + per_device_batch_size=1 \ + max_target_length=128 \ + ici_fsdp_parallelism=256 + +``` + +* * * * * + +Correctness Validation +---------------------- + +To verify that the MaxText implementation is numerically equivalent to the original Hugging Face model, you can run the end-to-end test scripts. These scripts automate the logit comparison test for each model. + +Before running, you must set the `MAXTEXT_CHECKPOINT_PATH` environment variable. You can also optionally set `HF_MODEL_PATH` to point to a local copy of the Hugging Face model. + +### Qwen3-Next-80B-A3B + +Bash + +``` +# Set the required path to your converted MaxText checkpoint +export MAXTEXT_CHECKPOINT_PATH=gs://your-gcs-bucket/qwen3-next-80b-a3b_maxtext_ckpt/0/items/ + +# (Optional) Set the path to your local Hugging Face checkpoint +# export HF_MODEL_PATH=/path/to/local/qwen3-next-80b-a3b_hf_checkpoint + +# Execute the validation script +bash end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh + +``` diff --git a/src/MaxText/configs/models/qwen3-next-80b-a3b.yml b/src/MaxText/configs/models/qwen3-next-80b-a3b.yml index 654869fb2..f48d7da34 100644 --- a/src/MaxText/configs/models/qwen3-next-80b-a3b.yml +++ b/src/MaxText/configs/models/qwen3-next-80b-a3b.yml @@ -46,3 +46,6 @@ gdn_chunk_size: 64 # RoPE Settings rope_max_timescale: 10000000 partial_rotary_factor: 0.25 + +# General Model Settings +enable_dropout: False diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index b12648e63..ef92dbe9b 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -1015,9 +1015,6 @@ def __call__( bidirectional_mask, self.sinks, ) - if self.is_qwen3_next: - out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim) - out = out * jax.nn.sigmoid(gate) if model_mode == MODEL_MODE_PREFILL: out = self._maybe_shard_with_logical(out, self.prefill_out_axis_names) elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: @@ -1026,6 +1023,9 @@ def __call__( out = self._maybe_shard_with_logical(out, self.out_axis_names) else: out = self._maybe_shard_with_logical(out, self.decode_out_axis_names) + if self.is_qwen3_next: + out = out.reshape(batch_size, seq_len, self.config.num_query_heads * self.config.head_dim) + out = out * jax.nn.sigmoid(gate) out = self.out_projection(out, out_sharding=out_sharding) out = checkpoint_name(out, "out_proj") return out diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 5ab5b3a8f..8b8c6f96d 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""""Module for decoder layers.""" +"""Module for decoder layers""" # pylint: disable=arguments-differ # pylint: disable=no-name-in-module @@ -34,6 +34,7 @@ from MaxText import max_utils from MaxText.inference import page_manager from MaxText.layers import linears +from MaxText.layers import normalizations from MaxText.layers import quantizations from MaxText.layers import pipeline from MaxText import maxtext_utils @@ -465,7 +466,6 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.GEMMA3, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, - DecoderBlockType.QWEN3_NEXT, DecoderBlockType.GPT_OSS, DecoderBlockType.SIMPLE, DecoderBlockType.SIMPLE_MLP, @@ -474,6 +474,10 @@ def get_norm_layer(self, num_features: int): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) + elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: + return functools.partial( + normalizations.Qwen3NextRMSNormLinen, num_features=num_features, shard_mode=self.config.shard_mode + ) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") diff --git a/src/MaxText/layers/normalizations.py b/src/MaxText/layers/normalizations.py index d868abec7..bbd4c9914 100644 --- a/src/MaxText/layers/normalizations.py +++ b/src/MaxText/layers/normalizations.py @@ -197,3 +197,11 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array: inv_norm = jax.lax.rsqrt((x * x).sum(axis=dim, keepdims=True) + jnp.array(eps, dtype=x.dtype)) return x * inv_norm + + +Qwen3NextRMSNormLinen = nnx_wrappers.to_linen_class( + RMSNorm, + base_metadata_fn=variable_to_logically_partitioned, + scale_init=linen_initializers.zeros, + scale_offset=1.0, +) diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index 26d8748ba..0a26e4edc 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -323,6 +323,7 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs self.value_dim = self.head_v_dim * self.num_v_heads conv_dim = self.key_dim * 2 + self.value_dim conv_kernel_size = cfg.gdn_conv_kernel_dim + self.v_heads_per_k_head = self.num_v_heads // self.num_k_heads # Submodule instantiations self.in_proj_qkvz = linears.DenseGeneral( @@ -380,33 +381,86 @@ def a_log_init(key, shape, dtype=jnp.float32): ) def __call__(self, hidden_states: Array) -> Array: + # hidden_states: (B, S, E) cfg = self.config + batch, seq_len, _ = hidden_states.shape # ========================================================================= # STEP A: Input Projections # ========================================================================= - # hidden_states shape: (B, S, E) - # qkvz shape: (B, S, 2*key_dim + 2*value_dim) + # qkvz: (B, S, 2 * K_dim + 2 * V_dim) qkvz = self.in_proj_qkvz(hidden_states) - # ba shape: (B, S, 2*H_v) + # ba: (B, S, 2 * H_v) ba = self.in_proj_ba(hidden_states) - # q shape: (B, S, key_dim), k shape: (B, S, key_dim), v shape: (B, S, value_dim), z shape: (B, S, value_dim) - q, k, v, z = jnp.split(qkvz, [self.key_dim, 2 * self.key_dim, 2 * self.key_dim + self.value_dim], axis=-1) - # b shape: (B, S, H_v), a shape: (B, S, H_v) - b, a = jnp.split(ba, [self.num_v_heads], axis=-1) + # QKVZ Reshaping and Splitting + # Per-K_head group dim: 2 * D_k + 2 * D_v * V_per_K + new_shape_qkvz = ( + batch, + seq_len, + self.num_k_heads, # H_k + 2 * self.head_k_dim + 2 * self.head_v_dim * self.v_heads_per_k_head, + ) + # mixed_qkvz: (B, S, H_k, 2*D_k + 2*D_v*V_per_K) + mixed_qkvz = qkvz.reshape(new_shape_qkvz) + + split_indices_qkvz = [ + self.head_k_dim, # D_k + 2 * self.head_k_dim, # 2 * D_k + 2 * self.head_k_dim + (self.v_heads_per_k_head * self.head_v_dim), # 2 * D_k + V_per_K * D_v + ] + # query: (B, S, H_k, D_k) + # key: (B, S, H_k, D_k) + # value_raw: (B, S, H_k, V_per_K * D_v) + # z_raw: (B, S, H_k, V_per_K * D_v) + query, key, value_raw, z_raw = jnp.split(mixed_qkvz, split_indices_qkvz, axis=3) + + # value: (B, S, H_v, D_v) + value = value_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + # z: (B, S, H_v, D_v) + z = z_raw.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) + + # BA Reshaping and Splitting + new_shape_ba = ( + batch, + seq_len, + self.num_k_heads, # H_k + 2 * self.v_heads_per_k_head, + ) + # mixed_ba: (B, S, H_k, 2 * V_per_K) + mixed_ba = ba.reshape(new_shape_ba) + + split_indices_ba = [self.v_heads_per_k_head] + # b_raw: (B, S, H_k, V_per_K) + # a_raw: (B, S, H_k, V_per_K) + b_raw, a_raw = jnp.split(mixed_ba, split_indices_ba, axis=3) + + # b: (B, S, H_v) + b = b_raw.reshape(batch, seq_len, self.num_v_heads) + # a: (B, S, H_v) + a = a_raw.reshape(batch, seq_len, self.num_v_heads) + + # Flatten head dimensions for concatenation before conv + # q: (B, S, K_dim) + q = query.reshape(batch, seq_len, -1) + # k: (B, S, K_dim) + k = key.reshape(batch, seq_len, -1) + # v: (B, S, V_dim) + v = value.reshape(batch, seq_len, -1) # ========================================================================= # STEP B: 1D Convolution # ========================================================================= - # qkv shape: (B, S, conv_dim) + # conv_dim = 2 * K_dim + V_dim + # qkv: (B, S, 2 * K_dim + V_dim) qkv = jnp.concatenate([q, k, v], axis=-1) # TODO(parambole): Implement caching logic for conv_state and recurrent_state # Input to conv_layer should be (B, S, C) # qkv_conv shape: (B, S, conv_dim) - qkv_conv = jax.nn.silu(self.conv1d(qkv).astype(jnp.float32)).astype(cfg.dtype) + conv_out = self.conv1d(qkv) + qkv_conv = jax.nn.silu(conv_out.astype(jnp.float32)).astype(cfg.dtype) # q_conv shape: (B, S, key_dim), k_conv shape: (B, S, key_dim), v_conv shape: (B, S, value_dim) q_conv, k_conv, v_conv = jnp.split(qkv_conv, [self.key_dim, 2 * self.key_dim], axis=-1) @@ -449,13 +503,11 @@ def __call__(self, hidden_states: Array) -> Array: # ========================================================================= # STEP D: Final Output Stage # ========================================================================= + # The normalization and gating is applied per-head on the value dimension. - # We first reshape the `z` tensor to match the multi-head structure of `core_attn_out`. - # z shape from (B, S, value_dim) -> (B, S, H_v, D_v) - z_reshaped = z.reshape(batch, seq_len, self.num_v_heads, self.head_v_dim) # Apply the norm and gate. Output shape: (B, S, H_v, D_v) - gated_output_reshaped = self.norm(core_attn_out, z_reshaped) + gated_output_reshaped = self.norm(core_attn_out, z) # Reshape back to a single feature dimension for the final projection. # Shape from (B, S, H_v, D_v) -> (B, S, value_dim) diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 5b4c0df85..4dec37825 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -434,10 +434,8 @@ def validate_qwen3_next_config(keys: dict): keys: the raw config in dict form """ - if keys["sparse_matmul"]: - raise ValueError( - "For Qwen3-Next, sparse_matmul must be False for now. The dense path has been verified against reference." - ) + if int(keys["gdn_num_value_heads"]) % int(keys["gdn_num_key_heads"]) != 0: + raise ValueError("gdn_num_value_heads must be divisible by gdn_num_key_heads") rotary_dim = int(keys["head_dim"] * keys["partial_rotary_factor"]) if rotary_dim % 2 != 0: raise ValueError(f"Calculated rotary dimension ({rotary_dim}) must be a multiple of 2.") diff --git a/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py new file mode 100644 index 000000000..6996b861b --- /dev/null +++ b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_scanned.py @@ -0,0 +1,433 @@ +""" +Copyright 2025 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +r"""Convert weights from a Qwen3-Next style model to a MaxText one. +This script rigorously follows the two-stage conversion process (map-then-transform) +required for generating a MaxText checkpoint compatible with the model structure, +specifically for scanned heterogeneous layers. +Example cmd: +python3 -m MaxText.utils.ckpt_scripts.convert_qwen3_next_scanned \ + --base_model_path . \ + --maxtext_model_path gs:/// \ + --model_size qwen3-next-80b-a3b +""" + +import argparse +import gc +import os +import pathlib +import numpy as np +import torch +import jax.numpy as jnp +from safetensors import safe_open +from functools import partial +from tqdm import tqdm + +from MaxText.utils.ckpt_scripts import llama_or_mistral_ckpt +from MaxText import max_logging +from MaxText.inference_utils import str2bool + +# Static model parameters dictionary +MODEL_PARAMS_DICT = { + "qwen3-next-80b-a3b": { + "num_layers": 48, + "num_q_heads": 16, + "num_kv_heads": 2, + "head_dim": 256, + "emb_dim": 2048, + "vocab_size": 151936, + "moe_intermediate_size": 512, # base_moe_mlp_dim + "num_experts": 512, + "num_experts_per_tok": 10, + # Qwen3-Next Specific Parameters for Linear Attention (Gated Delta Net) + "inhomogeneous_layer_cycle_interval": 4, + "gdn_conv_kernel_dim": 4, + "gdn_key_head_dim": 128, + "gdn_value_head_dim": 128, + "gdn_num_key_heads": 16, + "gdn_num_value_heads": 32, + } +} + + +def to_np_bfloat16(tensor): + """Converts a torch tensor to a numpy array with bfloat16 dtype.""" + return tensor.to(torch.float32).numpy().astype(jnp.bfloat16) + + +def hf_to_maxtext_mapping(layer_idx: int, num_experts: int, inhomogeneous_layer_cycle_interval: int) -> dict: + """Creates a mapping from HF weight names to MaxText weight names for a specific layer.""" + + # 1. Define base prefixes to shorten line lengths + block_idx = layer_idx % inhomogeneous_layer_cycle_interval + hf_prefix = f"model.layers.{layer_idx}" + mt_prefix = f"decoder.layers.layer_{block_idx}" + mt_attn_prefix = f"{mt_prefix}.attention" + mt_mlp_prefix = f"{mt_prefix}.mlp" + + # 2. Initialize mapping with global weights and standard layer norms + mapping = { + "model.embed_tokens.weight": "token_embedder.embedding", + "model.norm.weight": "decoder.decoder_norm.scale", + "lm_head.weight": "decoder.logits_dense.kernel", + f"{hf_prefix}.input_layernorm.weight": f"{mt_prefix}.input_layernorm.scale", + f"{hf_prefix}.post_attention_layernorm.weight": f"{mt_prefix}.post_attention_layernorm.scale", + } + + # 3. Handle Attention Logic (Full vs Linear) + is_full_attention_layer = (layer_idx + 1) % inhomogeneous_layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update( + { + f"{hf_prefix}.self_attn.q_proj.weight": f"{mt_attn_prefix}.attention.query.kernel", + f"{hf_prefix}.self_attn.k_proj.weight": f"{mt_attn_prefix}.attention.key.kernel", + f"{hf_prefix}.self_attn.v_proj.weight": f"{mt_attn_prefix}.attention.value.kernel", + f"{hf_prefix}.self_attn.o_proj.weight": f"{mt_attn_prefix}.attention.out.kernel", + f"{hf_prefix}.self_attn.q_norm.weight": f"{mt_attn_prefix}.attention.query_norm.scale", + f"{hf_prefix}.self_attn.k_norm.weight": f"{mt_attn_prefix}.attention.key_norm.scale", + } + ) + else: + mapping.update( + { + f"{hf_prefix}.linear_attn.in_proj_qkvz.weight": f"{mt_attn_prefix}.in_proj_qkvz.kernel", + f"{hf_prefix}.linear_attn.in_proj_ba.weight": f"{mt_attn_prefix}.in_proj_ba.kernel", + f"{hf_prefix}.linear_attn.conv1d.weight": f"{mt_attn_prefix}.conv1d.kernel", + f"{hf_prefix}.linear_attn.A_log": f"{mt_attn_prefix}.A_log", + f"{hf_prefix}.linear_attn.dt_bias": f"{mt_attn_prefix}.dt_bias", + f"{hf_prefix}.linear_attn.norm.weight": f"{mt_attn_prefix}.norm.rms_norm.scale", + f"{hf_prefix}.linear_attn.out_proj.weight": f"{mt_attn_prefix}.out_proj.kernel", + } + ) + + # 4. Handle MLP (Gates and Shared Experts) + mapping.update( + { + f"{hf_prefix}.mlp.gate.weight": f"{mt_mlp_prefix}.routed_experts.gate.kernel", + f"{hf_prefix}.mlp.shared_expert.gate_proj.weight": f"{mt_mlp_prefix}.shared_expert.wi_0.kernel", + f"{hf_prefix}.mlp.shared_expert.up_proj.weight": f"{mt_mlp_prefix}.shared_expert.wi_1.kernel", + f"{hf_prefix}.mlp.shared_expert.down_proj.weight": f"{mt_mlp_prefix}.shared_expert.wo.kernel", + f"{hf_prefix}.mlp.shared_expert_gate.weight": f"{mt_mlp_prefix}.shared_expert_gate.kernel", + } + ) + + # 5. Handle Routed Experts Loop + for i in range(num_experts): + # Note: Ensure these don't require '.kernel' suffix (common in Flax, but absent in your original code) + mapping[f"{hf_prefix}.mlp.experts.{i}.gate_proj.weight"] = f"{mt_mlp_prefix}.routed_experts.{i}.wi_0" + mapping[f"{hf_prefix}.mlp.experts.{i}.up_proj.weight"] = f"{mt_mlp_prefix}.routed_experts.{i}.wi_1" + mapping[f"{hf_prefix}.mlp.experts.{i}.down_proj.weight"] = f"{mt_mlp_prefix}.routed_experts.{i}.wo" + + return mapping + + +def init_maxtext_weights(model_params, num_layers_to_convert, num_experts_to_convert): + """Initializes an empty pytree for the hf weights to be loaded in""" + emb_dim = model_params["emb_dim"] + num_q_heads = model_params["num_q_heads"] + num_kv_heads = model_params["num_kv_heads"] + head_dim = model_params["head_dim"] + moe_intermediate_size = model_params["moe_intermediate_size"] + # num_experts = model_params["num_experts"] + cycle = model_params["inhomogeneous_layer_cycle_interval"] + num_stacked_layers = num_layers_to_convert // cycle + + gdn_num_v_heads = model_params["gdn_num_value_heads"] + gdn_key_dim = model_params["gdn_num_key_heads"] * model_params["gdn_key_head_dim"] + gdn_value_dim = gdn_num_v_heads * model_params["gdn_value_head_dim"] + gdn_conv_dim = gdn_key_dim * 2 + gdn_value_dim + gdn_conv_kernel_dim = model_params["gdn_conv_kernel_dim"] + + weights = { + "decoder": { + "layers": {}, + "decoder_norm": {"scale": None}, + "logits_dense": {"kernel": None}, + }, + "token_embedder": {"embedding": None}, + } + + for i in range(cycle): + layer_key = f"layer_{i}" + layer_struct = { + "input_layernorm": {"scale": np.zeros((emb_dim, num_stacked_layers), dtype=jnp.bfloat16)}, + "post_attention_layernorm": {"scale": np.zeros((emb_dim, num_stacked_layers), dtype=jnp.bfloat16)}, + "attention": {}, + "mlp": { + "routed_experts": { + "gate": {"kernel": np.zeros((emb_dim, num_stacked_layers, num_experts_to_convert), dtype=jnp.bfloat16)}, + "wi_0": np.zeros( + (num_experts_to_convert, num_stacked_layers, emb_dim, moe_intermediate_size), dtype=jnp.bfloat16 + ), + "wi_1": np.zeros( + (num_experts_to_convert, num_stacked_layers, emb_dim, moe_intermediate_size), dtype=jnp.bfloat16 + ), + "wo": np.zeros( + (num_experts_to_convert, num_stacked_layers, moe_intermediate_size, emb_dim), dtype=jnp.bfloat16 + ), + }, + "shared_expert": { + "wi_0": {"kernel": np.zeros((emb_dim, num_stacked_layers, moe_intermediate_size), dtype=jnp.bfloat16)}, + "wi_1": {"kernel": np.zeros((emb_dim, num_stacked_layers, moe_intermediate_size), dtype=jnp.bfloat16)}, + "wo": {"kernel": np.zeros((moe_intermediate_size, num_stacked_layers, emb_dim), dtype=jnp.bfloat16)}, + }, + "shared_expert_gate": {"kernel": np.zeros((emb_dim, num_stacked_layers, 1), dtype=jnp.bfloat16)}, + }, + } + + is_full_attention_layer = (i + 1) % cycle == 0 + if is_full_attention_layer: + layer_struct["attention"] = { + "attention": { + "query": {"kernel": np.zeros((emb_dim, num_stacked_layers, num_q_heads, head_dim * 2), dtype=jnp.bfloat16)}, + "key": {"kernel": np.zeros((emb_dim, num_stacked_layers, num_kv_heads, head_dim), dtype=jnp.bfloat16)}, + "value": {"kernel": np.zeros((emb_dim, num_stacked_layers, num_kv_heads, head_dim), dtype=jnp.bfloat16)}, + "out": {"kernel": np.zeros((num_q_heads * head_dim, num_stacked_layers, emb_dim), dtype=jnp.bfloat16)}, + "query_norm": {"scale": np.zeros((head_dim, num_stacked_layers), dtype=jnp.bfloat16)}, + "key_norm": {"scale": np.zeros((head_dim, num_stacked_layers), dtype=jnp.bfloat16)}, + } + } + else: + layer_struct["attention"] = { + "in_proj_qkvz": { + "kernel": np.zeros((emb_dim, num_stacked_layers, gdn_key_dim * 2 + gdn_value_dim * 2), dtype=jnp.bfloat16) + }, + "in_proj_ba": {"kernel": np.zeros((emb_dim, num_stacked_layers, gdn_num_v_heads * 2), dtype=jnp.bfloat16)}, + "conv1d": {"kernel": np.zeros((gdn_conv_kernel_dim, num_stacked_layers, 1, gdn_conv_dim), dtype=jnp.bfloat16)}, + "A_log": np.zeros((gdn_num_v_heads, num_stacked_layers), dtype=jnp.bfloat16), + "dt_bias": np.zeros((gdn_num_v_heads, num_stacked_layers), dtype=jnp.bfloat16), + "norm": { + "rms_norm": { + "scale": np.zeros((model_params["gdn_value_head_dim"], num_stacked_layers), dtype=jnp.bfloat16) + } + }, + "out_proj": {"kernel": np.zeros((gdn_value_dim, num_stacked_layers, emb_dim), dtype=jnp.bfloat16)}, + } + weights["decoder"]["layers"][layer_key] = layer_struct + return weights + + +def convert_hf_to_maxtext(base_model_path: str, model_params: dict, args) -> dict: + """Converts a Hugging Face Qwen3-Next checkpoint to a MaxText compatible format.""" + num_layers = model_params["num_layers"] + num_experts = model_params["num_experts"] + emb_dim = model_params["emb_dim"] + num_q_heads = model_params["num_q_heads"] + num_kv_heads = model_params["num_kv_heads"] + head_dim = model_params["head_dim"] + inhomogeneous_layer_cycle_interval = model_params["inhomogeneous_layer_cycle_interval"] + cycle = inhomogeneous_layer_cycle_interval + # num_stacked_layers = num_layers_to_convert // cycle + + num_layers_to_convert = args.num_layers_to_convert if args.num_layers_to_convert > 0 else num_layers + num_experts_to_convert = args.num_experts_to_convert if args.num_experts_to_convert > 0 else num_experts + if num_layers_to_convert % cycle != 0: + raise ValueError(f"num_layers_to_convert ({num_layers_to_convert}) must be a multiple of the cycle length ({cycle})") + + # Part 1: Load weights from safetensors + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("model-*-of-*.safetensors")) + chkpt_vars = {} + max_logging.log(f"Loading {len(ckpt_paths)} checkpoint files...") + for i, ckpt_path in enumerate(tqdm(ckpt_paths, desc="Loading HF Checkpoints")): + with safe_open(ckpt_path, framework="pt", device="cpu") as f: + for key in f.keys(): + chkpt_vars[key] = f.get_tensor(key) + gc.collect() + max_logging.log("HF weights loaded.") + + # Part 2: Initialize, populate, and transform weights + maxtext_weights = init_maxtext_weights(model_params, num_layers_to_convert, num_experts_to_convert) + + # Non-layer weights + max_logging.log("Populating non-layer weights...") + if "model.embed_tokens.weight" in chkpt_vars: + # HF: [vocab_size, emb_dim] -> MaxText: [vocab_size, emb_dim] + maxtext_weights["token_embedder"]["embedding"] = to_np_bfloat16(chkpt_vars["model.embed_tokens.weight"]) + if "model.norm.weight" in chkpt_vars: + # HF: [emb_dim] -> MaxText: [emb_dim] + maxtext_weights["decoder"]["decoder_norm"]["scale"] = to_np_bfloat16(chkpt_vars["model.norm.weight"]) + if "lm_head.weight" in chkpt_vars: + # HF: [vocab_size, emb_dim] -> MaxText: [emb_dim, vocab_size] (Transposed) + maxtext_weights["decoder"]["logits_dense"]["kernel"] = to_np_bfloat16(chkpt_vars["lm_head.weight"]).transpose() + + max_logging.log(f"Populating layer weights for {num_layers_to_convert} layers...") + + def _get_hf_tensor(maxtext_key_suffix, hf_map, l, chkpt_vars): + for hf_key, mt_key in hf_map.items(): + if mt_key.endswith(maxtext_key_suffix): + if hf_key in chkpt_vars: + return chkpt_vars[hf_key] + else: + raise ValueError( + f"HF Key {hf_key} not found in chkpt_vars for MaxText suffix: {maxtext_key_suffix} in layer {l}" + ) + raise ValueError(f"Could not find HF key for MaxText suffix: {maxtext_key_suffix} in layer {l}") + + for l in tqdm(range(num_layers_to_convert), desc="Processing Layers"): + block_idx = l % cycle + stack_idx = l // cycle + layer_key = f"layer_{block_idx}" + hf_map = hf_to_maxtext_mapping(l, num_experts, cycle) + + get_hf_tensor = partial(_get_hf_tensor, hf_map=hf_map, l=l, chkpt_vars=chkpt_vars) + + ln = maxtext_weights["decoder"]["layers"][layer_key] + + # Layernorms + # HF: [emb_dim] -> slice of MaxText: [emb_dim, num_stacked_layers] + ln["input_layernorm"]["scale"][:, stack_idx] = to_np_bfloat16(get_hf_tensor(".input_layernorm.scale")) + # HF: [emb_dim] -> slice of MaxText: [emb_dim, num_stacked_layers] + ln["post_attention_layernorm"]["scale"][:, stack_idx] = to_np_bfloat16( + get_hf_tensor(".post_attention_layernorm.scale") + ) + + attn_block = ln["attention"] + is_full_attention_layer = (l + 1) % cycle == 0 + if is_full_attention_layer: + attn_params = attn_block["attention"] + # HF: [8192, 2048] -> Transpose [2048, 8192] -> Reshape -> + # slice of MaxText: [emb_dim, num_stacked_layers, num_q_heads, head_dim * 2] + q_kernel = to_np_bfloat16(get_hf_tensor(".attention.attention.query.kernel")).transpose() + attn_params["query"]["kernel"][:, stack_idx, :, :] = q_kernel.reshape(emb_dim, num_q_heads, head_dim * 2) + # HF: [512, 2048] -> Transpose [2048, 512] -> Reshape -> + # slice of MaxText: [emb_dim, num_stacked_layers, num_kv_heads, head_dim] + k_kernel = to_np_bfloat16(get_hf_tensor(".attention.attention.key.kernel")).transpose() + attn_params["key"]["kernel"][:, stack_idx, :, :] = k_kernel.reshape(emb_dim, num_kv_heads, head_dim) + # HF: [512, 2048] -> Transpose [2048, 512] -> Reshape -> + # slice of MaxText: [emb_dim, num_stacked_layers, num_kv_heads, head_dim] + v_kernel = to_np_bfloat16(get_hf_tensor(".attention.attention.value.kernel")).transpose() + attn_params["value"]["kernel"][:, stack_idx, :, :] = v_kernel.reshape(emb_dim, num_kv_heads, head_dim) + # HF: [2048, 4096] -> Transpose -> slice of MaxText: [num_q_heads * head_dim, num_stacked_layers, emb_dim] + attn_params["out"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".attention.attention.out.kernel") + ).transpose() + # HF: [256] -> slice of MaxText: [head_dim, num_stacked_layers] + attn_params["query_norm"]["scale"][:, stack_idx] = to_np_bfloat16( + get_hf_tensor(".attention.attention.query_norm.scale") + ) + # HF: [256] -> slice of MaxText: [head_dim, num_stacked_layers] + attn_params["key_norm"]["scale"][:, stack_idx] = to_np_bfloat16( + get_hf_tensor(".attention.attention.key_norm.scale") + ) + else: # Gated Delta Net + # HF: [12288, 2048] -> Transpose -> slice of MaxText: [emb_dim, num_stacked_layers, 12288] + attn_block["in_proj_qkvz"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".attention.in_proj_qkvz.kernel") + ).transpose() + # HF: [64, 2048] -> Transpose -> slice of MaxText: [emb_dim, num_stacked_layers, 64] + attn_block["in_proj_ba"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".attention.in_proj_ba.kernel") + ).transpose() + # HF: [8192, 1, 4] -> Transpose(2,1,0) -> slice of MaxText: [gdn_conv_kernel_dim, num_stacked_layers, 1, gdn_conv_dim] + conv1d_kernel = to_np_bfloat16(get_hf_tensor(".attention.conv1d.kernel")) + attn_block["conv1d"]["kernel"][:, stack_idx, :, :] = conv1d_kernel.transpose(2, 1, 0) + # HF: [32] -> slice of MaxText: [32, num_stacked_layers] + attn_block["A_log"][:, stack_idx] = to_np_bfloat16(get_hf_tensor(".attention.A_log")) + # HF: [32] -> slice of MaxText: [32, num_stacked_layers] + attn_block["dt_bias"][:, stack_idx] = to_np_bfloat16(get_hf_tensor(".attention.dt_bias")) + # HF: [128] -> slice of MaxText: [128, num_stacked_layers] + attn_block["norm"]["rms_norm"]["scale"][:, stack_idx] = to_np_bfloat16( + get_hf_tensor(".attention.norm.rms_norm.scale") + ) + # HF: [2048, 4096] -> Transpose -> slice of MaxText: [4096, num_stacked_layers, 2048] + attn_block["out_proj"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".attention.out_proj.kernel") + ).transpose() + + # MoE + mlp_block = ln["mlp"] + # HF: [512, 2048] -> Transpose -> slice of MaxText: [emb_dim, num_stacked_layers, 512] + mlp_block["routed_experts"]["gate"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".mlp.routed_experts.gate.kernel") + ).transpose() + # HF: [512, 2048] -> Transpose -> slice of MaxText: [emb_dim, num_stacked_layers, 512] + mlp_block["shared_expert"]["wi_0"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".mlp.shared_expert.wi_0.kernel") + ).transpose() + # HF: [512, 2048] -> Transpose -> slice of MaxText: [emb_dim, num_stacked_layers, 512] + mlp_block["shared_expert"]["wi_1"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".mlp.shared_expert.wi_1.kernel") + ).transpose() + # HF: [2048, 512] -> Transpose -> slice of MaxText: [512, num_stacked_layers, 2048] + mlp_block["shared_expert"]["wo"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".mlp.shared_expert.wo.kernel") + ).transpose() + # HF: [1, 2048] -> Transpose -> slice of MaxText: [emb_dim, num_stacked_layers, 1] + mlp_block["shared_expert_gate"]["kernel"][:, stack_idx, :] = to_np_bfloat16( + get_hf_tensor(".mlp.shared_expert_gate.kernel") + ).transpose() + + for i in range(num_experts_to_convert): + # HF: [512, 2048] -> Transpose -> slice of MaxText: [num_experts, num_stacked_layers, emb_dim, moe_intermediate_size] + mlp_block["routed_experts"]["wi_0"][i, stack_idx, :, :] = to_np_bfloat16( + get_hf_tensor(f".mlp.routed_experts.{i}.wi_0") + ).transpose() + # HF: [512, 2048] -> Transpose -> slice of MaxText: [num_experts, num_stacked_layers, emb_dim, moe_intermediate_size] + mlp_block["routed_experts"]["wi_1"][i, stack_idx, :, :] = to_np_bfloat16( + get_hf_tensor(f".mlp.routed_experts.{i}.wi_1") + ).transpose() + # HF: [2048, 512] -> Transpose -> slice of MaxText: [num_experts, num_stacked_layers, moe_intermediate_size, emb_dim] + mlp_block["routed_experts"]["wo"][i, stack_idx, :, :] = to_np_bfloat16( + get_hf_tensor(f".mlp.routed_experts.{i}.wo") + ).transpose() + gc.collect() + return maxtext_weights + + +def main(args): + """Main function to run the conversion.""" + os.environ["JAX_PLATFORMS"] = "cpu" + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}" + + if args.model_size not in MODEL_PARAMS_DICT: + raise ValueError(f"Model size '{args.model_size}' not found in MODEL_PARAMS_DICT.") + + model_params = MODEL_PARAMS_DICT[args.model_size] + max_logging.log(f"Starting conversion for Qwen3-Next model size: {args.model_size}") + jax_weights = convert_hf_to_maxtext(args.base_model_path, model_params, args) + max_logging.log(f"Conversion complete. Saving MaxText checkpoint to {args.maxtext_model_path}") + + llama_or_mistral_ckpt.save_weights_to_checkpoint( + args.maxtext_model_path, jax_weights, args.simulated_cpu_devices_count, args.use_ocdbt, args.use_zarr3 + ) + max_logging.log("Checkpoint saved successfully.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Qwen3-Next HF weights to MaxText.") + parser.add_argument("--base_model_path", type=str, required=True, help="Path to the HF Qwen3-Next checkpoint files.") + parser.add_argument( + "--maxtext_model_path", type=str, required=True, help="Path to save the MaxText checkpoint (local or GCS)." + ) + parser.add_argument( + "--model_size", type=str, required=True, choices=MODEL_PARAMS_DICT.keys(), help="The model size to convert." + ) + + # Dry run options + parser.add_argument( + "--num_layers_to_convert", type=int, default=-1, help="Number of layers to convert for a dry run. -1 for all." + ) + parser.add_argument( + "--num_experts_to_convert", type=int, default=-1, help="Number of experts to convert for a dry run. -1 for all." + ) + + # Saving options + parser.add_argument( + "--simulated_cpu_devices_count", type=int, default=16, help="Number of simulated CPU devices for saving." + ) + parser.add_argument("--use_ocdbt", type=str2bool, default=True, help="Use OCDBT format for saving.") + parser.add_argument("--use_zarr3", type=str2bool, default=True, help="Use Zarr3 format for saving.") + + parsed_args = parser.parse_args() + main(parsed_args) diff --git a/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py new file mode 100644 index 000000000..b4b595612 --- /dev/null +++ b/src/MaxText/utils/ckpt_scripts/convert_qwen3_next_unscanned.py @@ -0,0 +1,449 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert weights from a Qwen3 Next model to a MaxText one in unscanned orbax format. + +Example cmd: + +python3 -m MaxText.utils.ckpt_scripts.convert_qwen3_next_unscanned --base-model-path \ + --maxtext-model-path --model-size qwen3-next-80b-a3b +""" + +# pylint: disable=g-line-too-long +import argparse +import gc +import logging +import os +import pathlib + +os.environ["JAX_PLATFORMS"] = "cpu" + +import ml_dtypes +import psutil +import numpy as np +from safetensors import safe_open +import torch +from tqdm import tqdm +from typing import Any, Dict + +from MaxText import max_logging +from MaxText.inference_utils import str2bool +from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint + + +# NOTE: numpy doesn't have native support for bfloat16, so +# we'll use ml_dtypes instead (which is quasi native) +# NOTE: it's incredibly silly but you can't directly cast from +# a torch tensor of type bfloat16 to a numpy array of type bfloat16 +# so we have to cast to float32 first +CAST_DTYPE = ml_dtypes.bfloat16 + + +def _pt_to_np(pt_weight, cast_dtype=None, transpose=False): + if cast_dtype: + np_weight = pt_weight.to(torch.float32).numpy().astype(cast_dtype) + else: + np_weight = pt_weight.to(torch.float32).numpy() + if transpose: + np_weight = np_weight.transpose() + return np_weight + + +MODEL_PARAMS_DICT = { + "qwen3-next-80b-a3b": { + "num_hidden_layers": 48, + "num_blocks": 12, # 48 layers / 4 types + "num_layers_per_block": 4, + "hidden_size": 2048, + # MoE Params + "num_experts": 512, + "num_shared_experts": 1, + "moe_intermediate_size": 512, + # Gated Attention (GA) params (layer_3) + "head_dim": 256, + "ga_num_q_heads": 16, + "ga_num_kv_heads": 2, + "ga_o_proj_input_dim": 4096, + # Gated DeltaNet (GDN) params (layers_0, _1, _2) + "gdn_num_value_heads": 32, + "gdn_num_key_heads": 16, + "gdn_chunk_size": 64, + "inhomogeneous_layer_cycle_interval": 4, + "gdn_conv_kernel_dim": 4, + "gdn_key_head_dim": 128, + "gdn_value_head_dim": 128, + "gdn_a_log_dim": 32, + "gdn_conv_features": 8192, + "gdn_in_proj_ba_dim": 64, + "gdn_in_proj_qkvz_dim": 12288, + "gdn_norm_dim": 128, + "gdn_out_proj_input_dim": 4096, + }, +} + + +# These will host the simple 1 to 1 mappings +def _hf_to_maxtext_mapping(layer_idx: int = -1, num_experts: int = 512) -> dict: + """ + Returns a mapping from HuggingFace model weight names to MaxText model weight names. + + Args: + layer_idx (int): Layer index. + + Returns: + dict [str, str]: Mapping from HuggingFace model weight names to MaxText model weight names. + """ + # pylint: disable=line-too-long + mapping = { + "model.embed_tokens.weight": "token_embedder.embedding", + "model.norm.weight": "decoder.decoder_norm.scale", + "lm_head.weight": "decoder.logits_dense.kernel", + # moe + # shared + f"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight": f"decoder.layers_{layer_idx}.mlp.shared_expert.wo.kernel", + f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight": f"decoder.layers_{layer_idx}.mlp.shared_expert.wi_0.kernel", + f"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight": f"decoder.layers_{layer_idx}.mlp.shared_expert.wi_1.kernel", + f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight": f"decoder.layers_{layer_idx}.mlp.shared_expert_gate.kernel", + # routed + f"model.layers.{layer_idx}.mlp.gate.weight": f"decoder.layers_{layer_idx}.mlp.routed_experts.gate.kernel", + } + + # Gated attention + if layer_idx % 4 == 3: + # Gated Attention (GA) Layer + mapping.update( + { + f"model.layers.{layer_idx}.self_attn.k_norm.weight": ( + f"decoder.layers_{layer_idx}.attention.attention.key_norm.scale" + ), + f"model.layers.{layer_idx}.self_attn.k_proj.weight": ( + f"decoder.layers_{layer_idx}.attention.attention.key.kernel" + ), + f"model.layers.{layer_idx}.self_attn.o_proj.weight": ( + f"decoder.layers_{layer_idx}.attention.attention.out.kernel" + ), + f"model.layers.{layer_idx}.self_attn.q_norm.weight": ( + f"decoder.layers_{layer_idx}.attention.attention.query_norm.scale" + ), + f"model.layers.{layer_idx}.self_attn.q_proj.weight": ( + f"decoder.layers_{layer_idx}.attention.attention.query.kernel" + ), + f"model.layers.{layer_idx}.self_attn.v_proj.weight": ( + f"decoder.layers_{layer_idx}.attention.attention.value.kernel" + ), + } + ) + else: + # Gated DeltaNet (GDN) Layer | Linear Attention + mapping.update( + { + f"model.layers.{layer_idx}.linear_attn.A_log": f"decoder.layers_{layer_idx}.attention.A_log", + f"model.layers.{layer_idx}.linear_attn.conv1d.weight": f"decoder.layers_{layer_idx}.attention.conv1d.kernel", + f"model.layers.{layer_idx}.linear_attn.dt_bias": f"decoder.layers_{layer_idx}.attention.dt_bias", + f"model.layers.{layer_idx}.linear_attn.in_proj_ba.weight": f"decoder.layers_{layer_idx}.attention.in_proj_ba.kernel", + f"model.layers.{layer_idx}.linear_attn.in_proj_qkvz.weight": f"decoder.layers_{layer_idx}.attention.in_proj_qkvz.kernel", + f"model.layers.{layer_idx}.linear_attn.norm.weight": f"decoder.layers_{layer_idx}.attention.norm.rms_norm.scale", + f"model.layers.{layer_idx}.linear_attn.out_proj.weight": f"decoder.layers_{layer_idx}.attention.out_proj.kernel", + } + ) + return mapping + + +def create_unscanned_layer_pytree(layer_idx) -> Dict[str, Any]: + """Creates the nested dictionary for one scanned layer.""" + if layer_idx % 4 == 3: + return { + # Common + "input_layernorm": {"scale": None}, + "post_attention_layernorm": {"scale": None}, + # MoE + "mlp": { + "shared_expert": { + "wi_0": {"kernel": None}, + "wi_1": {"kernel": None}, + "wo": {"kernel": None}, + }, + "shared_expert_gate": {"kernel": None}, + "routed_experts": { + "gate": {"kernel": None}, + "wi_0": None, + "wi_1": None, + "wo": None, + }, + }, + # Attention (will hold both GA and GDN params) + "attention": { + "attention": { + "query": {"kernel": None}, + "key": {"kernel": None}, + "value": {"kernel": None}, + "out": {"kernel": None}, + "query_norm": {"scale": None}, + "key_norm": {"scale": None}, + }, + }, + } + else: + return { + # Common + "input_layernorm": {"scale": None}, + "post_attention_layernorm": {"scale": None}, + # MoE + "mlp": { + "shared_expert": { + "wi_0": {"kernel": None}, + "wi_1": {"kernel": None}, + "wo": {"kernel": None}, + }, + "shared_expert_gate": {"kernel": None}, + "routed_experts": { + "gate": {"kernel": None}, + "wi_0": None, + "wi_1": None, + "wo": None, + }, + }, + # Attention (will hold both GA and GDN params) + "attention": { + # GDN Params + "A_log": None, + "conv1d": {"kernel": None}, + "dt_bias": None, + "in_proj_ba": {"kernel": None}, + "in_proj_qkvz": {"kernel": None}, + "norm": {"rms_norm": {"scale": None}}, + "out_proj": {"kernel": None}, + }, + } + + +def _convert_huggingface_to_jax_weights( + base_model_path: str, model_size: str, model_params: dict, mem_info: psutil.Process +): + """Convert a Huggingface Checkpoint to a dictionary of Numpy arrays representing the weights. + + Args: + base_model_path (str): Path to the base model checkpoint. + model_size (str): Size of the base model. + model_params (dict): Dictionary containing model parameters. + mem_info (psutil.Process): Process object to track memory usage. + + Returns: + jax_weights (dict): Dictionary containing the converted weights. + """ + # Load all params from config + num_layers = model_params["num_hidden_layers"] + hidden_size = model_params["hidden_size"] + num_experts = model_params["num_experts"] + ga_num_q_heads = model_params["ga_num_q_heads"] + head_dim = model_params["head_dim"] + ga_num_kv_heads = model_params["ga_num_kv_heads"] + + # load model + max_logging.log(f"Loading the base model from {base_model_path}") + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("*.safetensors")) + chkpt_vars = {} + + for i, ckpt_path in enumerate(ckpt_paths): + max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") + with safe_open(ckpt_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("model.") or key.startswith("lm_head."): + chkpt_vars[key] = f.get_tensor(key) + + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + + # Part 2: Initialize the nested MaxText weights dictionary + jax_weights = { + "token_embedder": {"embedding": None}, + "decoder": { + "decoder_norm": {"scale": None}, + "logits_dense": {"kernel": None}, + }, + } + for l in range(num_layers): + jax_weights["decoder"][f"layers_{l}"] = create_unscanned_layer_pytree(l) + + # Part 3: Populate weights + # Non-layer weights + max_logging.log("Populating non-layer weights...") + jax_weights["decoder"]["decoder_norm"]["scale"] = _pt_to_np(chkpt_vars["model.norm.weight"], cast_dtype=CAST_DTYPE) + jax_weights["token_embedder"]["embedding"] = _pt_to_np(chkpt_vars["model.embed_tokens.weight"], cast_dtype=CAST_DTYPE) + jax_weights["decoder"]["logits_dense"]["kernel"] = _pt_to_np( + chkpt_vars["lm_head.weight"], cast_dtype=CAST_DTYPE + ).transpose() + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + + # Linear + Gated Attention layers + max_logging.log("Processing linear & gated attention layers") + for l in tqdm(range(num_layers), desc="layers", leave=False): + if l % 4 == 3: + gated_attn = jax_weights["decoder"][f"layers_{l}"]["attention"]["attention"] + + k_kernel = ( + _pt_to_np(chkpt_vars[f"model.layers.{l}.self_attn.k_proj.weight"], cast_dtype=CAST_DTYPE) + .transpose() + .reshape(hidden_size, ga_num_kv_heads, head_dim) + ) + k_norm = _pt_to_np(chkpt_vars[f"model.layers.{l}.self_attn.k_norm.weight"], cast_dtype=CAST_DTYPE) + o_kernel = _pt_to_np(chkpt_vars[f"model.layers.{l}.self_attn.o_proj.weight"], cast_dtype=CAST_DTYPE).transpose() + q_kernel = ( + _pt_to_np(chkpt_vars[f"model.layers.{l}.self_attn.q_proj.weight"], cast_dtype=CAST_DTYPE) + .transpose() + .reshape(hidden_size, ga_num_q_heads, head_dim * 2) + ) + q_norm = _pt_to_np(chkpt_vars[f"model.layers.{l}.self_attn.q_norm.weight"], cast_dtype=CAST_DTYPE) + v_kernel = ( + _pt_to_np(chkpt_vars[f"model.layers.{l}.self_attn.v_proj.weight"], cast_dtype=CAST_DTYPE) + .transpose() + .reshape(hidden_size, ga_num_kv_heads, head_dim) + ) + + gated_attn["key"]["kernel"] = k_kernel + gated_attn["key_norm"]["scale"] = k_norm + gated_attn["out"]["kernel"] = o_kernel + gated_attn["query"]["kernel"] = q_kernel + gated_attn["query_norm"]["scale"] = q_norm + gated_attn["value"]["kernel"] = v_kernel + else: + lin_attn = jax_weights["decoder"][f"layers_{l}"]["attention"] + + a_log = _pt_to_np(chkpt_vars[f"model.layers.{l}.linear_attn.A_log"], cast_dtype=CAST_DTYPE) + conv1d_kernel = _pt_to_np( + chkpt_vars[f"model.layers.{l}.linear_attn.conv1d.weight"], cast_dtype=CAST_DTYPE + ).transpose(2, 1, 0) + dt_bias = _pt_to_np(chkpt_vars[f"model.layers.{l}.linear_attn.dt_bias"], cast_dtype=CAST_DTYPE) + ba_kernel = _pt_to_np( + chkpt_vars[f"model.layers.{l}.linear_attn.in_proj_ba.weight"], cast_dtype=CAST_DTYPE + ).transpose() + qkvz_kernel = _pt_to_np( + chkpt_vars[f"model.layers.{l}.linear_attn.in_proj_qkvz.weight"], cast_dtype=CAST_DTYPE + ).transpose() + gated_rms_norm = _pt_to_np(chkpt_vars[f"model.layers.{l}.linear_attn.norm.weight"], cast_dtype=CAST_DTYPE) + o_kernel = _pt_to_np(chkpt_vars[f"model.layers.{l}.linear_attn.out_proj.weight"], cast_dtype=CAST_DTYPE).transpose() + + lin_attn["A_log"] = a_log + lin_attn["conv1d"]["kernel"] = conv1d_kernel + lin_attn["dt_bias"] = dt_bias + lin_attn["in_proj_ba"]["kernel"] = ba_kernel + lin_attn["in_proj_qkvz"]["kernel"] = qkvz_kernel + lin_attn["norm"]["rms_norm"]["scale"] = gated_rms_norm + lin_attn["out_proj"]["kernel"] = o_kernel + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + + # layer weight pre and post self attention norm + max_logging.log("Processing pre and post self attention norms") + for l in tqdm(range(num_layers), desc="layers", leave=False): + layer_weight = jax_weights["decoder"][f"layers_{l}"] + + input_layernorm = _pt_to_np(chkpt_vars[f"model.layers.{l}.input_layernorm.weight"], cast_dtype=CAST_DTYPE) + post_attention_layernorm = _pt_to_np( + chkpt_vars[f"model.layers.{l}.post_attention_layernorm.weight"], cast_dtype=CAST_DTYPE + ) + + layer_weight["input_layernorm"]["scale"] = input_layernorm + layer_weight["post_attention_layernorm"]["scale"] = post_attention_layernorm + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + + # mlp weights + max_logging.log("Processing mlp layer weights") + for l in tqdm(range(num_layers), desc="layers", leave=False): + mlp_weights = jax_weights["decoder"][f"layers_{l}"]["mlp"] + + shared_wi_0 = _pt_to_np( + chkpt_vars[f"model.layers.{l}.mlp.shared_expert.gate_proj.weight"], cast_dtype=CAST_DTYPE + ).transpose() + shared_wi_1 = _pt_to_np( + chkpt_vars[f"model.layers.{l}.mlp.shared_expert.up_proj.weight"], cast_dtype=CAST_DTYPE + ).transpose() + shared_wo = _pt_to_np( + chkpt_vars[f"model.layers.{l}.mlp.shared_expert.down_proj.weight"], cast_dtype=CAST_DTYPE + ).transpose() + shared_gate_kernel = _pt_to_np( + chkpt_vars[f"model.layers.{l}.mlp.shared_expert_gate.weight"], cast_dtype=CAST_DTYPE + ).transpose() + + mlp_weights["shared_expert_gate"]["kernel"] = shared_gate_kernel + mlp_weights["shared_expert"]["wi_0"]["kernel"] = shared_wi_0 + mlp_weights["shared_expert"]["wi_1"]["kernel"] = shared_wi_1 + mlp_weights["shared_expert"]["wo"]["kernel"] = shared_wo + + wi_0_list = [] + wi_1_list = [] + wo_list = [] + routed_gate_kernel = _pt_to_np(chkpt_vars[f"model.layers.{l}.mlp.gate.weight"], cast_dtype=CAST_DTYPE).transpose() + for i in range(num_experts): + wi_0_list.append( + _pt_to_np(chkpt_vars[f"model.layers.{l}.mlp.experts.{i}.gate_proj.weight"], cast_dtype=CAST_DTYPE).transpose() + ) + wi_1_list.append( + _pt_to_np(chkpt_vars[f"model.layers.{l}.mlp.experts.{i}.up_proj.weight"], cast_dtype=CAST_DTYPE).transpose() + ) + wo_list.append( + _pt_to_np(chkpt_vars[f"model.layers.{l}.mlp.experts.{i}.down_proj.weight"], cast_dtype=CAST_DTYPE).transpose() + ) + + mlp_weights["routed_experts"]["gate"]["kernel"] = routed_gate_kernel + mlp_weights["routed_experts"]["wi_0"] = np.stack(wi_0_list, axis=0) + mlp_weights["routed_experts"]["wi_1"] = np.stack(wi_1_list, axis=0) + mlp_weights["routed_experts"]["wo"] = np.stack(wo_list, axis=0) + + gc.collect() + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + + del chkpt_vars + gc.collect() + return jax_weights + + +def convert_to_jax_weights(base_model_path: str, model_size: str): + """ + Function to convert the checkpoint at base_model_path into Orbax checkpoint + for MaxText and output jax_weights ready for MaxText + + Attributes: + base_model_path: checkpoint path + model_size: gpt-oss-20b, gpt-oss-120b + """ + model_params = MODEL_PARAMS_DICT[model_size] + mem_info = psutil.Process() + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + max_logging.log(f"Loading the base model from {base_model_path}") + return _convert_huggingface_to_jax_weights(base_model_path, model_size, model_params, mem_info) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--maxtext-model-path", type=str, required=True) + parser.add_argument("--model-size", type=str, required=True) + parser.add_argument("--simulated-cpu-devices-count", type=int, required=False, default=16) + parser.add_argument("--use-ocdbt", type=str2bool, required=False, default=True) + parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True) + args = parser.parse_args() + + if args.model_size not in MODEL_PARAMS_DICT: + raise NotImplementedError(f"Model '{args.model_size}' is not supported.") + + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}" + base_weights_path = args.maxtext_model_path + + save_weights_to_checkpoint( + args.maxtext_model_path, + convert_to_jax_weights(args.base_model_path, args.model_size), + args.simulated_cpu_devices_count, + args.use_ocdbt, + args.use_zarr3, + ) + max_logging.log(f"Successfully saved base_weights to {base_weights_path}.")