From affe1a8291eb64f49b39f188a3bf2fa609f752f7 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 15 Jul 2025 14:54:09 -0700 Subject: [PATCH 1/2] make mxfp8 dim1 cast kernel configurable --- torchtitan/components/quantization/mx.py | 15 ++++++++++++--- torchtitan/config_manager.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index f2b1bdb5f..76dbb2af3 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -56,11 +56,20 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): self.filter_fqns = mx_job_config.filter_fqns # Configure MXFP8 - from torchao.prototype.mx_formats.config import MXLinearConfig + from torchao.prototype.mx_formats.config import ( + MXFP8Dim1CastKernelChoice, + MXLinearConfig, + ) config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name]) - config.use_fp8_dim1_cast_triton_kernel = ( - mx_job_config.use_fp8_dim1_cast_triton_kernel + + dim1_cast_kernel_choice_str = ( + mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() + ) + config.mxfp8_dim1_cast_kernel_choice = ( + MXFP8Dim1CastKernelChoice[dim1_cast_kernel_choice_str] + if mx_job_config.mxfp8_dim1_cast_kernel_choice != "NONE" + else None ) self.config = config diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 3f8d25688..439f1f700 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -523,7 +523,7 @@ class Float8: @dataclass class MX: - use_fp8_dim1_cast_triton_kernel: bool = True + mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "none"] = "triton" """Temp work around for inductor performance gap""" recipe_name: Literal["mxfp8"] = "mxfp8" From 5e84ec7646f61da95c1cf3ec1bd396319ecca8ea Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 16 Jul 2025 11:27:38 -0700 Subject: [PATCH 2/2] update api name --- torchtitan/components/quantization/mx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 76dbb2af3..aa19829e4 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -57,7 +57,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # Configure MXFP8 from torchao.prototype.mx_formats.config import ( - MXFP8Dim1CastKernelChoice, + MXFP8CastKernelChoice, MXLinearConfig, ) @@ -67,7 +67,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): mx_job_config.mxfp8_dim1_cast_kernel_choice.upper() ) config.mxfp8_dim1_cast_kernel_choice = ( - MXFP8Dim1CastKernelChoice[dim1_cast_kernel_choice_str] + MXFP8CastKernelChoice[dim1_cast_kernel_choice_str] if mx_job_config.mxfp8_dim1_cast_kernel_choice != "NONE" else None )