Skip to content

mx: expose scaling calculation methods in training UX #2620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 4, 2025
Merged

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 28, 2025

Summary:

To prepare MX for graduating out of prototype, exposes the scaling mode at the top level training config and recipe. The two well supported scaling modes are FLOOR and RCEIL. CEIL has not been well tested, and EVEN does not work with torch.compile yet. We may further adjust these options in future PRs.

Note that for RCEIL, the dim0 casts are not yet using hardware accelerated instructions, so overall performance is currently slightly below FLOOR. We can improve this in a future PR.

Test Plan:

unit tests:

pytest test/prototype/mx_formats/ -s -x

performance on llama 3 8b training:

// requires https://github.com/pytorch/torchtitan/pull/1512

// requires hardcoding dim1 kernel choice to CUDA
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8_cublas"
...tps ~10.5k

with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8_cublas_rceil"
...tps ~10.3k

performance on individual cast

(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_floor
time_us 184.38400328159332
mem_bw_gbps 4413.045391781173
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_rceil
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_rceil
time_us 143.39199662208557
mem_bw_gbps 5674.619191924083

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 28, 2025

Copy link

pytorch-bot bot commented Jul 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2620

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 29354b7 with merge base d05e54f (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Jul 28, 2025
Summary:

Test Plan:

performance on individual cast

```bash
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_floor
time_us 184.38400328159332
mem_bw_gbps 4413.045391781173
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_rceil
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_rceil
time_us 143.39199662208557
mem_bw_gbps 5674.619191924083
```

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: aec9d07
ghstack-comment-id: 3129597761
Pull-Request: #2620
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 28, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Aug 1, 2025
Summary:

Test Plan:

performance on individual cast

```bash
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_floor
time_us 184.38400328159332
mem_bw_gbps 4413.045391781173
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_rceil
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_rceil
time_us 143.39199662208557
mem_bw_gbps 5674.619191924083
```

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 0531997
ghstack-comment-id: 3129597761
Pull-Request: #2620
@vkuzo vkuzo added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Aug 1, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Aug 1, 2025
Summary:

Test Plan:

performance on individual cast

```bash
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_floor
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_floor
time_us 184.38400328159332
mem_bw_gbps 4413.045391781173
(pytorch_nightly) [[email protected] ~/local/ao
(20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py
--mode dim0_mx_rceil
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.9.0.dev20250724+cu128
triton version: 3.4.0
mode: dim0_mx_rceil
time_us 143.39199662208557
mem_bw_gbps 5674.619191924083
```

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 2b36350
ghstack-comment-id: 3129597761
Pull-Request: #2620
@vkuzo vkuzo merged commit c993d64 into main Aug 4, 2025
49 of 54 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants