Skip to content

integration of new mxfp8 casting cuda kernel #2564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion test/prototype/mx_formats/test_mx_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tqdm import tqdm

from torchao.prototype.mx_formats import MXLinearConfig
from torchao.prototype.mx_formats.config import MXFP8Dim1CastKernelChoice
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.testing.training.dtensor_utils import (
_test_lowp_mlp_tensor_parallelism_base,
Expand Down Expand Up @@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
config.block_size = 32
config.use_fp8_dim1_cast_triton_kernel = True
config.mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice.TRITON
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)
Expand All @@ -93,12 +94,22 @@ def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
# )


def _test_mxfp8_mlp_tensor_parallelism_dim1_cuda(mesh: DeviceMesh, size=128):
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
config.block_size = 32
config.mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice.CUDA
_test_lowp_mlp_tensor_parallelism_base(
mesh, config, size, compile=False, allgather_in_lowp=False
)


if __name__ == "__main__":
device_mesh = setup_distributed()
tests = [
_test_dtensor_cast_to_mxfp8,
_test_mxfp8_mlp_tensor_parallelism,
_test_mxfp8_mlp_tensor_parallelism_dim1_triton,
_test_mxfp8_mlp_tensor_parallelism_dim1_cuda,
]

for test in tqdm(tests, desc="Running tests"):
Expand Down
35 changes: 24 additions & 11 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import (
MXFP8Dim1CastKernelChoice,
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
Expand Down Expand Up @@ -81,16 +82,21 @@ def run_around_tests():
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
def test_linear_eager_vs_hp(
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
):
@pytest.mark.parametrize(
"mxfp8_cast_kernel_choice",
[
MXFP8Dim1CastKernelChoice.TORCH,
MXFP8Dim1CastKernelChoice.TRITON,
MXFP8Dim1CastKernelChoice.CUDA,
],
)
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
"""
Smoke test for training linear module with mx weight, compares the following:
* baseline: float32
* experiment: emulated MX
"""
if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
if elem_dtype != (
torch.float8_e4m3fn,
torch.float8_e4m3fn,
Expand All @@ -109,11 +115,11 @@ def test_linear_eager_vs_hp(
)
m_mx = copy.deepcopy(m)
config = MXLinearConfig(
block_size=4,
block_size=32, # Only 32 is supported for now
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
)
quantize_(m_mx, config)

Expand Down Expand Up @@ -227,8 +233,15 @@ def test_activation_checkpointing():
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
@pytest.mark.parametrize(
"mxfp8_cast_kernel_choice",
[
MXFP8Dim1CastKernelChoice.TORCH,
MXFP8Dim1CastKernelChoice.TRITON,
MXFP8Dim1CastKernelChoice.CUDA,
],
)
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
Expand All @@ -246,7 +259,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for non-emulated recipes with bias=True")

if use_fp8_dim1_cast_triton_kernel:
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
pytest.skip("unsupported configuration")
if not is_sm_at_least_89():
Expand All @@ -267,7 +280,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
)
config = MXLinearConfig.from_recipe_name(recipe_name)
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice

quantize_(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
Expand Down
20 changes: 16 additions & 4 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class MXGemmKernelChoice(Enum):
CUBLAS = "cublas"


class MXFP8Dim1CastKernelChoice(Enum):
"""
Defines which kernel to use for mxfp8 casting. Currently custom casting kernels are
only for scaling along dim1, and torch native code is always used for scaling along dim0.
"""

TRITON = "triton"
CUDA = "cuda"
TORCH = "torch"


# Pre-made recipes for common configurations
class MXLinearRecipeName(Enum):
MXFP8_EMULATED = "mxfp8_emulated"
Expand Down Expand Up @@ -85,10 +96,12 @@ class MXLinearConfig(AOBaseConfig):
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED

# If True, uses a custom triton kernel for cast to mxfp8 across dim1
# define which kernel to use for mxfp8 casting
# TODO(1945): remove this config option once torch.compile gives us
# a fast kernel
use_fp8_dim1_cast_triton_kernel: bool = False
mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice = (
MXFP8Dim1CastKernelChoice.TORCH
)

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False
Expand Down Expand Up @@ -146,8 +159,7 @@ def short_str(self) -> str:
if self.elem_dtype_grad_output_override is not None:
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
s += f", kernel={self.gemm_kernel_choice.value}"
if self.use_fp8_dim1_cast_triton_kernel:
s += ", use_fp8_dim1_cast_triton_kernel=True"
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
if self.use_fp4_custom_triton_dequant_kernel:
s += ", use_fp4_custom_triton_dequant_kernel=True"
return s
Expand Down
47 changes: 44 additions & 3 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,6 +1404,7 @@ def triton_scale_swizzle(
scale_cols,
output_ptr,
input_row_stride,
input_col_stride,
output_block_stride,
BLOCK_ROWS: tl.constexpr,
BLOCK_COLS: tl.constexpr,
Expand All @@ -1423,7 +1424,7 @@ def triton_scale_swizzle(
mask = (global_rows < scale_rows) & (global_cols < scale_cols)

input_scales = tl.load(
scale_ptr + global_rows * input_row_stride + global_cols,
scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride,
mask=mask,
other=0.0,
)
Expand Down Expand Up @@ -1463,7 +1464,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
assert scale_tensor.element_size() == 1, (
"Expected element size to be 1 byte (8 bits)"
)
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"

rows, cols = scale_tensor.shape

Expand All @@ -1476,7 +1476,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
out = scale_tensor.new_empty((padded_rows, padded_cols))

# Input stride (for row-major format)
input_row_stride = cols
input_row_stride = scale_tensor.stride()[0]
input_col_stride = scale_tensor.stride()[1]

# We probably want handle multiple blocks per tile but for now keep it simple
BLOCK_ROWS, BLOCK_COLS = 128, 4
Expand All @@ -1495,6 +1496,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
cols,
out.view(torch.uint8),
input_row_stride,
input_col_stride,
output_block_stride,
BLOCK_ROWS=BLOCK_ROWS,
BLOCK_COLS=BLOCK_COLS,
Expand Down Expand Up @@ -1740,6 +1742,9 @@ def triton_quantize_nvfp4(
if is_sm_at_least_100():
from torchao.prototype import mxfp8_cuda

# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
# Currently we have to use an arbitrary string because custom ops don't support enum
# params.
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
def mxfp8_quantize_cuda(
x: torch.Tensor,
Expand Down Expand Up @@ -1812,6 +1817,42 @@ def _(

return output_rowwise, output_colwise, scales_rowwise, scales_colwise

@register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default)
def custom_mxfp8_quantize_cuda_dim1_sharding(
x: torch.Tensor,
rowwise: bool = False,
colwise: bool = True,
scaling_mode: str = "floor",
):
# This function signature can be used to understand the shardings:
# _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True)

# When inputs and scale are replicated, we return a quantized output tensor (replicated).
inputs_replicated = [None, Replicate(), None, Replicate()]
outputs_replicated = [None, Replicate(), None, None]
rule_for_input_replicated = (
inputs_replicated,
outputs_replicated,
)

# When inputs and scale are sharded along dim 0,
# we return a quantized output tensor (sharded along dim1 due to transpose).
inputs_sharded_dim0 = [None, Shard(0), None, Shard(0)]
outputs_sharded_dim1 = [None, Shard(1), None, None]
rule_for_input_sharded_dim0 = (inputs_sharded_dim0, outputs_sharded_dim1)

# When inputs and scale are sharded along dim 1,
# we return a quantized output tensor (sharded along dim0 due to transpose).
inputs_sharded_dim1 = [None, Shard(1), None, Shard(1)]
outputs_sharded_dim0 = [None, Shard(0), None, None]
rule_for_input_sharded_dim1 = (inputs_sharded_dim1, outputs_sharded_dim0)

acceptable_shardings = [
rule_for_input_replicated,
rule_for_input_sharded_dim0,
rule_for_input_sharded_dim1,
]
return acceptable_shardings
else:

def mxfp8_quantize_cuda(
Expand Down
57 changes: 42 additions & 15 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,41 @@
from torch.distributed._tensor import DTensor

from torchao.prototype.mx_formats.config import (
MXFP8Dim1CastKernelChoice,
MXGemmKernelChoice,
MXInferenceLinearConfig,
MXLinearConfig,
)
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1
from torchao.prototype.mx_formats.kernels import (
mxfp8_quantize_cuda,
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)


def _triton_to_mxfp8_dim1_wrapper(
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
def _to_mxfp8_dim1_kernel_wrapper(
a,
block_size,
elem_dtype,
hp_dtype,
gemm_kernel_choice,
cast_kernel_choice,
):
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
if cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
elif cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
_, a_data, _, a_scale = mxfp8_quantize_cuda(
a,
rowwise=False,
colwise=True,
scaling_mode="floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for later to allow choice of scaling modes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added todo on the custom op itself, with an explanation why we currently are using a string param

)
else:
raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}")

if isinstance(a_data, DTensor):
assert isinstance(a_scale, DTensor)
a_data_local = a_data.to_local()
Expand Down Expand Up @@ -86,15 +106,15 @@ def forward(
grad_elem_dtype: Any,
block_size: int,
gemm_kernel_choice: MXGemmKernelChoice,
use_fp8_dim1_cast_triton_kernel: bool,
mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
ctx.grad_elem_dtype = grad_elem_dtype
ctx.block_size = block_size
ctx.gemm_kernel_choice = gemm_kernel_choice
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice

# input @ weight_t = output
input_orig_shape = input_hp.shape
Expand All @@ -119,7 +139,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
grad_elem_dtype = ctx.grad_elem_dtype
block_size = ctx.block_size
gemm_kernel_choice = ctx.gemm_kernel_choice
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel
mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice

grad_output_orig_shape = grad_output_hp.shape
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
Expand All @@ -135,9 +155,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
gemm_kernel_choice=gemm_kernel_choice,
)

if use_fp8_dim1_cast_triton_kernel:
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
weight_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper(
weight_hp,
block_size,
w_elem_dtype,
weight_hp.dtype,
gemm_kernel_choice,
mxfp8_cast_kernel_choice,
)
else:
weight_hp_t_c = weight_hp.t().contiguous()
Expand All @@ -153,13 +178,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
)

# input_t @ grad_output = grad_weight
if use_fp8_dim1_cast_triton_kernel:
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper(
grad_output_hp_r,
block_size,
grad_elem_dtype,
grad_output_hp_r.dtype,
gemm_kernel_choice,
mxfp8_cast_kernel_choice,
)
else:
grad_output_mx_dim1 = MXTensor.to_mx(
Expand All @@ -169,13 +195,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
gemm_kernel_choice=gemm_kernel_choice,
)

if use_fp8_dim1_cast_triton_kernel:
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper(
input_hp_r,
block_size,
in_elem_dtype,
input_hp_r.dtype,
gemm_kernel_choice,
mxfp8_cast_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
else:
Expand Down Expand Up @@ -232,7 +259,7 @@ def forward(self, x):
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
config.gemm_kernel_choice,
config.use_fp8_dim1_cast_triton_kernel,
config.mxfp8_cast_kernel_choice,
)
if self.bias is not None:
y = y + self.bias
Expand Down
Loading