Skip to content

Commit 372b083

Browse files
make mxfp8 dim1 cast kernel configurable
stack-info: PR: #1427, branch: danielvegamyhre/stack/1
1 parent d69a737 commit 372b083

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,20 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5656
self.filter_fqns = mx_job_config.filter_fqns
5757

5858
# Configure MXFP8
59-
from torchao.prototype.mx_formats.config import MXLinearConfig
59+
from torchao.prototype.mx_formats.config import (
60+
MXFP8Dim1CastKernelChoice,
61+
MXLinearConfig,
62+
)
6063

6164
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62-
config.use_fp8_dim1_cast_triton_kernel = (
63-
mx_job_config.use_fp8_dim1_cast_triton_kernel
65+
66+
dim1_cast_kernel_choice_str = (
67+
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
68+
)
69+
config.mxfp8_dim1_cast_kernel_choice = (
70+
MXFP8Dim1CastKernelChoice[dim1_cast_kernel_choice_str]
71+
if mx_job_config.mxfp8_dim1_cast_kernel_choice != "NONE"
72+
else None
6473
)
6574
self.config = config
6675

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ class Float8:
556556

557557
@dataclass
558558
class MX:
559-
use_fp8_dim1_cast_triton_kernel: bool = True
559+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "none"] = "triton"
560560
"""Temp work around for inductor performance gap"""
561561

562562
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)