Skip to content

Commit affe1a8

Browse files
make mxfp8 dim1 cast kernel configurable
1 parent 7d5f3cc commit affe1a8

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,20 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5656
self.filter_fqns = mx_job_config.filter_fqns
5757

5858
# Configure MXFP8
59-
from torchao.prototype.mx_formats.config import MXLinearConfig
59+
from torchao.prototype.mx_formats.config import (
60+
MXFP8Dim1CastKernelChoice,
61+
MXLinearConfig,
62+
)
6063

6164
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62-
config.use_fp8_dim1_cast_triton_kernel = (
63-
mx_job_config.use_fp8_dim1_cast_triton_kernel
65+
66+
dim1_cast_kernel_choice_str = (
67+
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
68+
)
69+
config.mxfp8_dim1_cast_kernel_choice = (
70+
MXFP8Dim1CastKernelChoice[dim1_cast_kernel_choice_str]
71+
if mx_job_config.mxfp8_dim1_cast_kernel_choice != "NONE"
72+
else None
6473
)
6574
self.config = config
6675

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ class Float8:
523523

524524
@dataclass
525525
class MX:
526-
use_fp8_dim1_cast_triton_kernel: bool = True
526+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "none"] = "triton"
527527
"""Temp work around for inductor performance gap"""
528528

529529
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)