@@ -120,7 +120,7 @@ def _check_param_map_keys(param_map_keys, maxtext_state_keys):
120120 maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.
121121
122122 Returns:
123- A set of 'filtered' mapping keys (strings or tuples) that are fully present
123+ A list of 'filtered' mapping keys (strings or tuples) that are fully present
124124 and valid based on `maxtext_state_keys`.
125125
126126 Raises:
@@ -149,12 +149,12 @@ def _check_param_map_keys(param_map_keys, maxtext_state_keys):
149149 max_logging .log (f"Warning: extra keys in param_map are skipped: { extra_keys } " )
150150
151151 # skip extra keys in param map
152- filtered_map_keys = set ()
152+ filtered_map_keys = []
153153 for key in param_map_keys :
154- if (isinstance (key , tuple ) and all ( k in maxtext_state_keys for k in key ) ) or (
155- isinstance (key , str ) and key in maxtext_state_keys
154+ if (isinstance (key , str ) and key in maxtext_state_keys ) or (
155+ isinstance (key , tuple ) and all ( k in maxtext_state_keys for k in key )
156156 ):
157- filtered_map_keys .add (key )
157+ filtered_map_keys .append (key )
158158 return filtered_map_keys
159159
160160
@@ -211,9 +211,7 @@ def main(argv: Sequence[str]) -> None:
211211 processor = AutoProcessor .from_pretrained (hf_tokenizer_id , token = hf_token ) if config .use_multimodal else None
212212
213213 # 3. Get parameter mappings
214- mappings = _get_model_mappings (
215- model_key , config .scan_layers , hf_config_obj .to_dict (), config .inhomogeneous_layer_cycle_interval
216- )
214+ mappings = _get_model_mappings (model_key , config .scan_layers , hf_config_obj .to_dict (), config )
217215 param_map = mappings ["param_mapping" ]
218216 shape_map = mappings ["shape_mapping" ] # HF target shapes
219217 hook_fn_map = mappings ["hook_fn_mapping" ]
0 commit comments