Skip to content

Commit 84590a4

Browse files
committed
.
1 parent 20f2de7 commit 84590a4

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _check_param_map_keys(param_map_keys, maxtext_state_keys):
120120
maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.
121121
122122
Returns:
123-
A set of 'filtered' mapping keys (strings or tuples) that are fully present
123+
A list of 'filtered' mapping keys (strings or tuples) that are fully present
124124
and valid based on `maxtext_state_keys`.
125125
126126
Raises:
@@ -149,12 +149,12 @@ def _check_param_map_keys(param_map_keys, maxtext_state_keys):
149149
max_logging.log(f"Warning: extra keys in param_map are skipped: {extra_keys}")
150150

151151
# skip extra keys in param map
152-
filtered_map_keys = set()
152+
filtered_map_keys = []
153153
for key in param_map_keys:
154-
if (isinstance(key, tuple) and all(k in maxtext_state_keys for k in key)) or (
155-
isinstance(key, str) and key in maxtext_state_keys
154+
if (isinstance(key, str) and key in maxtext_state_keys) or (
155+
isinstance(key, tuple) and all(k in maxtext_state_keys for k in key)
156156
):
157-
filtered_map_keys.add(key)
157+
filtered_map_keys.append(key)
158158
return filtered_map_keys
159159

160160

@@ -211,9 +211,7 @@ def main(argv: Sequence[str]) -> None:
211211
processor = AutoProcessor.from_pretrained(hf_tokenizer_id, token=hf_token) if config.use_multimodal else None
212212

213213
# 3. Get parameter mappings
214-
mappings = _get_model_mappings(
215-
model_key, config.scan_layers, hf_config_obj.to_dict(), config.inhomogeneous_layer_cycle_interval
216-
)
214+
mappings = _get_model_mappings(model_key, config.scan_layers, hf_config_obj.to_dict(), config)
217215
param_map = mappings["param_mapping"]
218216
shape_map = mappings["shape_mapping"] # HF target shapes
219217
hook_fn_map = mappings["hook_fn_mapping"]

src/MaxText/utils/ckpt_conversion/utils/param_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals
968968

969969
for block_idx in range(layer_cycle_interval):
970970
# Identify all original HF layer indices that collapse into this block
971-
hf_indices = list(range(block_idx, n_layers, maxtext_config.layer_cycle_interval))
971+
hf_indices = list(range(block_idx, n_layers, layer_cycle_interval))
972972
prefix = f"params-decoder-layers-layers_{block_idx}"
973973

974974
# Layer Norms

0 commit comments

Comments
 (0)