Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/MaxText/integration/tunix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def to_hf_mapping(self):
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()

config = self.config
mapping = self.convert_hf_map_to_sharding_map(PARAM_MAPPING[self.model_name](config, scan_layers=True))
mapping = self.convert_hf_map_to_sharding_map(
PARAM_MAPPING[self.model_name](config, maxtext_config=None, scan_layers=True)
)
return mapping

def to_hf_transpose_keys(self):
Expand Down
108 changes: 94 additions & 14 deletions src/MaxText/utils/ckpt_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@

import jax
import os
from typing import Sequence, Any
from typing import Sequence
import time
from tqdm import tqdm
import numpy as np

from transformers import AutoTokenizer, AutoProcessor

Expand All @@ -71,20 +72,23 @@
)
from MaxText.utils.ckpt_conversion.utils.hf_shape import HF_SHAPE
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, HF_IDS)
from MaxText.utils.ckpt_conversion.utils.utils import process_maxtext_param, save_model_files, HF_IDS


os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"


def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
def _get_model_mappings(
model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters
):
"""Retrieves parameter, shape, and hook function mappings for the model.

Args:
model_name: The name of the model (e.g., "gemma2-2b").
scan_layers: Boolean indicating if the model was trained with scanned layers.
config_dict: The Hugging Face model configuration dictionary.
hf_config_dict: The Hugging Face model configuration dictionary.
maxtext_config: The maxtext model configuration.

Returns:
A dictionary containing the parameter mapping, shape mapping, and hook
Expand All @@ -97,12 +101,65 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")

return {
"param_mapping": PARAM_MAPPING[model_name](config_dict, scan_layers),
"shape_mapping": HF_SHAPE[model_name](config_dict),
"hook_fn_mapping": HOOK_FNS[model_name](config_dict, scan_layers, saving_to_hf=True),
"param_mapping": PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers),
"shape_mapping": HF_SHAPE[model_name](hf_config_dict),
"hook_fn_mapping": HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True),
}


def _check_param_map_keys(param_map_keys, maxtext_state_keys):
"""Validates map coverage, handles N-to-1 mappings, and filters unused keys.

Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by
the flattened parameter map. Keys in the map that are not present in the
checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.

Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one
target key) and are only returned if all constituent keys exist in the checkpoint.

Args:
param_map_keys: Keys from the parameter mapping (strings or N-to-1 tuples).
maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.

Returns:
A list of 'filtered' mapping keys (strings or tuples) that are fully present
and valid based on `maxtext_state_keys`.

Raises:
ValueError: If `maxtext_state_keys` is NOT a subset of the flattened
`param_map_keys`.
"""
flattened_map_keys = set()
for key in param_map_keys:
if isinstance(key, tuple):
flattened_map_keys.update(key)
else:
flattened_map_keys.add(key)

# every maxtext state key must be covered by param map
missing_keys = maxtext_state_keys - flattened_map_keys
if missing_keys:
raise ValueError(
"maxtext_state_dict must be a subset of flattened param_map"
+ f"\nparam map\n{param_map_keys}"
+ f"\nmaxtext:\n{maxtext_state_keys}"
)

# param map may have extra keys
extra_keys = flattened_map_keys - maxtext_state_keys
if extra_keys:
max_logging.log(f"Warning: extra keys in param_map are skipped: {extra_keys}")

# skip extra keys in param map
filtered_map_keys = []
for key in param_map_keys:
if (isinstance(key, str) and key in maxtext_state_keys) or (
isinstance(key, tuple) and all(k in maxtext_state_keys for k in key)
):
filtered_map_keys.append(key)
return filtered_map_keys


def main(argv: Sequence[str]) -> None:
"""Main function to convert a MaxText checkpoint to HuggingFace format.

Expand Down Expand Up @@ -156,29 +213,52 @@ def main(argv: Sequence[str]) -> None:
processor = AutoProcessor.from_pretrained(hf_tokenizer_id, token=hf_token) if config.use_multimodal else None

# 3. Get parameter mappings
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict())
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict(), config)
param_map = mappings["param_mapping"]
shape_map = mappings["shape_mapping"] # HF target shapes
hook_fn_map = mappings["hook_fn_mapping"]

# 4. Transform Weights
transformed_hf_weights: dict[str, Any] = {}

# MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
# The actual weights are typically under `state.params['params']`.
actual_weights_dict = loaded_params_from_engine.get("params")
if actual_weights_dict is None:
raise ValueError("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected.")

leaves_with_paths = jax.tree_util.tree_leaves_with_path(actual_weights_dict)

# traverse leavse to build: mt_param_key:mt_weights
# Construct maxtext_state_dict: {parameter name: parameter weight}
maxtext_state_dict = {}
for path_tuple, leaf_value in leaves_with_paths:
# Construct maxtext_param_key from path_tuple
maxtext_param_key = "params-" + "-".join(k.key for k in path_tuple)
# Check leaf value is an array
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
maxtext_state_dict[maxtext_param_key] = leaf_value

