5656from typing import Sequence , Any
5757import time
5858from tqdm import tqdm
59+ import numpy as np
5960
6061from transformers import AutoTokenizer , AutoProcessor
6162
7172)
7273from MaxText .utils .ckpt_conversion .utils .hf_shape import HF_SHAPE
7374from MaxText .utils .ckpt_conversion .utils .hf_model_configs import HF_MODEL_CONFIGS
74- from MaxText .utils .ckpt_conversion .utils .utils import ( process_leaf_param , save_model_files , HF_IDS )
75+ from MaxText .utils .ckpt_conversion .utils .utils import process_maxtext_param , save_model_files , HF_IDS
7576
7677
7778os .environ ["JAX_PLATFORMS" ] = "cpu"
7879os .environ ["XLA_FLAGS" ] = "--xla_force_host_platform_device_count=16"
7980
8081
81- def _get_model_mappings (model_name : str , scan_layers : bool , config_dict : dict ):
82+ def _get_model_mappings (model_name : str , scan_layers : bool , hf_config_dict : dict , inhomogeneous_layer_cycle_interval : int ):
8283 """Retrieves parameter, shape, and hook function mappings for the model.
8384
8485 Args:
8586 model_name: The name of the model (e.g., "gemma2-2b").
8687 scan_layers: Boolean indicating if the model was trained with scanned layers.
87- config_dict: The Hugging Face model configuration dictionary.
88+ hf_config_dict: The Hugging Face model configuration dictionary.
89+ inhomogeneous_layer_cycle_interval: For complex architectures (llama4, gpt-oss), there are repeated sets of
90+ n inhomogeneous layers. The scan structure has n blocks.
8891
8992 Returns:
9093 A dictionary containing the parameter mapping, shape mapping, and hook
@@ -97,12 +100,42 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
97100 raise ValueError (f"Mappings not found for model: { model_name } . Available PARAM_MAPPING keys: { PARAM_MAPPING .keys ()} " )
98101
99102 return {
100- "param_mapping" : PARAM_MAPPING [model_name ](config_dict , scan_layers ),
101- "shape_mapping" : HF_SHAPE [model_name ](config_dict ),
102- "hook_fn_mapping" : HOOK_FNS [model_name ](config_dict , scan_layers , saving_to_hf = True ),
103+ "param_mapping" : PARAM_MAPPING [model_name ](hf_config_dict , scan_layers , inhomogeneous_layer_cycle_interval ),
104+ "shape_mapping" : HF_SHAPE [model_name ](hf_config_dict ),
105+ "hook_fn_mapping" : HOOK_FNS [model_name ](
106+ hf_config_dict , scan_layers , inhomogeneous_layer_cycle_interval , saving_to_hf = True
107+ ),
103108 }
104109
105110
111+ def _check_param_map_keys (param_map_keys , maxtext_state_keys ):
112+ """Verifies that the keys in the parameter map match the keys in the MaxText state.
113+
114+ This function handles cases where the parameter map contains tuples as keys,
115+ which represent N-to-1 mappings (multiple MaxText parameters combined into one
116+ Hugging Face parameter). It flattens these tuples to ensure every individual
117+ MaxText parameter is accounted for.
118+
119+ Args:
120+ param_map_keys: The keys from the parameter mapping dictionary.
121+ maxtext_state_keys: A set of all parameter keys from the loaded MaxText
122+ checkpoint.
123+ """
124+ flattened_map_keys = set ()
125+ for key in param_map_keys :
126+ if isinstance (key , tuple ):
127+ flattened_map_keys .update (key )
128+ else :
129+ flattened_map_keys .add (key )
130+
131+ if flattened_map_keys != maxtext_state_keys :
132+ raise ValueError (
133+ f"param_map and maxtext_state_dict have different keys."
134+ + f"\n param map\n { param_map_keys } "
135+ + f"\n maxtext:\n { maxtext_state_keys } "
136+ )
137+
138+
106139def main (argv : Sequence [str ]) -> None :
107140 """Main function to convert a MaxText checkpoint to HuggingFace format.
108141
@@ -156,29 +189,51 @@ def main(argv: Sequence[str]) -> None:
156189 processor = AutoProcessor .from_pretrained (hf_tokenizer_id , token = hf_token ) if config .use_multimodal else None
157190
158191 # 3. Get parameter mappings
159- mappings = _get_model_mappings (model_key , config .scan_layers , hf_config_obj .to_dict ())
192+ mappings = _get_model_mappings (
193+ model_key , config .scan_layers , hf_config_obj .to_dict (), config .inhomogeneous_layer_cycle_interval
194+ )
160195 param_map = mappings ["param_mapping" ]
161196 shape_map = mappings ["shape_mapping" ] # HF target shapes
162197 hook_fn_map = mappings ["hook_fn_mapping" ]
163198
164199 # 4. Transform Weights
165- transformed_hf_weights : dict [str , Any ] = {}
166-
167200 # MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
168201 # The actual weights are typically under `state.params['params']`.
169202 actual_weights_dict = loaded_params_from_engine .get ("params" )
170203 if actual_weights_dict is None :
171204 raise ValueError ("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected." )
172-
173205 leaves_with_paths = jax .tree_util .tree_leaves_with_path (actual_weights_dict )
174206
175- # traverse leavse to build: mt_param_key:mt_weights
207+ # Construct maxtext_state_dict: {parameter name: parameter weight}
208+ maxtext_state_dict = {}
209+ for path_tuple , leaf_value in leaves_with_paths :
210+ # Construct maxtext_param_key from path_tuple
211+ maxtext_param_key = "params-" + "-" .join (k .key for k in path_tuple )
212+ # Check leaf value is an array
213+ if not isinstance (leaf_value , (jax .Array , np .ndarray )):
214+ raise ValueError (f"Leaf value for { maxtext_param_key } is not an array. Type: { type (leaf_value )} ." )
215+ maxtext_state_dict [maxtext_param_key ] = leaf_value
216+
217+ # The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
218+ # Check param_map after flattening has the same keys as maxtext_state_dict
219+ _check_param_map_keys (param_map .keys (), maxtext_state_dict .keys ())
220+
221+ # Iterate over param_map to build: mt_param_key:mt_weights
176222 max_logging .log ("\n Proccessing weight..." )
177223 start = time .time ()
178224 processed_params_list = []
179- for path_tuple_iter , leaf_value_iter in tqdm (leaves_with_paths , total = len (leaves_with_paths )):
180- processed_params = process_leaf_param (path_tuple_iter , leaf_value_iter , param_map , shape_map , hook_fn_map , config )
225+
226+ for key in tqdm (param_map , total = len (param_map )):
227+ if isinstance (key , tuple ):
228+ # key is tuple of param names, weight is list of param weights
229+ weight = [maxtext_state_dict [subkey ] for subkey in key ]
230+ else :
231+ # key is single param name, weight is single param weight
232+ weight = maxtext_state_dict [key ]
233+
234+ processed_params = process_maxtext_param (key , weight , param_map , hook_fn_map , shape_map , config )
181235 processed_params_list .extend (processed_params )
236+
182237 transformed_hf_weights = dict (processed_params_list )
183238 max_logging .log (f"Elapse: { (time .time () - start ) / 60 :.2f} min" )
184239
0 commit comments