You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Stacked PRs:
* __->__#1427
--- --- ---
make mxfp8 dim1 cast kernel configurable
## Summary
- We recently added a new CUDA kernel for the mxfp8 dim1 cast which is
~1.4x faster than the existing Triton kernel or torch.compile, and using
it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b
using FSDP=4/8 (pytorch/ao#2513). The
integration work for composability with torch.compile + FSDP is complete
as well: pytorch/ao#2564
- This PR updates the mxfp8 user facing API to replace the boolean flag
`"--mx.use_triton_for_dim1_cast=[true|false]` to
`mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]`
## Test plan
- Triton: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="triton"`
- Cuda: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="cuda"`
- Torch: `NGPU=8
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --model.converters="mx"
--mx.recipe_name="mxfp8" --training.compile
--mx.mxfp8_dim1_cast_kernel_choice="torch"`
## Limitations
- TP is currently not supported yet, as both the Triton kernel and CUDA
kernel are affected by an issue: `RuntimeError: Attempting to use
FunctionalTensor on its own. Instead, please use it with a corresponding
FunctionalTensorMode()`. This is a known issue we were talking to Brian
about, will continue following up on it.
0 commit comments