# The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
# Check maxtext_state_dict is a subset of flattened param_map
# Skip extra keys from param_map
filtered_map_keys = _check_param_map_keys(param_map.keys(), maxtext_state_dict.keys())

# Iterate through the parameter map to transform and collect weights.
# This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings
# (where multiple MaxText weights are combined into a single HF weight).
max_logging.log("\nProccessing weight...")
start = time.time()
processed_params_list = []
for path_tuple_iter, leaf_value_iter in tqdm(leaves_with_paths, total=len(leaves_with_paths)):
processed_params = process_leaf_param(path_tuple_iter, leaf_value_iter, param_map, shape_map, hook_fn_map, config)

for key in tqdm(filtered_map_keys, total=len(filtered_map_keys)):
if isinstance(key, tuple):
# if key is tuple of param names, weight is list of param weights
weight = [maxtext_state_dict[subkey] for subkey in key]
else:
# if key is single param name, weight is single param weight
weight = maxtext_state_dict[key]

processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
processed_params_list.extend(processed_params)

transformed_hf_weights = dict(processed_params_list)
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")

Expand Down
4 changes: 2 additions & 2 deletions src/MaxText/utils/ckpt_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,11 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
# f"model.layers.{global_layer_idx}.input_layernorm.weight",

model_key = config.model_name
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config.scan_layers)
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers)

# Example of Hook FN mapping, to perform reshape:
# f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config.scan_layers, saving_to_hf=False)
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False)
max_logging.log("Parameter mappings and hooks obtained.")

max_logging.log("Starting weight transformation...")
Expand Down
157 changes: 156 additions & 1 deletion src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@
vocab_size=151936,
)

# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
# remove fp8 quantization_config, since we are using bf16
deepseek3_671b_dict = {
"architectures": ["DeepseekV3ForCausalLM"],
"attention_bias": False,
Expand Down Expand Up @@ -527,7 +529,157 @@
}
deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict)

# {maxtext model name: hf model config}
# copy from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json
# remove mxfp4 quantization_config, since we are using bf16
gpt_oss_20b_dict = {
"architectures": ["GptOssForCausalLM"],
"attention_bias": True,
"attention_dropout": 0.0,
"eos_token_id": 200002,
"experts_per_token": 4,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2880,
"initial_context_length": 4096,
"initializer_range": 0.02,
"intermediate_size": 2880,
"layer_types": [
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
],
"max_position_embeddings": 131072,
"model_type": "gpt_oss",
"num_attention_heads": 64,
"num_experts_per_tok": 4,
"num_hidden_layers": 24,
"num_key_value_heads": 8,
"num_local_experts": 32,
"output_router_logits": False,
"pad_token_id": 199999,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"beta_fast": 32.0,
"beta_slow": 1.0,
"factor": 32.0,
"original_max_position_embeddings": 4096,
"rope_type": "yarn",
"truncate": False,
},
"rope_theta": 150000,
"router_aux_loss_coef": 0.9,
"sliding_window": 128,
"swiglu_limit": 7.0,
"tie_word_embeddings": False,
"transformers_version": "4.55.0.dev0",
"use_cache": True,
"vocab_size": 201088,
}
gpt_oss_20b_config = transformers.GptOssConfig(**gpt_oss_20b_dict)

# copy from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json
# remove mxfp4 quantization_config, since we are using bf16
gpt_oss_120b_dict = {
"architectures": ["GptOssForCausalLM"],
"attention_bias": True,
"attention_dropout": 0.0,
"eos_token_id": 200002,
"experts_per_token": 4,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2880,
"initial_context_length": 4096,
"initializer_range": 0.02,
"intermediate_size": 2880,
"layer_types": [
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"full_attention",
],
"max_position_embeddings": 131072,
"model_type": "gpt_oss",
"num_attention_heads": 64,
"num_experts_per_tok": 4,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"num_local_experts": 128,
"output_router_logits": False,
"pad_token_id": 199999,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"beta_fast": 32.0,
"beta_slow": 1.0,
"factor": 32.0,
"original_max_position_embeddings": 4096,
"rope_type": "yarn",
"truncate": False,
},
"rope_theta": 150000,
"router_aux_loss_coef": 0.9,
"sliding_window": 128,
"swiglu_limit": 7.0,
"tie_word_embeddings": False,
"transformers_version": "4.55.0.dev0",
"use_cache": True,
"vocab_size": 201088,
}
gpt_oss_120b_config = transformers.GptOssConfig(**gpt_oss_120b_dict)


qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig(
# TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts
architectures=["Qwen3OmniMoeForConditionalGeneration"],
Expand All @@ -539,6 +691,7 @@
},
)

# {maxtext model name: hf model config}
HF_MODEL_CONFIGS = {
"gemma2-2b": gemma2_2b_config,
"gemma2-9b": gemma2_9b_config,
Expand All @@ -560,5 +713,7 @@
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
"deepseek3-671b": deepseek3_671b_config,
"gpt-oss-20b": gpt_oss_20b_config,
"gpt-oss-120b": gpt_oss_120b_config,
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
}
Loading
Loading