Skip to content

Commit d420d93

Browse files
make mxfp8 dim1 cast kernel configurable
1 parent 7d5f3cc commit d420d93

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,15 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5656
self.filter_fqns = mx_job_config.filter_fqns
5757

5858
# Configure MXFP8
59-
from torchao.prototype.mx_formats.config import MXLinearConfig
59+
from torchao.prototype.mx_formats.config import MXLinearConfig, MXFP8Dim1CastKernelChoice
6060

6161
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62-
config.use_fp8_dim1_cast_triton_kernel = (
63-
mx_job_config.use_fp8_dim1_cast_triton_kernel
62+
63+
dim1_cast_kernel_choice_str = mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
64+
config.mxfp8_dim1_cast_kernel_choice = (
65+
MXFP8Dim1CastKernelChoice[dim1_cast_kernel_choice_str]
66+
if mx_job_config.mxfp8_dim1_cast_kernel_choice != "none"
67+
else None
6468
)
6569
self.config = config
6670

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ class Float8:
523523

524524
@dataclass
525525
class MX:
526-
use_fp8_dim1_cast_triton_kernel: bool = True
526+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "none"] = "triton"
527527
"""Temp work around for inductor performance gap"""
528528

529529
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)