Skip to content

Commit 86e7885

Browse files
committed
Checkpoint conversion utility: gpt-oss orbax scan to hf, inhomogeneous scan block, many-to-one transform
1 parent bfdb7ed commit 86e7885

File tree

6 files changed

+530
-90
lines changed

6 files changed

+530
-90
lines changed

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from typing import Sequence, Any
5757
import time
5858
from tqdm import tqdm
59+
import numpy as np
5960

6061
from transformers import AutoTokenizer, AutoProcessor
6162

@@ -71,20 +72,22 @@
7172
)
7273
from MaxText.utils.ckpt_conversion.utils.hf_shape import HF_SHAPE
7374
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
74-
from MaxText.utils.ckpt_conversion.utils.utils import (process_leaf_param, save_model_files, HF_IDS)
75+
from MaxText.utils.ckpt_conversion.utils.utils import process_maxtext_param, save_model_files, HF_IDS
7576

7677

7778
os.environ["JAX_PLATFORMS"] = "cpu"
7879
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
7980

8081

81-
def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
82+
def _get_model_mappings(model_name: str, scan_layers: bool, hf_config_dict: dict, inhomogeneous_layer_cycle_interval: int):
8283
"""Retrieves parameter, shape, and hook function mappings for the model.
8384
8485
Args:
8586
model_name: The name of the model (e.g., "gemma2-2b").
8687
scan_layers: Boolean indicating if the model was trained with scanned layers.
87-
config_dict: The Hugging Face model configuration dictionary.
88+
hf_config_dict: The Hugging Face model configuration dictionary.
89+
inhomogeneous_layer_cycle_interval: For complex architectures (llama4, gpt-oss), there are repeated sets of
90+
n inhomogeneous layers. The scan structure has n blocks.
8891
8992
Returns:
9093
A dictionary containing the parameter mapping, shape mapping, and hook
@@ -97,12 +100,42 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
97100
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")
98101

99102
return {
100-
"param_mapping": PARAM_MAPPING[model_name](config_dict, scan_layers),
101-
"shape_mapping": HF_SHAPE[model_name](config_dict),
102-
"hook_fn_mapping": HOOK_FNS[model_name](config_dict, scan_layers, saving_to_hf=True),
103+
"param_mapping": PARAM_MAPPING[model_name](hf_config_dict, scan_layers, inhomogeneous_layer_cycle_interval),
104+
"shape_mapping": HF_SHAPE[model_name](hf_config_dict),
105+
"hook_fn_mapping": HOOK_FNS[model_name](
106+
hf_config_dict, scan_layers, inhomogeneous_layer_cycle_interval, saving_to_hf=True
107+
),
103108
}
104109

105110

111+
def _check_param_map_keys(param_map_keys, maxtext_state_keys):
112+
"""Verifies that the keys in the parameter map match the keys in the MaxText state.
113+
114+
This function handles cases where the parameter map contains tuples as keys,
115+
which represent N-to-1 mappings (multiple MaxText parameters combined into one
116+
Hugging Face parameter). It flattens these tuples to ensure every individual
117+
MaxText parameter is accounted for.
118+
119+
Args:
120+
param_map_keys: The keys from the parameter mapping dictionary.
121+
maxtext_state_keys: A set of all parameter keys from the loaded MaxText
122+
checkpoint.
123+
"""
124+
flattened_map_keys = set()
125+
for key in param_map_keys:
126+
if isinstance(key, tuple):
127+
flattened_map_keys.update(key)
128+
else:
129+
flattened_map_keys.add(key)
130+
131+
if flattened_map_keys != maxtext_state_keys:
132+
raise ValueError(
133+
f"param_map and maxtext_state_dict have different keys."
134+
+ f"\nparam map\n{param_map_keys}"
135+
+ f"\nmaxtext:\n{maxtext_state_keys}"
136+
)
137+
138+
106139
def main(argv: Sequence[str]) -> None:
107140
"""Main function to convert a MaxText checkpoint to HuggingFace format.
108141
@@ -156,29 +189,51 @@ def main(argv: Sequence[str]) -> None:
156189
processor = AutoProcessor.from_pretrained(hf_tokenizer_id, token=hf_token) if config.use_multimodal else None
157190

158191
# 3. Get parameter mappings
159-
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict())
192+
mappings = _get_model_mappings(
193+
model_key, config.scan_layers, hf_config_obj.to_dict(), config.inhomogeneous_layer_cycle_interval
194+
)
160195
param_map = mappings["param_mapping"]
161196
shape_map = mappings["shape_mapping"] # HF target shapes
162197
hook_fn_map = mappings["hook_fn_mapping"]
163198

164199
# 4. Transform Weights
165-
transformed_hf_weights: dict[str, Any] = {}
166-
167200
# MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
168201
# The actual weights are typically under `state.params['params']`.
169202
actual_weights_dict = loaded_params_from_engine.get("params")
170203
if actual_weights_dict is None:
171204
raise ValueError("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected.")
172-
173205
leaves_with_paths = jax.tree_util.tree_leaves_with_path(actual_weights_dict)
174206

175-
# traverse leavse to build: mt_param_key:mt_weights
207+
# Construct maxtext_state_dict: {parameter name: parameter weight}
208+
maxtext_state_dict = {}
209+
for path_tuple, leaf_value in leaves_with_paths:
210+
# Construct maxtext_param_key from path_tuple
211+
maxtext_param_key = "params-" + "-".join(k.key for k in path_tuple)
212+
# Check leaf value is an array
213+
if not isinstance(leaf_value, (jax.Array, np.ndarray)):
214+
raise ValueError(f"Leaf value for {maxtext_param_key} is not an array. Type: {type(leaf_value)}.")
215+
maxtext_state_dict[maxtext_param_key] = leaf_value
216+
217+
# The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
218+
# Check param_map after flattening has the same keys as maxtext_state_dict
219+
_check_param_map_keys(param_map.keys(), maxtext_state_dict.keys())
220+
221+
# Iterate over param_map to build: mt_param_key:mt_weights
176222
max_logging.log("\nProccessing weight...")
177223
start = time.time()
178224
processed_params_list = []
179-
for path_tuple_iter, leaf_value_iter in tqdm(leaves_with_paths, total=len(leaves_with_paths)):
180-
processed_params = process_leaf_param(path_tuple_iter, leaf_value_iter, param_map, shape_map, hook_fn_map, config)
225+
226+
for key in tqdm(param_map, total=len(param_map)):
227+
if isinstance(key, tuple):
228+
# key is tuple of param names, weight is list of param weights
229+
weight = [maxtext_state_dict[subkey] for subkey in key]
230+
else:
231+
# key is single param name, weight is single param weight
232+
weight = maxtext_state_dict[key]
233+
234+
processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
181235
processed_params_list.extend(processed_params)
236+
182237
transformed_hf_weights = dict(processed_params_list)
183238
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
184239

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,15 @@ def main(argv: Sequence[str]) -> None:
273273
# f"model.layers.{global_layer_idx}.input_layernorm.weight",
274274

275275
model_key = config.model_name
276-
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config.scan_layers)
276+
param_map_mt_to_hf = PARAM_MAPPING[model_key](
277+
hf_config_obj.to_dict(), config.scan_layers, config.inhomogeneous_layer_cycle_interval
278+
)
277279

278280
# Example of Hook FN mapping, to perform reshape:
279281
# f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
280-
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config.scan_layers, saving_to_hf=False)
282+
hook_fn_map_mt = HOOK_FNS[model_key](
283+
hf_config_obj.to_dict(), config.scan_layers, config.inhomogeneous_layer_cycle_interval, saving_to_hf=False
284+
)
281285
max_logging.log("Parameter mappings and hooks obtained.")
282286

283287
# Transform weights

src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@
469469
vocab_size=151936,
470470
)
471471

472+
# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
473+
# remove fp8 quantization_config, since we are using bf16
472474
deepseek3_671b_dict = {
473475
"architectures": ["DeepseekV3ForCausalLM"],
474476
"attention_bias": False,
@@ -527,7 +529,157 @@
527529
}
528530
deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict)
529531

