Skip to content

Commit 2fdb878

Browse files
committed
Checkpoint conversion utility: gpt-oss orbax scan to hf, many-to-one transform, inhomogeneous scan block
1 parent 11d5e08 commit 2fdb878

File tree

7 files changed

+607
-107
lines changed

7 files changed

+607
-107
lines changed

src/MaxText/integration/tunix/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def to_hf_mapping(self):
130130
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_mapping()
131131

132132
config = self.config
133-
mapping = self.convert_hf_map_to_sharding_map(PARAM_MAPPING[self.model_name](config, scan_layers=True))
133+
mapping = self.convert_hf_map_to_sharding_map(
134+
PARAM_MAPPING[self.model_name](config, maxtext_config=None, scan_layers=True)
135+
)
134136
return mapping
135137

136138
def to_hf_transpose_keys(self):

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@
5353

5454
import jax
5555
import os
56-
from typing import Sequence, Any
56+
from typing import Sequence
5757
import time
5858
from tqdm import tqdm
59+
import numpy as np
5960

6061
from transformers import AutoTokenizer, AutoProcessor
6162

@@ -71,20 +72,21 @@
7172
)
7273
from MaxText.utils.ckpt_conversion.utils.hf_shape import HF_SHAPE
7374
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
74-
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, HF_IDS)
75+
from MaxText.utils.ckpt_conversion.utils.utils import process_maxtext_param, save_model_files, HF_IDS
7576

7677

7778
os.environ["JAX_PLATFORMS"] = "cpu"
7879
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
7980

8081

81-
def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
82+
def _get_model_mappings(model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters):
8283
"""Retrieves parameter, shape, and hook function mappings for the model.
8384
8485
Args:
8586
model_name: The name of the model (e.g., "gemma2-2b").
8687
scan_layers: Boolean indicating if the model was trained with scanned layers.
87-
config_dict: The Hugging Face model configuration dictionary.
88+
hf_config_dict: The Hugging Face model configuration dictionary.
89+
maxtext_config: The maxtext model configuration.
8890
8991
Returns:
9092
A dictionary containing the parameter mapping, shape mapping, and hook
@@ -97,12 +99,65 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
9799
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")
98100

99101
return {
100-
"param_mapping": PARAM_MAPPING[model_name](config_dict, scan_layers),
101-
"shape_mapping": HF_SHAPE[model_name](config_dict),
102-
"hook_fn_mapping": HOOK_FNS[model_name](config_dict, scan_layers, saving_to_hf=True),
102+
"param_mapping": PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers),
103+
"shape_mapping": HF_SHAPE[model_name](hf_config_dict),
104+
"hook_fn_mapping": HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True),
103105
}
104106

105107

108+
def _check_param_map_keys(param_map_keys, maxtext_state_keys):
109+
"""Validates map coverage, handles N-to-1 mappings, and filters unused keys.
110+
111+
Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by
112+
the flattened parameter map. Keys in the map that are not present in the
113+
checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.
114+
115+
Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one
116+
target key) and are only returned if all constituent keys exist in the checkpoint.
117+
118+
Args:
119+
param_map_keys: Keys from the parameter mapping (strings or N-to-1 tuples).
120+
maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.
121+
122+
Returns:
123+
A list of 'filtered' mapping keys (strings or tuples) that are fully present
124+
and valid based on `maxtext_state_keys`.
125+
126+
Raises:
127+
ValueError: If `maxtext_state_keys` is NOT a subset of the flattened
128+
`param_map_keys`.
129+
"""
130+
flattened_map_keys = set()
131+
for key in param_map_keys:
132+
if isinstance(key, tuple):
133+
flattened_map_keys.update(key)
134+
else:
135+
flattened_map_keys.add(key)
136+
137+
# every maxtext state key must be covered by param map
138+
missing_keys = maxtext_state_keys - flattened_map_keys
139+
if missing_keys:
140+
raise ValueError(
141+
"maxtext_state_dict must be a subset of flattened param_map"
142+
+ f"\nparam map\n{param_map_keys}"
143+
+ f"\nmaxtext:\n{maxtext_state_keys}"
144+
)
145+
146+
# param map may have extra keys
147+
extra_keys = flattened_map_keys - maxtext_state_keys
148+
if extra_keys:
149+
max_logging.log(f"Warning: extra keys in param_map are skipped: {extra_keys}")
150+
151+
# skip extra keys in param map
152+
filtered_map_keys = []
153+
for key in param_map_keys:
154+
if (isinstance(key, str) and key in maxtext_state_keys) or (
155+
isinstance(key, tuple) and all(k in maxtext_state_keys for k in key)
156+
):
157+
filtered_map_keys.append(key)
158+
return filtered_map_keys
159+
160+
106161
def main(argv: Sequence[str]) -> None:
107162
"""Main function to convert a MaxText checkpoint to HuggingFace format.
108163
@@ -156,29 +211,52 @@ def main(argv: Sequence[str]) -> None:
156211
processor = AutoProcessor.from_pretrained(hf_tokenizer_id, token=hf_token) if config.use_multimodal else None
157212

