Skip to content

Commit 3f88897

Browse files
integration of new mxfp8 casting cuda kernel
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
1 parent 95d13d5 commit 3f88897

File tree

5 files changed

+129
-27
lines changed

5 files changed

+129
-27
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tqdm import tqdm
2626

2727
from torchao.prototype.mx_formats import MXLinearConfig
28+
from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice
2829
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2930
from torchao.testing.training.dtensor_utils import (
3031
_test_lowp_mlp_tensor_parallelism_base,
@@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
8283
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
8384
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
8485
config.block_size = 32
85-
config.use_fp8_dim1_cast_triton_kernel = True
86+
config.mxfp8_cast_kernel_choice = MXFP8CastKernelChoice.CUDA
8687
_test_lowp_mlp_tensor_parallelism_base(
8788
mesh, config, size, compile=False, allgather_in_lowp=False
8889
)

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313

1414
from torchao.prototype.mx_formats.config import (
15+
MXFP8CastKernelChoice,
1516
MXGemmKernelChoice,
1617
MXInferenceLinearConfig,
1718
MXLinearConfig,
@@ -81,16 +82,17 @@ def run_around_tests():
8182
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
8283
@pytest.mark.parametrize("bias", [True, False])
8384
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
84-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
85-
def test_linear_eager_vs_hp(
86-
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
87-
):
85+
@pytest.mark.parametrize(
86+
"mxfp8_cast_kernel_choice",
87+
[None, MXFP8CastKernelChoice.TRITON, MXFP8CastKernelChoice.CUDA],
88+
)
89+
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
8890
"""
8991
Smoke test for training linear module with mx weight, compares the following:
9092
* baseline: float32
9193
* experiment: emulated MX
9294
"""
93-
if use_fp8_dim1_cast_triton_kernel:
95+
if mxfp8_cast_kernel_choice is not None:
9496
if elem_dtype != (
9597
torch.float8_e4m3fn,
9698
torch.float8_e4m3fn,
@@ -109,11 +111,11 @@ def test_linear_eager_vs_hp(
109111
)
110112
m_mx = copy.deepcopy(m)
111113
config = MXLinearConfig(
112-
block_size=4,
114+
block_size=32, # Only 32 is supported for now
113115
elem_dtype=elem_dtype[0],
114116
elem_dtype_weight_override=elem_dtype[1],
115117
elem_dtype_grad_output_override=elem_dtype[2],
116-
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
118+
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
117119
)
118120
quantize_(m_mx, config)
119121

@@ -227,8 +229,11 @@ def test_activation_checkpointing():
227229
@pytest.mark.parametrize("bias", [False, True])
228230
# TODO(future PR): figure out why torch.compile does not match eager when
229231
# autocast is on
230-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
231-
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
232+
@pytest.mark.parametrize(
233+
"mxfp8_cast_kernel_choice",
234+
[None, MXFP8CastKernelChoice.TRITON, MXFP8CastKernelChoice.CUDA],
235+
)
236+
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
232237
"""
233238
Verify that compile does not change numerics of MX linear fw + bw
234239
"""
@@ -246,7 +251,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
246251
# TODO(future PR): fix this, things are clearly broken with bias=True
247252
pytest.skip("this test is broken for non-emulated recipes with bias=True")
248253

249-
if use_fp8_dim1_cast_triton_kernel:
254+
if mxfp8_cast_kernel_choice is not None:
250255
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
251256
pytest.skip("unsupported configuration")
252257
if not is_sm_at_least_89():
@@ -267,7 +272,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
267272
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
268273
)
269274
config = MXLinearConfig.from_recipe_name(recipe_name)
270-
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
275+
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
271276

272277
quantize_(m_mx, config=config)
273278
m_mx_c = copy.deepcopy(m_mx)

torchao/prototype/mx_formats/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ class MXGemmKernelChoice(Enum):
3333
CUBLAS = "cublas"
3434

3535

36+
class MXFP8CastKernelChoice(Enum):
37+
TRITON = "triton"
38+
CUDA = "cuda"
39+
TORCH = "torch"
40+
41+
3642
# Pre-made recipes for common configurations
3743
class MXLinearRecipeName(Enum):
3844
MXFP8_EMULATED = "mxfp8_emulated"
@@ -85,10 +91,10 @@ class MXLinearConfig(AOBaseConfig):
8591
# on the given hardware an exception will be thrown
8692
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
8793

88-
# If True, uses a custom triton kernel for cast to mxfp8 across dim1
94+
# define which kernel to use for dim1 cast
8995
# TODO(1945): remove this config option once torch.compile gives us
9096
# a fast kernel
91-
use_fp8_dim1_cast_triton_kernel: bool = False
97+
mxfp8_cast_kernel_choice: MXFP8CastKernelChoice = MXFP8CastKernelChoice.TRITON
9298

9399
# If True, uses a custom triton kernel for fp4 dequantize
94100
use_fp4_custom_triton_dequant_kernel: bool = False
@@ -146,8 +152,7 @@ def short_str(self) -> str:
146152
if self.elem_dtype_grad_output_override is not None:
147153
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
148154
s += f", kernel={self.gemm_kernel_choice.value}"
149-
if self.use_fp8_dim1_cast_triton_kernel:
150-
s += ", use_fp8_dim1_cast_triton_kernel=True"
155+
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
151156
if self.use_fp4_custom_triton_dequant_kernel:
152157
s += ", use_fp4_custom_triton_dequant_kernel=True"
153158
return s

torchao/prototype/mx_formats/kernels.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,7 @@ def triton_scale_swizzle(
14041404
scale_cols,
14051405
output_ptr,
14061406
input_row_stride,
1407+
input_col_stride,
14071408
output_block_stride,
14081409
BLOCK_ROWS: tl.constexpr,
14091410
BLOCK_COLS: tl.constexpr,
@@ -1423,7 +1424,7 @@ def triton_scale_swizzle(
14231424
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
14241425

14251426
input_scales = tl.load(
1426-
scale_ptr + global_rows * input_row_stride + global_cols,
1427+
scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride,
14271428
mask=mask,
14281429
other=0.0,
14291430
)
@@ -1463,7 +1464,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14631464
assert scale_tensor.element_size() == 1, (
14641465
"Expected element size to be 1 byte (8 bits)"
14651466
)
1466-
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
14671467

14681468
rows, cols = scale_tensor.shape
14691469

@@ -1476,7 +1476,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14761476
out = scale_tensor.new_empty((padded_rows, padded_cols))
14771477

14781478
# Input stride (for row-major format)
1479-
input_row_stride = cols
1479+
input_row_stride = scale_tensor.stride()[0]
1480+
input_col_stride = scale_tensor.stride()[1]
14801481

14811482
# We probably want handle multiple blocks per tile but for now keep it simple
14821483
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -1495,6 +1496,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14951496
cols,
14961497
out.view(torch.uint8),
14971498
input_row_stride,
1499+
input_col_stride,
14981500
output_block_stride,
14991501
BLOCK_ROWS=BLOCK_ROWS,
15001502
BLOCK_COLS=BLOCK_COLS,
@@ -1812,6 +1814,25 @@ def _(
18121814

18131815
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
18141816

1817+
@register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default)
1818+
def custom_mxfp8_quantize_cuda_dim1_sharding(
1819+
x: torch.Tensor,
1820+
rowwise: bool = False,
1821+
colwise: bool = True,
1822+
scaling_mode: str = "floor",
1823+
):
1824+
# This function signature can be used to understand the shardings:
1825+
# _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True)
1826+
replicate = (
1827+
[None, Replicate(), None, Replicate()],
1828+
[None, Replicate(), None, None],
1829+
)
1830+
# Note that the data is returned transposed, which is why
1831+
# we flip the sharding dim below
1832+
shard_dim0 = ([None, Shard(1), None, Shard(1)], [None, Shard(0), None, None])
1833+
shard_dim1 = ([None, Shard(0), None, Shard(0)], [None, Shard(1), None, None])
1834+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
1835+
return acceptable_shardings
18151836
else:
18161837

18171838
def mxfp8_quantize_cuda(

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
from torch.distributed._tensor import DTensor
1616

1717
from torchao.prototype.mx_formats.config import (
18+
MXFP8CastKernelChoice,
1819
MXGemmKernelChoice,
1920
MXInferenceLinearConfig,
2021
MXLinearConfig,
2122
)
22-
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim1
23+
from torchao.prototype.mx_formats.kernels import (
24+
mxfp8_quantize_cuda,
25+
triton_to_mxfp8_dim1,
26+
)
2327
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2428
from torchao.quantization.transform_module import (
2529
register_quantize_module_handler,
@@ -66,6 +70,51 @@ def _triton_to_mxfp8_dim1_wrapper(
6670
return mx_tensor
6771

6872

73+
def _cuda_to_mxfp8_dim1_wrapper(
74+
a, block_size, elem_dtype, hp_dtype, gemm_kernel_choice
75+
):
76+
_, a_data, _, a_scale = mxfp8_quantize_cuda(
77+
a,
78+
rowwise=False,
79+
colwise=True,
80+
scaling_mode="floor",
81+
)
82+
if isinstance(a_data, DTensor):
83+
assert isinstance(a_scale, DTensor)
84+
a_data_local = a_data.to_local()
85+
a_scale_local = a_scale.to_local()
86+
inner = MXTensor(
87+
a_scale_local,
88+
a_data_local.t(),
89+
elem_dtype,
90+
block_size,
91+
hp_dtype,
92+
False,
93+
gemm_kernel_choice,
94+
False,
95+
)
96+
mx_tensor = DTensor.from_local(
97+
inner,
98+
a_data.device_mesh,
99+
a_data.placements,
100+
run_check=False,
101+
shape=a_data.t().size(),
102+
stride=a_data.t().stride(),
103+
)
104+
else:
105+
mx_tensor = MXTensor(
106+
a_scale,
107+
a_data.t(),
108+
elem_dtype,
109+
block_size,
110+
hp_dtype,
111+
False,
112+
gemm_kernel_choice,
113+
False,
114+
)
115+
return mx_tensor
116+
117+
69118
@torch._dynamo.allow_in_graph
70119
class mx_mm(torch.autograd.Function):
71120
# There are three gemms in a forward + backward of a Linear layer:
@@ -86,15 +135,15 @@ def forward(
86135
grad_elem_dtype: Any,
87136
block_size: int,
88137
gemm_kernel_choice: MXGemmKernelChoice,
89-
use_fp8_dim1_cast_triton_kernel: bool,
138+
mxfp8_cast_kernel_choice: MXFP8CastKernelChoice,
90139
):
91140
ctx.save_for_backward(input_hp, weight_hp)
92141
ctx.in_elem_dtype = in_elem_dtype
93142
ctx.w_elem_dtype = w_elem_dtype
94143
ctx.grad_elem_dtype = grad_elem_dtype
95144
ctx.block_size = block_size
96145
ctx.gemm_kernel_choice = gemm_kernel_choice
97-
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
146+
ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
98147

99148
# input @ weight_t = output
100149
input_orig_shape = input_hp.shape
@@ -119,7 +168,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
119168
grad_elem_dtype = ctx.grad_elem_dtype
120169
block_size = ctx.block_size
121170
gemm_kernel_choice = ctx.gemm_kernel_choice
122-
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel
171+
mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice
123172

124173
grad_output_orig_shape = grad_output_hp.shape
125174
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
@@ -135,10 +184,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
135184
gemm_kernel_choice=gemm_kernel_choice,
136185
)
137186

138-
if use_fp8_dim1_cast_triton_kernel:
187+
if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON:
139188
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
140189
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
141190
)
191+
elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA:
192+
weight_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper(
193+
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
194+
)
142195
else:
143196
weight_hp_t_c = weight_hp.t().contiguous()
144197
weight_mx_dim1 = MXTensor.to_mx(
@@ -153,14 +206,22 @@ def backward(ctx, grad_output_hp: torch.Tensor):
153206
)
154207

155208
# input_t @ grad_output = grad_weight
156-
if use_fp8_dim1_cast_triton_kernel:
209+
if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON:
157210
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
158211
grad_output_hp_r,
159212
block_size,
160213
grad_elem_dtype,
161214
grad_output_hp_r.dtype,
162215
gemm_kernel_choice,
163216
)
217+
elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA:
218+
grad_output_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper(
219+
grad_output_hp_r,
220+
block_size,
221+
grad_elem_dtype,
222+
grad_output_hp_r.dtype,
223+
gemm_kernel_choice,
224+
)
164225
else:
165226
grad_output_mx_dim1 = MXTensor.to_mx(
166227
grad_output_hp_r.t().contiguous(),
@@ -169,7 +230,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
169230
gemm_kernel_choice=gemm_kernel_choice,
170231
)
171232

172-
if use_fp8_dim1_cast_triton_kernel:
233+
if mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.TRITON:
173234
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
174235
input_hp_r,
175236
block_size,
@@ -178,6 +239,15 @@ def backward(ctx, grad_output_hp: torch.Tensor):
178239
gemm_kernel_choice,
179240
)
180241
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
242+
elif mxfp8_cast_kernel_choice == MXFP8CastKernelChoice.CUDA:
243+
input_t_mx_dim0_tmp = _cuda_to_mxfp8_dim1_wrapper(
244+
input_hp_r,
245+
block_size,
246+
in_elem_dtype,
247+
input_hp_r.dtype,
248+
gemm_kernel_choice,
249+
)
250+
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
181251
else:
182252
input_t_mx_dim0_tmp = MXTensor.to_mx(
183253
input_hp_r.t().contiguous(),
@@ -232,7 +302,7 @@ def forward(self, x):
232302
config.elem_dtype_grad_output_override or config.elem_dtype,
233303
config.block_size,
234304
config.gemm_kernel_choice,
235-
config.use_fp8_dim1_cast_triton_kernel,
305+
config.mxfp8_cast_kernel_choice,
236306
)
237307
if self.bias is not None:
238308
y = y + self.bias

0 commit comments

Comments
 (0)