Skip to content

Commit a6466e7

Browse files
make mxfp8 dim1 cast kernel configurable
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
1 parent d69a737 commit a6466e7

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4040
"torchao is not installed. Please install it to use MXFP8 linear layers."
4141
)
4242
torchao_version = version("torchao")
43-
mxfp8_min_version = "0.11.0"
44-
if torchao_version < mxfp8_min_version:
43+
44+
# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
45+
is_nightly_build = torchao_version.startswith("0.13.0")
46+
if not is_nightly_build:
4547
raise ImportError(
46-
f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again"
48+
f"torchao version {torchao_version} is too old, please install torchao nightly build and try again"
4749
)
4850

4951
# Can be removed if we enable the emulated versions
@@ -56,12 +58,17 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5658
self.filter_fqns = mx_job_config.filter_fqns
5759

5860
# Configure MXFP8
59-
from torchao.prototype.mx_formats.config import MXLinearConfig
61+
from torchao.prototype.mx_formats.config import (
62+
MXFP8Dim1CastKernelChoice,
63+
MXLinearConfig,
64+
)
6065

6166
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62-
config.use_fp8_dim1_cast_triton_kernel = (
63-
mx_job_config.use_fp8_dim1_cast_triton_kernel
64-
)
67+
68+
# String to enum
69+
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
70+
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
71+
]
6572
self.config = config
6673

6774
logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}")

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ class Float8:
556556

557557
@dataclass
558558
class MX:
559-
use_fp8_dim1_cast_triton_kernel: bool = True
559+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
560560
"""Temp work around for inductor performance gap"""
561561

562562
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)