Skip to content

[mxpf8] Make mxfp8 dim1 cast kernel configurable #1401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jul 15, 2025

Summary

  • We recently developed a CUDA kernel in torchao to perform mxfp8 casting with scaling along dim1, which is ~1.4x faster than the previous Triton implementation, this results in e2e training speedup of 1.5% - 2.5% with torchtitan Llama3 8b with FSDP=4/8: Add CUDA kernel for MXFP8 dim1 casting ao#2513
  • The integration into torchao is finished (integration of new mxfp8 casting cuda kernel ao#2564), so we need to update torchtitan to make the kernel choice for mxfp8 dim1 cast configurable to "triton", "cuda", or "torch".

Test plan

  • Triton: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"
  • Cuda: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"

@danielvegamyhre
Copy link
Contributor Author

cc @tianyu-l @vkuzo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants