Skip to content

Commit 20f2de7

Browse files
committed
.
1 parent c7e12b4 commit 20f2de7

File tree

3 files changed

+29
-37
lines changed

3 files changed

+29
-37
lines changed

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,20 +79,14 @@
7979
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
8080

8181

82-
def _get_model_mappings(model_name: str, scan_layers: bool, hf_config_dict: dict, inhomogeneous_layer_cycle_interval: int):
82+
def _get_model_mappings(model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: Any):
8383
"""Retrieves parameter, shape, and hook function mappings for the model.
8484
8585
Args:
8686
model_name: The name of the model (e.g., "gemma2-2b").
8787
scan_layers: Boolean indicating if the model was trained with scanned layers.
8888
hf_config_dict: The Hugging Face model configuration dictionary.
89-
inhomogeneous_layer_cycle_interval: For models with complex, non-uniform
90-
layer structures (e.g., a repeating pattern of different layer types),
91-
this specifies the number of unique layers in one cycle of the pattern.
92-
For example, gpt-oss has 'sliding_attention' layer followed by a
93-
'full_attention' layer, this value would be 2. This allows
94-
the conversion to correctly map parameters from a scanned MaxText model
95-
where these inhomogeneous layers are packed into a single scanned block.
89+
maxtext_config: The maxtext model configuration.
9690
9791
Returns:
9892
A dictionary containing the parameter mapping, shape mapping, and hook
@@ -105,11 +99,9 @@ def _get_model_mappings(model_name: str, scan_layers: bool, hf_config_dict: dict
10599
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")
106100

107101
return {
108-
"param_mapping": PARAM_MAPPING[model_name](hf_config_dict, scan_layers, inhomogeneous_layer_cycle_interval),
102+
"param_mapping": PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers),
109103
"shape_mapping": HF_SHAPE[model_name](hf_config_dict),
110-
"hook_fn_mapping": HOOK_FNS[model_name](
111-
hf_config_dict, scan_layers, inhomogeneous_layer_cycle_interval, saving_to_hf=True
112-
),
104+
"hook_fn_mapping": HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True),
113105
}
114106

115107

@@ -118,7 +110,7 @@ def _check_param_map_keys(param_map_keys, maxtext_state_keys):
118110
119111
Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by
120112
the flattened parameter map. Keys in the map that are not present in the
121-
checkpoint (common for multi-variant maps like gemma3 or qwen3) are skipped.
113+
checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.
122114
123115
Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one
124116
target key) and are only returned if all constituent keys exist in the checkpoint.

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,11 @@ def main(argv: Sequence[str]) -> None:
273273
# f"model.layers.{global_layer_idx}.input_layernorm.weight",
274274

275275
model_key = config.model_name
276-
param_map_mt_to_hf = PARAM_MAPPING[model_key](
277-
hf_config_obj.to_dict(), config.scan_layers, config.inhomogeneous_layer_cycle_interval
278-
)
276+
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers)
279277

280278
# Example of Hook FN mapping, to perform reshape:
281279
# f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel,
282-
hook_fn_map_mt = HOOK_FNS[model_key](
283-
hf_config_obj.to_dict(), config.scan_layers, config.inhomogeneous_layer_cycle_interval, saving_to_hf=False
284-
)
280+
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False)
285281
max_logging.log("Parameter mappings and hooks obtained.")
286282

287283
# Transform weights

src/MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import jax.numpy as jnp
4242

4343

44-
def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_interval=1):
44+
def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
4545
"""Generates a parameter mapping from MaxText to Hugging Face for Gemma3.
4646
4747
This function creates a dictionary that maps the parameter names from a
@@ -143,7 +143,7 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_in
143143
return mapping
144144

145145

146-
def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
146+
def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
147147
"""Hook functions for Gemma3 parameter conversion.
148148
149149
This function provides a dictionary of transformation functions (hooks) for
@@ -298,7 +298,7 @@ def pos_embed(x, target_shape):
298298
return hooks
299299

300300

301-
def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_interval=1):
301+
def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
302302
"""Returns mapping between MaxText and HuggingFace Gemma2 weight paths.
303303
304304
Args:
@@ -431,7 +431,7 @@ def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_in
431431
return mapping
432432

433433

434-
def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
434+
def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
435435
"""Creates parameter transformation functions for Gemma2 conversion.
436436
437437
This function generates a mapping of transformation functions that handle the
@@ -596,7 +596,7 @@ def from_hf():
596596
return mapping
597597

598598

599-
def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_interval=1):
599+
def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
600600
"""Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
601601
602602
This function generates a dictionary that maps parameter names from a MaxText
@@ -729,7 +729,7 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_int
729729
return mapping
730730

