Skip to content

Commit 9cabb91

Browse files
update api name
stack-info: PR: #1428, branch: danielvegamyhre/stack/2
1 parent 372b083 commit 9cabb91

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6363

6464
config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
6565

66-
dim1_cast_kernel_choice_str = (
66+
# String to enum
67+
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice[
6768
mx_job_config.mxfp8_dim1_cast_kernel_choice.upper()
68-
)
69-
config.mxfp8_dim1_cast_kernel_choice = (
70-
MXFP8Dim1CastKernelChoice[dim1_cast_kernel_choice_str]
71-
if mx_job_config.mxfp8_dim1_cast_kernel_choice != "NONE"
72-
else None
73-
)
69+
]
7470
self.config = config
7571

7672
logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}")

torchtitan/config_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ class Float8:
556556

557557
@dataclass
558558
class MX:
559-
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "none"] = "triton"
559+
mxfp8_dim1_cast_kernel_choice: Literal["triton", "cuda", "torch"] = "triton"
560560
"""Temp work around for inductor performance gap"""
561561

562562
recipe_name: Literal["mxfp8"] = "mxfp8"

0 commit comments

Comments
 (0)