158213
# 3. Get parameter mappings
159-
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict())
214+
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict(), config)
160215
param_map = mappings["param_mapping"]
161216
shape_map = mappings["shape_mapping"] # HF target shapes
162217
hook_fn_map = mappings["hook_fn_mapping"]
163218

164219
# 4. Transform Weights
165-
transformed_hf_weights: dict[str, Any] = {}
166-
167220
# MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
168221
# The actual weights are typically under `state.params['params']`.
169222
actual_weights_dict = loaded_params_from_engine.get("params")
170223
if actual_weights_dict is None:
171224
raise ValueError("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected.")
172-
173225
leaves_with_paths = jax.tree_util.tree_leaves_with_path(actual_weights_dict)
174226

175-
# traverse leavse to build: mt_param_key:mt_weights
227+
# Construct maxtext_state_dict: {parameter name: parameter weight}
228+
maxtext_state_dict = {}
229+
for path_tuple, leaf_value in leaves_with_paths:
230+
# Construct maxtext_param_key from path_tuple
231+
maxtext_param_key = "params-" + "-".join(k.key for k in path_tuple)
232+
# Check leaf value is an array
233+
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
234+
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
235+
maxtext_state_dict[maxtext_param_key] = leaf_value
236+
237+
# The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
238+
# Check maxtext_state_dict is a subset of flattened param_map
239+
# Skip extra keys from param_map
240+
filtered_map_keys = _check_param_map_keys(param_map.keys(), maxtext_state_dict.keys())
241+
242+
# Iterate through the parameter map to transform and collect weights.
243+
# This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings
244+
# (where multiple MaxText weights are combined into a single HF weight).
176245
max_logging.log("\nProccessing weight...")
177246
start = time.time()
178247
processed_params_list = []
179-
for path_tuple_iter, leaf_value_iter in tqdm(leaves_with_paths, total=len(leaves_with_paths)):
180-
processed_params = process_leaf_param(path_tuple_iter, leaf_value_iter, param_map, shape_map, hook_fn_map, config)
248+
249+
for key in tqdm(filtered_map_keys, total=len(filtered_map_keys)):
250+
if isinstance(key, tuple):
251+
# if key is tuple of param names, weight is list of param weights
252+
weight = [maxtext_state_dict[subkey] for subkey in key]
253+
else:
254+
# if key is single param name, weight is single param weight
255+
weight = maxtext_state_dict[key]
256+
257+
processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
181258
processed_params_list.extend(processed_params)
259+
182260
transformed_hf_weights = dict(processed_params_list)
183261
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
184262

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,11 @@ def main(args: Sequence[str], test_args: Sequence[str]) -> None:
466466
# f"model.layers.{global_layer_idx}.input_layernorm.weight",
467467

468468
model_key = config.model_name
469-
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config.scan_layers)
469+
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers)
470470

471471
# Example of Hook FN mapping, to perform reshape:
472472
# f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
473-
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config.scan_layers, saving_to_hf=False)
473+
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False)
474474
max_logging.log("Parameter mappings and hooks obtained.")
475475

476476
max_logging.log("Starting weight transformation...")

src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@
469469
vocab_size=151936,
470470
)
471471

472+
# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
473+
# remove fp8 quantization_config, since we are using bf16
472474
deepseek3_671b_dict = {
473475
"architectures": ["DeepseekV3ForCausalLM"],
474476
"attention_bias": False,
@@ -527,7 +529,157 @@
527529
}
528530
deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict)
529531

