Skip to content

Commit 775d2a2

Browse files
committed
mx: make CUDA kernel for dim1 cast in mxfp8_cublas recipe
Summary: As titled, this is the fastest option right now so should be default. Test Plan: ```bash pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 52eb5cd ghstack-comment-id: 3144899603 Pull-Request: #2661
1 parent c993d64 commit 775d2a2

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchao/prototype/mx_formats/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,10 @@ def from_recipe_name(
184184
if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
185185
return MXLinearConfig()
186186
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
187-
# TODO(future PR): default to CUDA dim1 kernel
188-
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
187+
return MXLinearConfig(
188+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
189+
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
190+
)
189191
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
190192
return MXLinearConfig(
191193
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,

0 commit comments

Comments
 (0)