5353
5454import jax
5555import os
56- from typing import Sequence , Any
56+ from typing import Sequence
5757import time
5858from tqdm import tqdm
59+ import numpy as np
5960
6061from transformers import AutoTokenizer , AutoProcessor
6162
7172)
7273from MaxText .utils .ckpt_conversion .utils .hf_shape import HF_SHAPE
7374from MaxText .utils .ckpt_conversion .utils .hf_model_configs import HF_MODEL_CONFIGS
74- from MaxText .utils .ckpt_conversion .utils .utils import ( process_leaf_param , save_model_files , HF_IDS )
75+ from MaxText .utils .ckpt_conversion .utils .utils import process_maxtext_param , save_model_files , HF_IDS
7576
7677
7778os .environ ["JAX_PLATFORMS" ] = "cpu"
7879os .environ ["XLA_FLAGS" ] = "--xla_force_host_platform_device_count=16"
7980
8081
81- def _get_model_mappings (model_name : str , scan_layers : bool , config_dict : dict ):
82+ def _get_model_mappings (model_name : str , scan_layers : bool , hf_config_dict : dict , maxtext_config : pyconfig . HyperParameters ):
8283 """Retrieves parameter, shape, and hook function mappings for the model.
8384
8485 Args:
8586 model_name: The name of the model (e.g., "gemma2-2b").
8687 scan_layers: Boolean indicating if the model was trained with scanned layers.
87- config_dict: The Hugging Face model configuration dictionary.
88+ hf_config_dict: The Hugging Face model configuration dictionary.
89+ maxtext_config: The maxtext model configuration.
8890
8991 Returns:
9092 A dictionary containing the parameter mapping, shape mapping, and hook
@@ -97,12 +99,65 @@ def _get_model_mappings(model_name: str, scan_layers: bool, config_dict: dict):
9799 raise ValueError (f"Mappings not found for model: { model_name } . Available PARAM_MAPPING keys: { PARAM_MAPPING .keys ()} " )
98100
99101 return {
100- "param_mapping" : PARAM_MAPPING [model_name ](config_dict , scan_layers ),
101- "shape_mapping" : HF_SHAPE [model_name ](config_dict ),
102- "hook_fn_mapping" : HOOK_FNS [model_name ](config_dict , scan_layers , saving_to_hf = True ),
102+ "param_mapping" : PARAM_MAPPING [model_name ](hf_config_dict , maxtext_config , scan_layers ),
103+ "shape_mapping" : HF_SHAPE [model_name ](hf_config_dict ),
104+ "hook_fn_mapping" : HOOK_FNS [model_name ](hf_config_dict , maxtext_config , scan_layers , saving_to_hf = True ),
103105 }
104106
105107
108+ def _check_param_map_keys (param_map_keys , maxtext_state_keys ):
109+ """Validates map coverage, handles N-to-1 mappings, and filters unused keys.
110+
111+ Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by
112+ the flattened parameter map. Keys in the map that are not present in the
113+ checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.
114+
115+ Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one
116+ target key) and are only returned if all constituent keys exist in the checkpoint.
117+
118+ Args:
119+ param_map_keys: Keys from the parameter mapping (strings or N-to-1 tuples).
120+ maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.
121+
122+ Returns:
123+ A list of 'filtered' mapping keys (strings or tuples) that are fully present
124+ and valid based on `maxtext_state_keys`.
125+
126+ Raises:
127+ ValueError: If `maxtext_state_keys` is NOT a subset of the flattened
128+ `param_map_keys`.
129+ """
130+ flattened_map_keys = set ()
131+ for key in param_map_keys :
132+ if isinstance (key , tuple ):
133+ flattened_map_keys .update (key )
134+ else :
135+ flattened_map_keys .add (key )
136+
137+ # every maxtext state key must be covered by param map
138+ missing_keys = maxtext_state_keys - flattened_map_keys
139+ if missing_keys :
140+ raise ValueError (
141+ "maxtext_state_dict must be a subset of flattened param_map"
142+ + f"\n param map\n { param_map_keys } "
143+ + f"\n maxtext:\n { maxtext_state_keys } "
144+ )
145+
146+ # param map may have extra keys
147+ extra_keys = flattened_map_keys - maxtext_state_keys
148+ if extra_keys :
149+ max_logging .log (f"Warning: extra keys in param_map are skipped: { extra_keys } " )
150+
151+ # skip extra keys in param map
152+ filtered_map_keys = []
153+ for key in param_map_keys :
154+ if (isinstance (key , str ) and key in maxtext_state_keys ) or (
155+ isinstance (key , tuple ) and all (k in maxtext_state_keys for k in key )
156+ ):
157+ filtered_map_keys .append (key )
158+ return filtered_map_keys
159+
160+
106161def main (argv : Sequence [str ]) -> None :
107162 """Main function to convert a MaxText checkpoint to HuggingFace format.
108163
@@ -156,29 +211,52 @@ def main(argv: Sequence[str]) -> None:
156211 processor = AutoProcessor .from_pretrained (hf_tokenizer_id , token = hf_token ) if config .use_multimodal else None
157212
158213 # 3. Get parameter mappings
159- mappings = _get_model_mappings (model_key , config .scan_layers , hf_config_obj .to_dict ())
214+ mappings = _get_model_mappings (model_key , config .scan_layers , hf_config_obj .to_dict (), config )
160215 param_map = mappings ["param_mapping" ]
161216 shape_map = mappings ["shape_mapping" ] # HF target shapes
162217 hook_fn_map = mappings ["hook_fn_mapping" ]
163218
164219 # 4. Transform Weights
165- transformed_hf_weights : dict [str , Any ] = {}
166-
167220 # MaxText `engine.load_params()` returns `state.params` (a FrozenDict).
168221 # The actual weights are typically under `state.params['params']`.
169222 actual_weights_dict = loaded_params_from_engine .get ("params" )
170223 if actual_weights_dict is None :
171224 raise ValueError ("Loaded parameters from engine do not contain a 'params' key. Structure might be unexpected." )
172-
173225 leaves_with_paths = jax .tree_util .tree_leaves_with_path (actual_weights_dict )
174226
175- # traverse leavse to build: mt_param_key:mt_weights
227+ # Construct maxtext_state_dict: {parameter name: parameter weight}
228+ maxtext_state_dict = {}
229+ for path_tuple , leaf_value in leaves_with_paths :
230+ # Construct maxtext_param_key from path_tuple
231+ maxtext_param_key = "params-" + "-" .join (k .key for k in path_tuple )
232+ # Check leaf value is an array
233+ if not isinstance (leaf_value , (jax .Array , np .ndarray )):
234+ raise ValueError (f"Leaf value for { maxtext_param_key } is not an array. Type: { type (leaf_value )} ." )
235+ maxtext_state_dict [maxtext_param_key ] = leaf_value
236+
237+ # The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
238+ # Check maxtext_state_dict is a subset of flattened param_map
239+ # Skip extra keys from param_map
240+ filtered_map_keys = _check_param_map_keys (param_map .keys (), maxtext_state_dict .keys ())
241+
242+ # Iterate through the parameter map to transform and collect weights.
243+ # This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings
244+ # (where multiple MaxText weights are combined into a single HF weight).
176245 max_logging .log ("\n Proccessing weight..." )
177246 start = time .time ()
178247 processed_params_list = []
179- for path_tuple_iter , leaf_value_iter in tqdm (leaves_with_paths , total = len (leaves_with_paths )):
180- processed_params = process_leaf_param (path_tuple_iter , leaf_value_iter , param_map , shape_map , hook_fn_map , config )
248+
249+ for key in tqdm (filtered_map_keys , total = len (filtered_map_keys )):
250+ if isinstance (key , tuple ):
251+ # if key is tuple of param names, weight is list of param weights
252+ weight = [maxtext_state_dict [subkey ] for subkey in key ]
253+ else :
254+ # if key is single param name, weight is single param weight
255+ weight = maxtext_state_dict [key ]
256+
257+ processed_params = process_maxtext_param (key , weight , param_map , hook_fn_map , shape_map , config )
181258 processed_params_list .extend (processed_params )
259+
182260 transformed_hf_weights = dict (processed_params_list )
183261 max_logging .log (f"Elapse: { (time .time () - start ) / 60 :.2f} min" )
184262
0 commit comments