530-
# {maxtext model name: hf model config}
532+
# copy from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json
533+
# remove mxfp4 quantization_config, since we are using bf16
534+
gpt_oss_20b_dict = {
535+
"architectures": ["GptOssForCausalLM"],
536+
"attention_bias": True,
537+
"attention_dropout": 0.0,
538+
"eos_token_id": 200002,
539+
"experts_per_token": 4,
540+
"head_dim": 64,
541+
"hidden_act": "silu",
542+
"hidden_size": 2880,
543+
"initial_context_length": 4096,
544+
"initializer_range": 0.02,
545+
"intermediate_size": 2880,
546+
"layer_types": [
547+
"sliding_attention",
548+
"full_attention",
549+
"sliding_attention",
550+
"full_attention",
551+
"sliding_attention",
552+
"full_attention",
553+
"sliding_attention",
554+
"full_attention",
555+
"sliding_attention",
556+
"full_attention",
557+
"sliding_attention",
558+
"full_attention",
559+
"sliding_attention",
560+
"full_attention",
561+
"sliding_attention",
562+
"full_attention",
563+
"sliding_attention",
564+
"full_attention",
565+
"sliding_attention",
566+
"full_attention",
567+
"sliding_attention",
568+
"full_attention",
569+
"sliding_attention",
570+
"full_attention",
571+
],
572+
"max_position_embeddings": 131072,
573+
"model_type": "gpt_oss",
574+
"num_attention_heads": 64,
575+
"num_experts_per_tok": 4,
576+
"num_hidden_layers": 24,
577+
"num_key_value_heads": 8,
578+
"num_local_experts": 32,
579+
"output_router_logits": False,
580+
"pad_token_id": 199999,
581+
"rms_norm_eps": 1e-05,
582+
"rope_scaling": {
583+
"beta_fast": 32.0,
584+
"beta_slow": 1.0,
585+
"factor": 32.0,
586+
"original_max_position_embeddings": 4096,
587+
"rope_type": "yarn",
588+
"truncate": False,
589+
},
590+
"rope_theta": 150000,
591+
"router_aux_loss_coef": 0.9,
592+
"sliding_window": 128,
593+
"swiglu_limit": 7.0,
594+
"tie_word_embeddings": False,
595+
"transformers_version": "4.55.0.dev0",
596+
"use_cache": True,
597+
"vocab_size": 201088,
598+
}
599+
gpt_oss_20b_config = transformers.GptOssConfig(**gpt_oss_20b_dict)
600+
601+
# copy from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json
602+
# remove mxfp4 quantization_config, since we are using bf16
603+
gpt_oss_120b_dict = {
604+
"architectures": ["GptOssForCausalLM"],
605+
"attention_bias": True,
606+
"attention_dropout": 0.0,
607+
"eos_token_id": 200002,
608+
"experts_per_token": 4,
609+
"head_dim": 64,
610+
"hidden_act": "silu",
611+
"hidden_size": 2880,
612+
"initial_context_length": 4096,
613+
"initializer_range": 0.02,
614+
"intermediate_size": 2880,
615+
"layer_types": [
616+
"sliding_attention",
617+
"full_attention",
618+
"sliding_attention",
619+
"full_attention",
620+
"sliding_attention",
621+
"full_attention",
622+
"sliding_attention",
623+
"full_attention",
624+
"sliding_attention",
625+
"full_attention",
626+
"sliding_attention",
627+
"full_attention",
628+
"sliding_attention",
629+
"full_attention",
630+
"sliding_attention",
631+
"full_attention",
632+
"sliding_attention",
633+
"full_attention",
634+
"sliding_attention",
635+
"full_attention",
636+
"sliding_attention",
637+
"full_attention",
638+
"sliding_attention",
639+
"full_attention",
640+
"sliding_attention",
641+
"full_attention",
642+
"sliding_attention",
643+
"full_attention",
644+
"sliding_attention",
645+
"full_attention",
646+
"sliding_attention",
647+
"full_attention",
648+
"sliding_attention",
649+
"full_attention",
650+
"sliding_attention",
651+
"full_attention",
652+
],
653+
"max_position_embeddings": 131072,
654+
"model_type": "gpt_oss",
655+
"num_attention_heads": 64,
656+
"num_experts_per_tok": 4,
657+
"num_hidden_layers": 36,
658+
"num_key_value_heads": 8,
659+
"num_local_experts": 128,
660+
"output_router_logits": False,
661+
"pad_token_id": 199999,
662+
"rms_norm_eps": 1e-05,
663+
"rope_scaling": {
664+
"beta_fast": 32.0,
665+
"beta_slow": 1.0,
666+
"factor": 32.0,
667+
"original_max_position_embeddings": 4096,
668+
"rope_type": "yarn",
669+
"truncate": False,
670+
},
671+
"rope_theta": 150000,
672+
"router_aux_loss_coef": 0.9,
673+
"sliding_window": 128,
674+
"swiglu_limit": 7.0,
675+
"tie_word_embeddings": False,
676+
"transformers_version": "4.55.0.dev0",
677+
"use_cache": True,
678+
"vocab_size": 201088,
679+
}
680+
gpt_oss_120b_config = transformers.GptOssConfig(**gpt_oss_120b_dict)
681+
682+
531683
qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig(
532684
# TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts
533685
architectures=["Qwen3OmniMoeForConditionalGeneration"],
@@ -539,6 +691,7 @@
539691
},
540692
)
541693

694+
# {maxtext model name: hf model config}
542695
HF_MODEL_CONFIGS = {
543696
"gemma2-2b": gemma2_2b_config,
544697
"gemma2-9b": gemma2_9b_config,
@@ -560,5 +713,7 @@
560713
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
561714
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
562715
"deepseek3-671b": deepseek3_671b_config,
716+
"gpt-oss-20b": gpt_oss_20b_config,
717+
"gpt-oss-120b": gpt_oss_120b_config,
563718
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
564719
}

0 commit comments

Comments
 (0)