4141import jax .numpy as jnp
4242
4343
44- def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False , layer_cycle_interval = 1 ):
44+ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
4545 """Generates a parameter mapping from MaxText to Hugging Face for Gemma3.
4646
4747 This function creates a dictionary that maps the parameter names from a
@@ -143,7 +143,7 @@ def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_in
143143 return mapping
144144
145145
146- def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
146+ def GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
147147 """Hook functions for Gemma3 parameter conversion.
148148
149149 This function provides a dictionary of transformation functions (hooks) for
@@ -298,7 +298,7 @@ def pos_embed(x, target_shape):
298298 return hooks
299299
300300
301- def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False , layer_cycle_interval = 1 ):
301+ def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
302302 """Returns mapping between MaxText and HuggingFace Gemma2 weight paths.
303303
304304 Args:
@@ -431,7 +431,7 @@ def GEMMA2_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_in
431431 return mapping
432432
433433
434- def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
434+ def GEMMA2_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
435435 """Creates parameter transformation functions for Gemma2 conversion.
436436
437437 This function generates a mapping of transformation functions that handle the
@@ -596,7 +596,7 @@ def from_hf():
596596 return mapping
597597
598598
599- def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False , layer_cycle_interval = 1 ):
599+ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
600600 """Returns mapping from MaxText to HuggingFace Qwen3 weight paths.
601601
602602 This function generates a dictionary that maps parameter names from a MaxText
@@ -729,7 +729,7 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_int
729729 return mapping
730730
731731
732- def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
732+ def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
733733 """Creates parameter transformation functions for Qwen3.
734734
735735 This function provides a dictionary of transformation functions (hooks) for
@@ -814,7 +814,7 @@ def reshape_kernel(input_tensor, target_shape):
814814 return mapping
815815
816816
817- def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False , layer_cycle_interval = 1 ):
817+ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
818818 """Returns mapping from MaxText to HuggingFace Deepseek weight paths using f-strings."""
819819 # TODO(shuningjin): add unscan support, b/457820735
820820 if not scan_layers :
@@ -885,7 +885,7 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_
885885 return mapping
886886
887887
888- def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
888+ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
889889 """Creates parameter transformation functions for Deepseek using f-strings."""
890890 # TODO(shuningjin): support hf->orbax(scan), b/457820372
891891 if not saving_to_hf :
@@ -937,14 +937,16 @@ def reshape_kernel(input_tensor, target_shape):
937937 mapping [key ] = reshape_kernel
938938 return mapping
939939
940+
940941def DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN ():
941942 """Creates parameter transformation functions for Deepseek."""
942943 return {}
943944
944- def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (hf_config , scan_layers = True , layer_cycle_interval = 1 ):
945+
946+ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
945947 """Returns mapping from MaxText gpt-oss to Hugging Face weight paths.
946948
947- Handles the inhomogeneous scan block structure (layer_cycle_interval )
949+ Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval )
948950
949951 Handles N-to-1 mapping from maxtext to huggingface
950952 - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
@@ -954,7 +956,8 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle
954956 if not scan_layers :
955957 raise NotImplementedError ("Current gpt-oss mapping only supports scan_layers=True" )
956958
957- n_layers = hf_config ["num_hidden_layers" ]
959+ n_layers = config ["num_hidden_layers" ] # hf config
960+ layer_cycle_interval = maxtext_config .inhomogeneous_layer_cycle_interval
958961
959962 # Base mapping for non-layer parameters (targeting standard HF keys)
960963 mapping = {
@@ -965,7 +968,7 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle
965968
966969 for block_idx in range (layer_cycle_interval ):
967970 # Identify all original HF layer indices that collapse into this block
968- hf_indices = list (range (block_idx , n_layers , layer_cycle_interval ))
971+ hf_indices = list (range (block_idx , n_layers , maxtext_config . layer_cycle_interval ))
969972 prefix = f"params-decoder-layers-layers_{ block_idx } "
970973
971974 # Layer Norms
@@ -1024,10 +1027,10 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(hf_config, scan_layers=True, layer_cycle
10241027 return mapping
10251028
10261029
1027- def GPT_OSS_TO_HF_PARAM_HOOK_FN (hf_config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
1030+ def GPT_OSS_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
10281031 """Transformation hooks for gpt-oss parameters.
10291032
1030- Handles the inhomogeneous scan block structure (layer_cycle_interval )
1033+ Handles the inhomogeneous scan block structure (inhomogeneous_layer_cycle_interval )
10311034
10321035 Handles N-to-1 mapping from maxtext to huggingface
10331036 - (GptOssMlp-wi_0, GptOssMlp-wi_1): mlp.experts.gate_up_proj
@@ -1085,6 +1088,7 @@ def interleave(input_tensor, target_shape=None):
10851088 }
10861089
10871090 # Scan over blocks
1091+ layer_cycle_interval = maxtext_config .inhomogeneous_layer_cycle_interval
10881092 for block_idx in range (layer_cycle_interval ):
10891093 prefix = f"params-decoder-layers-layers_{ block_idx } "
10901094 # Attention Kernels & Biases
@@ -1103,7 +1107,7 @@ def interleave(input_tensor, target_shape=None):
11031107 return hooks
11041108
11051109
1106- def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False , layer_cycle_interval = 1 ):
1110+ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
11071111 """Returns mapping from MaxText to HuggingFace Qwen3-Omni weight paths.
11081112
11091113 This function combines mappings from different modalities (text, vision, audio, etc.)
@@ -1137,7 +1141,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_
11371141 return mapping
11381142
11391143
1140- def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
1144+ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
11411145 """Creates parameter transformation functions for Qwen3-Omni.
11421146
11431147 This function provides a dictionary of transformation functions (hooks) for
@@ -1188,7 +1192,7 @@ def QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN(target_shape=None):
11881192 return {}
11891193
11901194
1191- def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING (config , scan_layers = False , layer_cycle_interval = 1 ):
1195+ def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING (config , maxtext_config , scan_layers = False ):
11921196 """
11931197 Returns a dictionary mapping from MaxText parameter names to
11941198 HuggingFace LLaMA3.1 parameter names.
@@ -1266,7 +1270,7 @@ def LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False, layer_cycle_i
12661270 return mapping
12671271
12681272
1269- def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , scan_layers = False , layer_cycle_interval = 1 , saving_to_hf = False ):
1273+ def LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN (config , maxtext_config , scan_layers = False , saving_to_hf = False ):
12701274 """Creates parameter transformation functions for converting between MaxText and
12711275 HuggingFace formats.
12721276
0 commit comments