Skip to content

Commit d828f91

Browse files
integration of new mxfp8 casting cuda kernel (#2564)
1 parent 11f1a76 commit d828f91

File tree

5 files changed

+138
-34
lines changed

5 files changed

+138
-34
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tqdm import tqdm
2626

2727
from torchao.prototype.mx_formats import MXLinearConfig
28+
from torchao.prototype.mx_formats.config import MXFP8Dim1CastKernelChoice
2829
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2930
from torchao.testing.training.dtensor_utils import (
3031
_test_lowp_mlp_tensor_parallelism_base,
@@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
8283
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
8384
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
8485
config.block_size = 32
85-
config.use_fp8_dim1_cast_triton_kernel = True
86+
config.mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice.TRITON
8687
_test_lowp_mlp_tensor_parallelism_base(
8788
mesh, config, size, compile=False, allgather_in_lowp=False
8889
)
@@ -93,12 +94,22 @@ def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
9394
# )
9495

9596

97+
def _test_mxfp8_mlp_tensor_parallelism_dim1_cuda(mesh: DeviceMesh, size=128):
98+
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
99+
config.block_size = 32
100+
config.mxfp8_cast_kernel_choice = MXFP8Dim1CastKernelChoice.CUDA
101+
_test_lowp_mlp_tensor_parallelism_base(
102+
mesh, config, size, compile=False, allgather_in_lowp=False
103+
)
104+
105+
96106
if __name__ == "__main__":
97107
device_mesh = setup_distributed()
98108
tests = [
99109
_test_dtensor_cast_to_mxfp8,
100110
_test_mxfp8_mlp_tensor_parallelism,
101111
_test_mxfp8_mlp_tensor_parallelism_dim1_triton,
112+
_test_mxfp8_mlp_tensor_parallelism_dim1_cuda,
102113
]
103114

104115
for test in tqdm(tests, desc="Running tests"):

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313

1414
from torchao.prototype.mx_formats.config import (
15+
MXFP8Dim1CastKernelChoice,
1516
MXGemmKernelChoice,
1617
MXInferenceLinearConfig,
1718
MXLinearConfig,
@@ -81,16 +82,21 @@ def run_around_tests():
8182
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
8283
@pytest.mark.parametrize("bias", [True, False])
8384
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
84-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
85-
def test_linear_eager_vs_hp(
86-
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
87-
):
85+
@pytest.mark.parametrize(
86+
"mxfp8_cast_kernel_choice",
87+
[
88+
MXFP8Dim1CastKernelChoice.TORCH,
89+
MXFP8Dim1CastKernelChoice.TRITON,
90+
MXFP8Dim1CastKernelChoice.CUDA,
91+
],
92+
)
93+
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
8894
"""
8995
Smoke test for training linear module with mx weight, compares the following:
9096
* baseline: float32
9197
* experiment: emulated MX
9298
"""
93-
if use_fp8_dim1_cast_triton_kernel:
99+
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
94100
if elem_dtype != (
95101
torch.float8_e4m3fn,
96102
torch.float8_e4m3fn,
@@ -109,11 +115,11 @@ def test_linear_eager_vs_hp(
109115
)
110116
m_mx = copy.deepcopy(m)
111117
config = MXLinearConfig(
112-
block_size=4,
118+
block_size=32, # Only 32 is supported for now
113119
elem_dtype=elem_dtype[0],
114120
elem_dtype_weight_override=elem_dtype[1],
115121
elem_dtype_grad_output_override=elem_dtype[2],
116-
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
122+
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
117123
)
118124
quantize_(m_mx, config)
119125

@@ -227,8 +233,15 @@ def test_activation_checkpointing():
227233
@pytest.mark.parametrize("bias", [False, True])
228234
# TODO(future PR): figure out why torch.compile does not match eager when
229235
# autocast is on
230-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
231-
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
236+
@pytest.mark.parametrize(
237+
"mxfp8_cast_kernel_choice",
238+
[
239+
MXFP8Dim1CastKernelChoice.TORCH,
240+
MXFP8Dim1CastKernelChoice.TRITON,
241+
MXFP8Dim1CastKernelChoice.CUDA,
242+
],
243+
)
244+
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
232245
"""
233246
Verify that compile does not change numerics of MX linear fw + bw
234247
"""
@@ -246,7 +259,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
246259
# TODO(future PR): fix this, things are clearly broken with bias=True
247260
pytest.skip("this test is broken for non-emulated recipes with bias=True")
248261

249-
if use_fp8_dim1_cast_triton_kernel:
262+
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
250263
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
251264
pytest.skip("unsupported configuration")
252265
if not is_sm_at_least_89():
@@ -267,7 +280,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
267280
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
268281
)
269282
config = MXLinearConfig.from_recipe_name(recipe_name)
270-
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
283+
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
271284

272285
quantize_(m_mx, config=config)
273286
m_mx_c = copy.deepcopy(m_mx)

torchao/prototype/mx_formats/config.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ class MXGemmKernelChoice(Enum):
3333
CUBLAS = "cublas"
3434

3535

36+
class MXFP8Dim1CastKernelChoice(Enum):
37+
"""
38+
Defines which kernel to use for mxfp8 casting. Currently custom casting kernels are
39+
only for scaling along dim1, and torch native code is always used for scaling along dim0.
40+
"""
41+
42+
TRITON = "triton"
43+
CUDA = "cuda"
44+
TORCH = "torch"
45+
46+
3647
# Pre-made recipes for common configurations
3748
class MXLinearRecipeName(Enum):
3849
MXFP8_EMULATED = "mxfp8_emulated"
@@ -85,10 +96,12 @@ class MXLinearConfig(AOBaseConfig):
8596
# on the given hardware an exception will be thrown
8697
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
8798

88-
# If True, uses a custom triton kernel for cast to mxfp8 across dim1
99+
# define which kernel to use for mxfp8 casting
89100
# TODO(1945): remove this config option once torch.compile gives us
90101
# a fast kernel
91-
use_fp8_dim1_cast_triton_kernel: bool = False
102+
mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice = (
103+
MXFP8Dim1CastKernelChoice.TORCH
104+
)
92105

93106
# If True, uses a custom triton kernel for fp4 dequantize
94107
use_fp4_custom_triton_dequant_kernel: bool = False
@@ -146,8 +159,7 @@ def short_str(self) -> str:
146159
if self.elem_dtype_grad_output_override is not None:
147160
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
148161
s += f", kernel={self.gemm_kernel_choice.value}"
149-
if self.use_fp8_dim1_cast_triton_kernel:
150-
s += ", use_fp8_dim1_cast_triton_kernel=True"
162+
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
151163
if self.use_fp4_custom_triton_dequant_kernel:
152164
s += ", use_fp4_custom_triton_dequant_kernel=True"
153165
return s

torchao/prototype/mx_formats/kernels.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,7 @@ def triton_scale_swizzle(
14041404
scale_cols,
14051405
output_ptr,
14061406
input_row_stride,
1407+
input_col_stride,
14071408
output_block_stride,
14081409
BLOCK_ROWS: tl.constexpr,
14091410
BLOCK_COLS: tl.constexpr,
@@ -1423,7 +1424,7 @@ def triton_scale_swizzle(
14231424
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
14241425

14251426
input_scales = tl.load(
1426-
scale_ptr + global_rows * input_row_stride + global_cols,
1427+
scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride,
14271428
mask=mask,
14281429
other=0.0,
14291430
)
@@ -1463,7 +1464,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14631464
assert scale_tensor.element_size() == 1, (
14641465
"Expected element size to be 1 byte (8 bits)"
14651466
)
1466-
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
14671467

14681468
rows, cols = scale_tensor.shape
14691469

@@ -1476,7 +1476,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14761476
out = scale_tensor.new_empty((padded_rows, padded_cols))
14771477

14781478
# Input stride (for row-major format)
1479-
input_row_stride = cols
1479+
input_row_stride = scale_tensor.stride()[0]
1480+
input_col_stride = scale_tensor.stride()[1]
14801481

14811482
# We probably want handle multiple blocks per tile but for now keep it simple
14821483
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -1495,6 +1496,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14951496
cols,
14961497
out.view(torch.uint8),
14971498
input_row_stride,
1499+
input_col_stride,
14981500
output_block_stride,
14991501
BLOCK_ROWS=BLOCK_ROWS,
15001502
BLOCK_COLS=BLOCK_COLS,
@@ -1740,6 +1742,9 @@ def triton_quantize_nvfp4(
17401742
if is_sm_at_least_100():
17411743
from torchao.prototype import mxfp8_cuda
17421744

1745+
# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
1746+
# Currently we have to use an arbitrary string because custom ops don't support enum
1747+
# params.
17431748
@torch.library.custom_op("torchao::mxfp8_quantize_cuda", mutates_args=())
17441749
def mxfp8_quantize_cuda(
17451750
x: torch.Tensor,
@@ -1812,6 +1817,42 @@ def _(
18121817

18131818
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
18141819

1820+
@register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default)
1821+
def custom_mxfp8_quantize_cuda_dim1_sharding(
1822+
x: torch.Tensor,
1823+
rowwise: bool = False,
1824+
colwise: bool = True,
1825+
scaling_mode: str = "floor",
1826+
):
1827+
# This function signature can be used to understand the shardings:
1828+
# _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True)
1829+
1830+
# When inputs and scale are replicated, we return a quantized output tensor (replicated).
1831+
inputs_replicated = [None, Replicate(), None, Replicate()]
1832+
outputs_replicated = [None, Replicate(), None, None]
1833+
rule_for_input_replicated = (
1834+
inputs_replicated,
1835+
outputs_replicated,
1836+
)
1837+
1838+
# When inputs and scale are sharded along dim 0,
1839+
# we return a quantized output tensor (sharded along dim1 due to transpose).
1840+
inputs_sharded_dim0 = [None, Shard(0), None, Shard(0)]
1841+
outputs_sharded_dim1 = [None, Shard(1), None, None]
1842+
rule_for_input_sharded_dim0 = (inputs_sharded_dim0, outputs_sharded_dim1)
1843+
1844+
# When inputs and scale are sharded along dim 1,
1845+
# we return a quantized output tensor (sharded along dim0 due to transpose).
1846+
inputs_sharded_dim1 = [None, Shard(1), None, Shard(1)]
1847+
outputs_sharded_dim0 = [None, Shard(0), None, None]
1848+
rule_for_input_sharded_dim1 = (inputs_sharded_dim1, outputs_sharded_dim0)
1849+
1850+
acceptable_shardings = [
1851+
rule_for_input_replicated,
1852+
rule_for_input_sharded_dim0,
1853+
rule_for_input_sharded_dim1,
1854+
]
1855+
return acceptable_shardings
18151856
else:
18161857

18171858
def mxfp8_quantize_cuda(

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,41 @@
1515
from torch.distributed._tensor import DTensor
1616

1717
from torchao.prototype.mx_formats.config import (
18+
MXFP8Dim1CastKernelChoice,
1819
MXGemmKernelChoice,
1920
MXInferenceLinearConfig,
2021
MXLinearConfig,
2122
)
22-
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1
23+
from torchao.prototype.mx_formats.kernels import (
24+
mxfp8_quantize_cuda,
25+
triton_to_mxfp8_dim1,
26+
)
2327
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2428
from torchao.quantization.transform_module import (
2529
register_quantize_module_handler,
2630
)
2731

2832

29-
def _triton_to_mxfp8_dim1_wrapper(
30-
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
33+
def _to_mxfp8_dim1_kernel_wrapper(
34+
a,
35+
block_size,
36+
elem_dtype,
37+
hp_dtype,
38+
gemm_kernel_choice,
39+
cast_kernel_choice,
3140
):
32-
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
41+
if cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
42+
a_data, a_scale = triton_to_mxfp8_dim1(a, block_size)
43+
elif cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
44+
_, a_data, _, a_scale = mxfp8_quantize_cuda(
45+
a,
46+
rowwise=False,
47+
colwise=True,
48+
scaling_mode="floor",
49+
)
50+
else:
51+
raise ValueError(f"must be one of [CUDA, TRITON], got {cast_kernel_choice}")
52+
3353
if isinstance(a_data, DTensor):
3454
assert isinstance(a_scale, DTensor)
3555
a_data_local = a_data.to_local()
@@ -86,15 +106,15 @@ def forward(
86106
grad_elem_dtype: Any,
87107
block_size: int,
88108
gemm_kernel_choice: MXGemmKernelChoice,
89-
use_fp8_dim1_cast_triton_kernel: bool,
109+
mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice,
90110
):
91111
ctx.save_for_backward(input_hp, weight_hp)
92112
ctx.in_elem_dtype = in_elem_dtype
93113
ctx.w_elem_dtype = w_elem_dtype
94114
ctx.grad_elem_dtype = grad_elem_dtype
95115
ctx.block_size = block_size
96116
ctx.gemm_kernel_choice = gemm_kernel_choice
97-
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
117+
ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
98118

99119
# input @ weight_t = output
100120
input_orig_shape = input_hp.shape
@@ -119,7 +139,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
119139
grad_elem_dtype = ctx.grad_elem_dtype
120140
block_size = ctx.block_size
121141
gemm_kernel_choice = ctx.gemm_kernel_choice
122-
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel
142+
mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice
123143

124144
grad_output_orig_shape = grad_output_hp.shape
125145
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
@@ -135,9 +155,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
135155
gemm_kernel_choice=gemm_kernel_choice,
136156
)
137157

138-
if use_fp8_dim1_cast_triton_kernel:
139-
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
140-
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
158+
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
159+
weight_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper(
160+
weight_hp,
161+
block_size,
162+
w_elem_dtype,
163+
weight_hp.dtype,
164+
gemm_kernel_choice,
165+
mxfp8_cast_kernel_choice,
141166
)
142167
else:
143168
weight_hp_t_c = weight_hp.t().contiguous()
@@ -153,13 +178,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
153178
)
154179

155180
# input_t @ grad_output = grad_weight
156-
if use_fp8_dim1_cast_triton_kernel:
157-
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
181+
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
182+
grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper(
158183
grad_output_hp_r,
159184
block_size,
160185
grad_elem_dtype,
161186
grad_output_hp_r.dtype,
162187
gemm_kernel_choice,
188+
mxfp8_cast_kernel_choice,
163189
)
164190
else:
165191
grad_output_mx_dim1 = MXTensor.to_mx(
@@ -169,13 +195,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
169195
gemm_kernel_choice=gemm_kernel_choice,
170196
)
171197

172-
if use_fp8_dim1_cast_triton_kernel:
173-
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
198+
if mxfp8_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH:
199+
input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper(
174200
input_hp_r,
175201
block_size,
176202
in_elem_dtype,
177203
input_hp_r.dtype,
178204
gemm_kernel_choice,
205+
mxfp8_cast_kernel_choice,
179206
)
180207
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
181208
else:
@@ -232,7 +259,7 @@ def forward(self, x):
232259
config.elem_dtype_grad_output_override or config.elem_dtype,
233260
config.block_size,
234261
config.gemm_kernel_choice,
235-
config.use_fp8_dim1_cast_triton_kernel,
262+
config.mxfp8_cast_kernel_choice,
236263
)
237264
if self.bias is not None:
238265
y = y + self.bias

0 commit comments

Comments
 (0)