From 4806fdbb4acf3c307607be66312d779e8117e62a Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 15 Jul 2025 14:54:09 -0700 Subject: [PATCH] make mxfp8 dim1 cast kernel configurable stack-info: PR: https://github.com/pytorch/torchtitan/pull/1427, branch: danielvegamyhre/stack/1 --- torchtitan/components/quantization/mx.py | 31 +++++++++++++++--------- torchtitan/config/job_config.py | 2 +- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f22ac4bd04..f2c6820a70 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -40,10 +40,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "torchao is not installed. Please install it to use MXFP8 linear layers." ) torchao_version = version("torchao") - mxfp8_min_version = "0.11.0" - if torchao_version < mxfp8_min_version: + + # Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git... + is_nightly_build = torchao_version.startswith("0.13.0") + if not is_nightly_build: raise ImportError( - f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again" + f"torchao version {torchao_version} is too old, please install torchao nightly build and try again" ) # Can be removed if we enable the emulated versions @@ -51,19 +53,26 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 10, 0 ), "MXFP8 is only supported on SM100 or architectures" - self.enabled = True - mx_job_config: MX = job_config.mx - self.filter_fqns = mx_job_config.filter_fqns + # TP not yet supported with torch.compile + assert not ( + job_config.training.compile + and job_config.parallelism.tensor_parallel_degree > 1 + ), "TP not yet supported with torch.compile for mxfp8" # Configure MXFP8 - from torchao.prototype.mx_formats.config import MXLinearConfig + from torchao.prototype.mx_formats.config import ( + MXFP8Dim1CastKernelChoice, + MXLinearConfig, + ) + mx_job_config: MX = job_config.mx config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name]) - config.use_fp8_dim1_cast_triton_kernel = ( - mx_job_config.use_fp8_dim1_cast_triton_kernel - ) + config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[ + mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() + ] + self.filter_fqns = mx_job_config.filter_fqns self.config = config - + self.enabled = True logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}") def convert(self, model: nn.Module): diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index fdf38c63fe..d673999810 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -534,7 +534,7 @@ class Float8: @dataclass class MX: - use_fp8_dim1_cast_triton_kernel: bool = True + mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton" """Temp work around for inductor performance gap""" recipe_name: Literal["mxfp8"] = "mxfp8"