Skip to content

Commit 24c3d3a

Browse files
make mxfp8 dim1 cast kernel configurable
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
1 parent d69a737 commit 24c3d3a

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,17 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5656
self.filter_fqns = mx_job_config.filter_fqns
5757

5858
# Configure MXFP8
59-
from torchao.prototype.mx_formats.config import MXLinearConfig
59+
from torchao.prototype.mx_formats.config import (
60+
MXFP8Dim1CastKernelChoice,
61+
MXLinearConfig,
62+
)
6063

6164
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62-
config.use_fp8_dim1_cast_triton_kernel = (
63-
mx_job_config.use_fp8_dim1_cast_triton_kernel
64-
)
65+
66+
# String to enum
67+
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
68+
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
69+
]
6570
self.config = config
6671

6772
logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}")

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ class Float8:
556556

557557
@dataclass
558558
class MX:
559-
use_fp8_dim1_cast_triton_kernel: bool = True
559+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
560560
"""Temp work around for inductor performance gap"""
561561

562562
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)