Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Nov 10, 2025

Description

The goal is to support gpt-oss in the checkpoint utility tool. However, two features are previously missing from the utility, which need to be implemented first

  • multiple maxtext param are mapped to a single hf param
  • inhomogeneous scan block

1 Support many-to-one transform

gpt-oss has two maxtext params mapped to one huggingface key

  • GptOssMlp-wi_0, GptOssMlp-wi_1 -- mlp.experts.gate_up_proj
  • GptOssMlp-wi_0_bias, GptOssMlp-wi_1_bias -- mlp.experts.gate_up_proj_bias
  • wi_0=wi_0_1[..., ::2], wi_1 = wi_0_1[..., 1::2] -- wi_0_1

To implement this many-to-one mapping

  • param_mapping.py: previous mt: hf, now also has tuple as key(mt1, mt2): hf
  • to_huggingface.py: loop over param_map instead leaves (pre-check coverage), when key is tuple, collect weights into a list
  • utils.py - process_maxtext_param, previously handle key is str + weight is single array, now also handle key tuple + list weights

other models can follow the structure similarly, e.g.,

  • llama4, hf weight is split into two maxtext keys (here)

2 Support inhomogeneous scan block

gpt-oss has inhomogenous cycle interval = 2, both 20b and 120b: sliding attention->full attention

param_mapping.py

  • add layer_cycle_interval argument to all MAXTEXT_TO_HF_PARAM_MAPPING and MAXTEXT_TO_HF_PARAM_HOOK_FN
  • used in GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING and GPT_OSS_TO_HF_PARAM_HOOK_FN
    • interval = 2, maxtext block 0 -- hf layer 0, 2, 4, ..., maxtext block 1 -- hf layer 1, 3, 5, ...

other models can follow the structure similarly, e.g.,

  • llama4. interval = 4. scout: rope->rope->rope->nope, maverick: mlp+rope->moe+rope->mlp+rope->moe+nope
  • qwen3-next, interval = 4, interleaved attention
  • note gemma3 is slightly different: it has extra layers, as n_layers is not divisible by interval

3 Add gpt-oss: orbax (scan) to hf

Future work

  • gpt-oss: hf -> orbax (scan), accomodate the other direction of many-to-one transform
  • gpt-oss: add unscan
  • both track in b/459541579

Tests

run orbax scan -> hf, and forward logit check

gpt-oss-20b

# cpu
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/scan-bf16-v2-2025-09-08-06-52-03/0/items \
base_output_directory=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-$ID \
scan_layers=true \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/6054915302621184

/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-2025-11-13-08-11-24

# cpu
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/scan-bf16-v2-2025-09-08-06-52-03/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-2025-11-13-08-11-24 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/4806878806802432

gpt-oss-120b

# cpu
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gpt-oss-120b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-120b/scan-bf16-v2-2025-09-08-07-19-09/0/items \
base_output_directory=/home/shuningjin/gpt-oss-120b/gpt-oss-120b-hf-$ID \
scan_layers=true \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/6024552484306944

/home/shuningjin/gpt-oss-120b/gpt-oss-120b-hf-2025-11-10-11-55-23

# cpu
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-120b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-120b/scan-bf16-v2-2025-09-08-07-19-09/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-120b/gpt-oss-120b-hf-2025-11-10-11-55-23 \
tokenizer_path=openai/gpt-oss-120b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5431988781711360

check other models just in case

qwen3-4b

# cpu
ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=qwen3-4b \
load_parameters_path=gs://maxtext-qwen/qwen3/4b/unscanned/2025-08-04-21-31/0/items \
base_output_directory=/tmp/conversion/$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/4788178485641216

CKPT=gs://maxtext-qwen/qwen3/4b/unscanned/2025-08-04-21-31/0/items
hf_model_path=/tmp/conversion/2025-11-13-08-31-06
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-4b attention=dot_product \
override_model_config=true enable_dropout=false tokenizer_type=huggingface \
load_parameters_path=$CKPT scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=8 \
tokenizer_path=Qwen/Qwen3-4B --run_hf_model=True --hf_model_path=$hf_model_path \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
skip_jax_distributed_system=True

https://paste.googleplex.com/5246501760663552

gemma3-4b

ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gemma3-4b \
load_parameters_path=gs://maxtext-gemma/unified/gemma3/4b/unscanned/2025-08-05-18-18/0/items \
base_output_directory=/tmp/conversion/$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024 \
hf_access_token=$HF_TOKEN

https://paste.googleplex.com/5830899908345856

CKPT=gs://maxtext-gemma/unified/gemma3/4b/unscanned/2025-08-05-18-18/0/items
hf_model_path=/tmp/conversion/2025-11-13-08-37-47
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gemma3-4b attention=dot_product \
override_model_config=true enable_dropout=false tokenizer_type=huggingface \
load_parameters_path=$CKPT scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=8 \
tokenizer_path=$hf_model_path --run_hf_model=True --hf_model_path=$hf_model_path \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
skip_jax_distributed_system=True

https://paste.googleplex.com/6449070222737408

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces significant enhancements to the checkpoint conversion utility to support gpt-oss models. The addition of many-to-one parameter mapping and support for inhomogeneous scan blocks are well-implemented and make the tool more flexible for future models. The overall code quality is high, with clear logic and necessary updates to configurations and mappings.

🔍 General Feedback

  • The refactoring in to_huggingface.py and utils.py to support the new mapping features is excellent and improves robustness.
  • The use of NotImplementedError for features planned as future work is a good practice.
  • I've identified one potential bug where the logic for detecting unscanned MoE layers might fail with the new tuple-based keys for many-to-one mappings. Please see the inline comment.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! LGTM at high level, just a few comments.

@shuningjin shuningjin changed the title Checkpoint conversion utility: gpt-oss orbax scan to hf, inhomogeneous scan block, many-to-one transform Checkpoint conversion utility: gpt-oss orbax scan to hf, many-to-one transform, inhomogeneous scan block Nov 13, 2025
@github-actions
Copy link

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request introduces support for GPT-OSS models in the checkpoint conversion utility, addressing complex scenarios like many-to-one parameter mappings and inhomogeneous scan blocks. The changes are well-structured and include robust validation for parameter keys.

🔍 General Feedback

  • The new _check_param_map_keys function significantly improves the robustness of parameter mapping validation.
  • The process_maxtext_param function (formerly process_leaf_param) has been refactored effectively to handle various mapping complexities.
  • Comprehensive docstrings and comments enhance the maintainability of the new and modified code.
  • Consider adding tracking issues or TODOs for the NotImplementedError cases to ensure future support for unscanned layers and reverse conversions.

@copybara-service copybara-service bot merged commit c92d9d9 into main Nov 14, 2025
38 checks passed
@copybara-service copybara-service bot deleted the shuningjin-ckpt-gpt branch November 14, 2025 05:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants