From b7c508c282dc54294115d73bae18df83ae2dfe78 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 15 Jul 2025 13:17:17 -0700 Subject: [PATCH] integrate mxfp8 dim1 cast kernel choice enum into MXLinear stack-info: PR: https://github.com/pytorch/ao/pull/2554, branch: danielvegamyhre/stack/10 --- test/prototype/mx_formats/test_mx_dtensor.py | 3 +- test/prototype/mx_formats/test_mx_linear.py | 27 +++++++++------ torchao/prototype/mx_formats/config.py | 13 ++++--- torchao/prototype/mx_formats/mx_linear.py | 36 ++++++++++++++++---- 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_dtensor.py b/test/prototype/mx_formats/test_mx_dtensor.py index 4f5cce1a2a..495d1600b4 100644 --- a/test/prototype/mx_formats/test_mx_dtensor.py +++ b/test/prototype/mx_formats/test_mx_dtensor.py @@ -25,6 +25,7 @@ from tqdm import tqdm from torchao.prototype.mx_formats import MXLinearConfig +from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.testing.training.dtensor_utils import ( _test_lowp_mlp_tensor_parallelism_base, @@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128): def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128): config = MXLinearConfig.from_recipe_name("mxfp8_emulated") config.block_size = 32 - config.use_fp8_dim1_cast_triton_kernel = True + config.mxfp8_cast_kernel_choice = MXFP8CastKernelChoice.CUDA _test_lowp_mlp_tensor_parallelism_base( mesh, config, size, compile=False, allgather_in_lowp=False ) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index fbf115b1bb..478159effd 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torchao.prototype.mx_formats.config import ( + MXFP8CastKernelChoice, MXGemmKernelChoice, MXInferenceLinearConfig, MXLinearConfig, @@ -81,16 +82,17 @@ def run_around_tests(): @pytest.mark.parametrize("elem_dtype", elem_dtypes) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) -@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True]) -def test_linear_eager_vs_hp( - elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel -): +@pytest.mark.parametrize( + "mxfp8_cast_kernel_choice", + [None, MXFP8CastKernelChoice.TRITON, MXFP8CastKernelChoice.CUDA], +) +def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice): """ Smoke test for training linear module with mx weight, compares the following: * baseline: float32 * experiment: emulated MX """ - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice is not None: if elem_dtype != ( torch.float8_e4m3fn, torch.float8_e4m3fn, @@ -109,11 +111,11 @@ def test_linear_eager_vs_hp( ) m_mx = copy.deepcopy(m) config = MXLinearConfig( - block_size=4, + block_size=32, # Only 32 is supported for now elem_dtype=elem_dtype[0], elem_dtype_weight_override=elem_dtype[1], elem_dtype_grad_output_override=elem_dtype[2], - use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel, + mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice, ) quantize_(m_mx, config) @@ -227,8 +229,11 @@ def test_activation_checkpointing(): @pytest.mark.parametrize("bias", [False, True]) # TODO(future PR): figure out why torch.compile does not match eager when # autocast is on -@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True]) -def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel): +@pytest.mark.parametrize( + "mxfp8_cast_kernel_choice", + [None, MXFP8CastKernelChoice.TRITON, MXFP8CastKernelChoice.CUDA], +) +def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice): """ Verify that compile does not change numerics of MX linear fw + bw """ @@ -246,7 +251,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke # TODO(future PR): fix this, things are clearly broken with bias=True pytest.skip("this test is broken for non-emulated recipes with bias=True") - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice is not None: if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"): pytest.skip("unsupported configuration") if not is_sm_at_least_89(): @@ -267,7 +272,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype), ) config = MXLinearConfig.from_recipe_name(recipe_name) - config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel + config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice quantize_(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 525bf21fc6..9c588c01fe 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -33,6 +33,12 @@ class MXGemmKernelChoice(Enum): CUBLAS = "cublas" +class MXFP8CastKernelChoice(Enum): + TRITON = "triton" + CUDA = "cuda" + TORCH = "torch" + + # Pre-made recipes for common configurations class MXLinearRecipeName(Enum): MXFP8_EMULATED = "mxfp8_emulated" @@ -85,10 +91,10 @@ class MXLinearConfig(AOBaseConfig): # on the given hardware an exception will be thrown gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED - # If True, uses a custom triton kernel for cast to mxfp8 across dim1 + # define which kernel to use for dim1 cast # TODO(1945): remove this config option once torch.compile gives us # a fast kernel - use_fp8_dim1_cast_triton_kernel: bool = False + mxfp8_cast_kernel_choice: MXFP8CastKernelChoice = MXFP8CastKernelChoice.TRITON # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False @@ -146,8 +152,7 @@ def short_str(self) -> str: if self.elem_dtype_grad_output_override is not None: s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" s += f", kernel={self.gemm_kernel_choice.value}" - if self.use_fp8_dim1_cast_triton_kernel: - s += ", use_fp8_dim1_cast_triton_kernel=True" + s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}" if self.use_fp4_custom_triton_dequant_kernel: s += ", use_fp4_custom_triton_dequant_kernel=True" return s diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index a405238ad8..59df87f6d0 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -15,6 +15,7 @@ from torch.distributed._tensor import DTensor from torchao.prototype.mx_formats.config import ( + MXFP8CastKernelChoice, MXGemmKernelChoice, MXInferenceLinearConfig, MXLinearConfig, @@ -134,7 +135,7 @@ def forward( grad_elem_dtype: Any, block_size: int, gemm_kernel_choice: MXGemmKernelChoice, - use_fp8_dim1_cast_triton_kernel: bool, + mxfp8_cast_kernel_choice: MXFP8CastKernelChoice, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype @@ -142,7 +143,7 @@ def forward( ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size ctx.gemm_kernel_choice = gemm_kernel_choice - ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel + ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice # input @ weight_t = output input_orig_shape = input_hp.shape @@ -167,7 +168,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size gemm_kernel_choice = ctx.gemm_kernel_choice - use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel + mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -183,10 +184,14 @@ def backward(ctx, grad_output_hp: torch.Tensor): gemm_kernel_choice=gemm_kernel_choice, ) - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON: weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper( weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice ) + elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA: + weight_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper( + weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice + ) else: weight_hp_t_c = weight_hp.t().contiguous() weight_mx_dim1 = MXTensor.to_mx( @@ -201,7 +206,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) # input_t @ grad_output = grad_weight - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON: grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper( grad_output_hp_r, block_size, @@ -209,6 +214,14 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_output_hp_r.dtype, gemm_kernel_choice, ) + elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA: + grad_output_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper( + grad_output_hp_r, + block_size, + grad_elem_dtype, + grad_output_hp_r.dtype, + gemm_kernel_choice, + ) else: grad_output_mx_dim1 = MXTensor.to_mx( grad_output_hp_r.t().contiguous(), @@ -217,7 +230,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): gemm_kernel_choice=gemm_kernel_choice, ) - if use_fp8_dim1_cast_triton_kernel: + if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON: input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper( input_hp_r, block_size, @@ -226,6 +239,15 @@ def backward(ctx, grad_output_hp: torch.Tensor): gemm_kernel_choice, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() + elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA: + input_t_mx_dim0_tmp = _cuda_to_mxfp8_dim1_wrapper( + input_hp_r, + block_size, + in_elem_dtype, + input_hp_r.dtype, + gemm_kernel_choice, + ) + input_t_mx_dim0 = input_t_mx_dim0_tmp.t() else: input_t_mx_dim0_tmp = MXTensor.to_mx( input_hp_r.t().contiguous(), @@ -280,7 +302,7 @@ def forward(self, x): config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, config.gemm_kernel_choice, - config.use_fp8_dim1_cast_triton_kernel, + config.mxfp8_cast_kernel_choice, ) if self.bias is not None: y = y + self.bias