530-
# {maxtext model name: hf model config}
532+
# copy from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json
533+
# remove mxfp4 quantization_config, since we are using bf16
534+
gpt_oss_20b_dict = {
535+
"architectures": ["GptOssForCausalLM"],
536+
"attention_bias": True,
537+
"attention_dropout": 0.0,
538+
"eos_token_id": 200002,
539+
"experts_per_token": 4,
540+
"head_dim": 64,
541+
"hidden_act": "silu",
542+
"hidden_size": 2880,
543+
"initial_context_length": 4096,
544+
"initializer_range": 0.02,
545+
"intermediate_size": 2880,
546+
"layer_types": [
547+
"sliding_attention",
548+
"full_attention",
549+
"sliding_attention",
550+
"full_attention",
551+
"sliding_attention",
552+
"full_attention",
553+
"sliding_attention",
554+
"full_attention",
555+
"sliding_attention",
556+
"full_attention",
557+
"sliding_attention",
558+
"full_attention",
559+
"sliding_attention",
560+
"full_attention",
561+
"sliding_attention",
562+
"full_attention",
563+
"sliding_attention",
564+
"full_attention",
565+
"sliding_attention",
566+
"full_attention",
567+
"sliding_attention",
568+
"full_attention",
569+
"sliding_attention",
570+
"full_attention",
571+
],
572+
"max_position_embeddings": 131072,
573+
"model_type": "gpt_oss",
574+
"num_attention_heads": 64,
575+
"num_experts_per_tok": 4,
576+
"num_hidden_layers": 24,
577+
"num_key_value_heads": 8,
578+
"num_local_experts": 32,
579+
"output_router_logits": False,
580+
"pad_token_id": 199999,
581+
"rms_norm_eps": 1e-05,
582+
"rope_scaling": {
583+
"beta_fast": 32.0,
584+
"beta_slow": 1.0,
585+
"factor": 32.0,
586+
"original_max_position_embeddings": 4096,
587+
"rope_type": "yarn",
588+
"truncate": False,
589+
},
590+
"rope_theta": 150000,
591+
"router_aux_loss_coef": 0.9,
592+
"sliding_window": 128,
593+
"swiglu_limit": 7.0,
594+
"tie_word_embeddings": False,
595+
"transformers_version": "4.55.0.dev0",
596+
"use_cache": True,
597+
"vocab_size": 201088,
598+
}
599+
gpt_oss_20b_config = transformers.GptOssConfig(**gpt_oss_20b_dict)
600+
601+
# copy from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json
602+
# remove mxfp4 quantization_config, since we are using bf16
603+
gpt_oss_120b_dict = {
604+
"architectures": ["GptOssForCausalLM"],
605+
"attention_bias": True,
606+
"attention_dropout": 0.0,
607+
"eos_token_id": 200002,
608+
"experts_per_token": 4,
609+
"head_dim": 64,
610+
"hidden_act": "silu",
611+
"hidden_size": 2880,
612+
"initial_context_length": 4096,
613+
"initializer_range": 0.02,
614+
"intermediate_size": 2880,
615+
"layer_types": [
616+
"sliding_attention",
617+
"full_attention",
618+
"sliding_attention",
619+
"full_attention",
620+
"sliding_attention",
621+
"full_attention",
622+
"sliding_attention",
623+
"full_attention",
624+
"sliding_attention",
625+
"full_attention",
626+
"sliding_attention",
627+
"full_attention",
628+
"sliding_attention",
629+
"full_attention",
630+
"sliding_attention",
631+
"full_attention",
632+
"sliding_attention",
633+
"full_attention",
634+
"sliding_attention",
635+
"full_attention",
636+
"sliding_attention",
637+
"full_attention",
638+
"sliding_attention",
639+
"full_attention",
640+
"sliding_attention",
641+
"full_attention",
642+
"sliding_attention",
643+
"full_attention",
644+
"sliding_attention",
645+
"full_attention",
646+
"sliding_attention",
647+
"full_attention",
648+
"sliding_attention",
649+
"full_attention",
650+
"sliding_attention",
651+
"full_attention",
652+
],
653+
"max_position_embeddings": 131072,
654+
"model_type": "gpt_oss",
655+
"num_attention_heads": 64,
656+
"num_experts_per_tok": 4,
657+
"num_hidden_layers": 36,
658+
"num_key_value_heads": 8,
659+
"num_local_experts": 128,
660+
"output_router_logits": False,
661+
"pad_token_id": 199999,
662+
"rms_norm_eps": 1e-05,
663+
"rope_scaling": {
664+
"beta_fast": 32.0,
665+
"beta_slow": 1.0,
666+
"factor": 32.0,
667+
"original_max_position_embeddings": 4096,
668+
"rope_type": "yarn",
669+
"truncate": False,
670+
},
671+
"rope_theta": 150000,
672+
"router_aux_loss_coef": 0.9,
673+
"sliding_window": 128,
674+
"swiglu_limit": 7.0,
675+
"tie_word_embeddings": False,
676+
"transformers_version": "4.55.0.dev0",
677+
"use_cache": True,
678+
"vocab_size": 201088,
679+
}
680+
gpt_oss_120b_config = transformers.GptOssConfig(**gpt_oss_120b_dict)
681+
682+
531683
qwen3_omni_30b_a3b_config = transformers.Qwen3OmniMoeConfig(
532684
# TODO(hengtaoguo): Pure-text Omni model, need to fill in visual/audio/code2wav parts
533685
architectures=["Qwen3OmniMoeForConditionalGeneration"],
@@ -539,6 +691,7 @@
539691
},
540692
)
541693

694+
# {maxtext model name: hf model config}
542695
HF_MODEL_CONFIGS = {
543696
"gemma2-2b": gemma2_2b_config,
544697
"gemma2-9b": gemma2_9b_config,
@@ -560,5 +713,7 @@
560713
"qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config,
561714
"qwen3-480b-a35b": qwen3_coder_480b_a35b_config,
562715
"deepseek3-671b": deepseek3_671b_config,
716+
"gpt-oss-20b": gpt_oss_20b_config,
717+
"gpt-oss-120b": gpt_oss_120b_config,
563718
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
564719
}

0 commit comments

Comments
 (0)