731731

732-
def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
732+
def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
733733
"""Creates parameter transformation functions for Qwen3.
734734
735735
This function provides a dictionary of transformation functions (hooks) for
@@ -814,7 +814,7 @@ def reshape_kernel(input_tensor, target_shape):
814814
return mapping
815815

816816

817-
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_interval=1):
817+
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
818818
"""Returns mapping from MaxText to HuggingFace Deepseek weight paths using f-strings."""
819819
# TODO(shuningjin): add unscan support, b/457820735
820820
if not scan_layers:
@@ -885,7 +885,7 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_
885885
return mapping
886886

887887

888-
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
888+
def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
889889
"""Creates parameter transformation functions for Deepseek using f-strings."""
890890
# TODO(shuningjin): support hf->orbax(scan), b/457820372
891891
if not saving_to_hf:
@@ -937,14 +937,16 @@ def reshape_kernel(input_tensor, target_shape):
937937
mapping[key] = reshape_kernel
938938
return mapping
939939

940+
940941
def DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN():
941942
"""Creates parameter transformation functions for Deepseek."""
942943
return {}
943944

944-
def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle_interval=1):
945+
946+
def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
945947
"""Returns mapping from MaxText gpt-oss to Hugging Face weight paths.
946948
947-
Handles the inhomogeneous scan block structure (layer_cycle_interval)
949+
Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval)
948950
949951
Handles N-to-1 mapping from maxtext to huggingface
950952
- (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
@@ -954,7 +956,8 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle
954956
if not scan_layers:
955957
raise NotImplementedError("Current gpt-oss mapping only supports scan_layers=True")
956958

957-
n_layers = hf_config["num_hidden_layers"]
959+
n_layers = config["num_hidden_layers"] # hf config
960+
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
958961

959962
# Base mapping for non-layer parameters (targeting standard HF keys)
960963
mapping = {
@@ -965,7 +968,7 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle
965968

966969
for block_idx in range(layer_cycle_interval):
967970
# Identify all original HF layer indices that collapse into this block
968-
hf_indices = list(range(block_idx, n_layers, layer_cycle_interval))
971+
hf_indices = list(range(block_idx, n_layers, maxtext_config.layer_cycle_interval))
969972
prefix = f"params-decoder-layers-layers_{block_idx}"
970973

971974
# Layer Norms
@@ -1024,10 +1027,10 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle
10241027
return mapping
10251028

10261029

1027-
def GPT_OSS_TO_HF_PARAM_HOOK_FN(hf_config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
1030+
def GPT_OSS_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
10281031
"""Transformation hooks for gpt-oss parameters.
10291032
1030-
Handles the inhomogeneous scan block structure (layer_cycle_interval)
1033+
Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval)
10311034
10321035
Handles N-to-1 mapping from maxtext to huggingface
10331036
- (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
@@ -1085,6 +1088,7 @@ def interleave(input_tensor, target_shape=None):
10851088
}
10861089

10871090
# Scan over blocks
1091+
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
10881092
for block_idx in range(layer_cycle_interval):
10891093
prefix = f"params-decoder-layers-layers_{block_idx}"
10901094
# Attention Kernels & Biases
@@ -1103,7 +1107,7 @@ def interleave(input_tensor, target_shape=None):
11031107
return hooks
11041108

11051109

1106-
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_interval=1):
1110+
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
11071111
"""Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths.
11081112
11091113
This function combines mappings from different modalities (text, vision, audio, etc.)
@@ -1137,7 +1141,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_
11371141
return mapping
11381142

11391143

1140-
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
1144+
def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
11411145
"""Creates parameter transformation functions for Qwen3-Omni.
11421146
11431147
This function provides a dictionary of transformation functions (hooks) for
@@ -1188,7 +1192,7 @@ def QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN(target_shape=None):
11881192
return {}
11891193

11901194

1191-
def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_interval=1):
1195+
def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
11921196
"""
11931197
Returns a dictionary mapping from MaxText parameter names to
11941198
HuggingFace LLaMA3.1 parameter names.
@@ -1266,7 +1270,7 @@ def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_i
12661270
return mapping
12671271

12681272

1269-
def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, layer_cycle_interval=1, saving_to_hf=False):
1273+
def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
12701274
"""Creates parameter transformation functions for converting between MaxText and
12711275
HuggingFace formats.
12721276

0 commit comments

Comments
 (0)