Skip to content

Commit f3e2a75

Browse files
make mxfp8 dim1 cast kernel configurable (#1427)
Stacked PRs: * __->__#1427 --- --- --- make mxfp8 dim1 cast kernel configurable ## Summary - We recently added a new CUDA kernel for the mxfp8 dim1 cast which is ~1.4x faster than the existing Triton kernel or torch.compile, and using it results in an e2e training speedup of +1.5-2.5% TPS with Llama3 8b using FSDP=4/8 (pytorch/ao#2513). The integration work for composability with torch.compile + FSDP is complete as well: pytorch/ao#2564 - This PR updates the mxfp8 user facing API to replace the boolean flag `"--mx.use_triton_for_dim1_cast=[true|false]` to `mxfp8_dim1_cast_kernel_choice=[triton|cuda|torch]` ## Test plan - Triton: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="triton"` - Cuda: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="cuda"` - Torch: `NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --model.converters="mx" --mx.recipe_name="mxfp8" --training.compile --mx.mxfp8_dim1_cast_kernel_choice="torch"` ## Limitations - TP is currently not supported yet, as both the Triton kernel and CUDA kernel are affected by an issue: `RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()`. This is a known issue we were talking to Brian about, will continue following up on it.
1 parent 38a9d30 commit f3e2a75

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,30 +40,39 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4040
"torchao is not installed. Please install it to use MXFP8 linear layers."
4141
)
4242
torchao_version = version("torchao")
43-
mxfp8_min_version = "0.11.0"
44-
if torchao_version < mxfp8_min_version:
43+
44+
# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
45+
is_nightly_build = torchao_version.startswith("0.13.0")
46+
if not is_nightly_build:
4547
raise ImportError(
46-
f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again"
48+
f"torchao version {torchao_version} is too old, please install torchao nightly build and try again"
4749
)
4850

4951
# Can be removed if we enable the emulated versions
5052
assert has_cuda_capability(
5153
10, 0
5254
), "MXFP8 is only supported on SM100 or architectures"
5355

54-
self.enabled = True
55-
mx_job_config: MX = job_config.mx
56-
self.filter_fqns = mx_job_config.filter_fqns
56+
# TP not yet supported with torch.compile
57+
assert not (
58+
job_config.training.compile
59+
and job_config.parallelism.tensor_parallel_degree > 1
60+
), "TP not yet supported with torch.compile for mxfp8"
5761

5862
# Configure MXFP8
59-
from torchao.prototype.mx_formats.config import MXLinearConfig
63+
from torchao.prototype.mx_formats.config import (
64+
MXFP8Dim1CastKernelChoice,
65+
MXLinearConfig,
66+
)
6067

68+
mx_job_config: MX = job_config.mx
6169
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62-
config.use_fp8_dim1_cast_triton_kernel = (
63-
mx_job_config.use_fp8_dim1_cast_triton_kernel
64-
)
70+
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
71+
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
72+
]
73+
self.filter_fqns = mx_job_config.filter_fqns
6574
self.config = config
66-
75+
self.enabled = True
6776
logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}")
6877

6978
def convert(self, model: nn.Module):

torchtitan/config/job_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ class Float8:
534534

535535
@dataclass
536536
class MX:
537-
use_fp8_dim1_cast_triton_kernel: bool = True
537+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
538538
"""Temp work around for inductor performance gap"""
539539

540540
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)