Skip to content

Commit bb930b6

Browse files
integration of new mxfp8 casting cuda kernel
stack-info: PR: #2564, branch: danielvegamyhre/stack/13
1 parent 95d13d5 commit bb930b6

File tree

5 files changed

+164
-27
lines changed

5 files changed

+164
-27
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tqdm import tqdm
2626

2727
from torchao.prototype.mx_formats import MXLinearConfig
28+
from torchao.prototype.mx_formats.config import MXFP8CastKernelChoice
2829
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2930
from torchao.testing.training.dtensor_utils import (
3031
_test_lowp_mlp_tensor_parallelism_base,
@@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
8283
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
8384
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
8485
config.block_size = 32
85-
config.use_fp8_dim1_cast_triton_kernel = True
86+
config.mxfp8_cast_kernel_choice = MXFP8CastKernelChoice.TRITON
8687
_test_lowp_mlp_tensor_parallelism_base(
8788
mesh, config, size, compile=False, allgather_in_lowp=False
8889
)
@@ -93,12 +94,22 @@ def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
9394
# )
9495

9596

97+
def _test_mxfp8_mlp_tensor_parallelism_dim1_cuda(mesh: DeviceMesh, size=128):
98+
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
99+
config.block_size = 32
100+
config.mxfp8_cast_kernel_choice = MXFP8CastKernelChoice.CUDA
101+
_test_lowp_mlp_tensor_parallelism_base(
102+
mesh, config, size, compile=False, allgather_in_lowp=False
103+
)
104+
105+
96106
if __name__ == "__main__":
97107
device_mesh = setup_distributed()
98108
tests = [
99109
_test_dtensor_cast_to_mxfp8,
100110
_test_mxfp8_mlp_tensor_parallelism,
101111
_test_mxfp8_mlp_tensor_parallelism_dim1_triton,
112+
_test_mxfp8_mlp_tensor_parallelism_dim1_cuda,
102113
]
103114

104115
for test in tqdm(tests, desc="Running tests"):

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313

1414
from torchao.prototype.mx_formats.config import (
15+
MXFP8CastKernelChoice,
1516
MXGemmKernelChoice,
1617
MXInferenceLinearConfig,
1718
MXLinearConfig,
@@ -81,16 +82,21 @@ def run_around_tests():
8182
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
8283
@pytest.mark.parametrize("bias", [True, False])
8384
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
84-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
85-
def test_linear_eager_vs_hp(
86-
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
87-
):
85+
@pytest.mark.parametrize(
86+
"mxfp8_cast_kernel_choice",
87+
[
88+
MXFP8CastKernelChoice.TORCH,
89+
MXFP8CastKernelChoice.TRITON,
90+
MXFP8CastKernelChoice.CUDA,
91+
],
92+
)
93+
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
8894
"""
8995
Smoke test for training linear module with mx weight, compares the following:
9096
* baseline: float32
9197
* experiment: emulated MX
9298
"""
93-
if use_fp8_dim1_cast_triton_kernel:
99+
if mxfp8_cast_kernel_choice != MXFP8CastKernelChoice.TORCH:
94100
if elem_dtype != (
95101
torch.float8_e4m3fn,
96102
torch.float8_e4m3fn,
@@ -109,11 +115,11 @@ def test_linear_eager_vs_hp(
109115
)
110116
m_mx = copy.deepcopy(m)
111117
config = MXLinearConfig(
112-
block_size=4,
118+
block_size=32, # Only 32 is supported for now
113119
elem_dtype=elem_dtype[0],
114120
elem_dtype_weight_override=elem_dtype[1],
115121
elem_dtype_grad_output_override=elem_dtype[2],
116-
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
122+
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
117123
)
118124
quantize_(m_mx, config)
119125

@@ -227,8 +233,15 @@ def test_activation_checkpointing():
227233
@pytest.mark.parametrize("bias", [False, True])
228234
# TODO(future PR): figure out why torch.compile does not match eager when
229235
# autocast is on
230-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
231-
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
236+
@pytest.mark.parametrize(
237+
"mxfp8_cast_kernel_choice",
238+
[
239+
MXFP8CastKernelChoice.TORCH,
240+
MXFP8CastKernelChoice.TRITON,
241+
MXFP8CastKernelChoice.CUDA,
242+
],
243+
)
244+
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
232245
"""
233246
Verify that compile does not change numerics of MX linear fw + bw
234247
"""
@@ -246,7 +259,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
246259
# TODO(future PR): fix this, things are clearly broken with bias=True
247260
pytest.skip("this test is broken for non-emulated recipes with bias=True")
248261

249-
if use_fp8_dim1_cast_triton_kernel:
262+
if mxfp8_cast_kernel_choice != MXFP8CastKernelChoice.TORCH:
250263
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
251264
pytest.skip("unsupported configuration")
252265
if not is_sm_at_least_89():
@@ -267,7 +280,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
267280
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
268281
)
269282
config = MXLinearConfig.from_recipe_name(recipe_name)
270-
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
283+
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
271284

272285
quantize_(m_mx, config=config)
273286
m_mx_c = copy.deepcopy(m_mx)

torchao/prototype/mx_formats/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ class MXGemmKernelChoice(Enum):
3333
CUBLAS = "cublas"
3434

3535

36+
class MXFP8CastKernelChoice(Enum):
37+
TRITON = "triton"
38+
CUDA = "cuda"
39+
TORCH = "torch"
40+
41+
3642
# Pre-made recipes for common configurations
3743
class MXLinearRecipeName(Enum):
3844
MXFP8_EMULATED = "mxfp8_emulated"
@@ -85,10 +91,10 @@ class MXLinearConfig(AOBaseConfig):
8591
# on the given hardware an exception will be thrown
8692
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
8793

88-
# If True, uses a custom triton kernel for cast to mxfp8 across dim1
94+
# define which kernel to use for mxfp8 casting
8995
# TODO(1945): remove this config option once torch.compile gives us
9096
# a fast kernel
91-
use_fp8_dim1_cast_triton_kernel: bool = False
97+
mxfp8_cast_kernel_choice: MXFP8CastKernelChoice = MXFP8CastKernelChoice.TRITON
9298

9399
# If True, uses a custom triton kernel for fp4 dequantize
94100
use_fp4_custom_triton_dequant_kernel: bool = False
@@ -146,8 +152,7 @@ def short_str(self) -> str:
146152
if self.elem_dtype_grad_output_override is not None:
147153
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
148154
s += f", kernel={self.gemm_kernel_choice.value}"
149-
if self.use_fp8_dim1_cast_triton_kernel:
150-
s += ", use_fp8_dim1_cast_triton_kernel=True"
155+
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
151156
if self.use_fp4_custom_triton_dequant_kernel:
152157
s += ", use_fp4_custom_triton_dequant_kernel=True"
153158
return s

torchao/prototype/mx_formats/kernels.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,6 +1404,7 @@ def triton_scale_swizzle(
14041404
scale_cols,
14051405
output_ptr,
14061406
input_row_stride,
1407+
input_col_stride,
14071408
output_block_stride,
14081409
BLOCK_ROWS: tl.constexpr,
14091410
BLOCK_COLS: tl.constexpr,
@@ -1423,7 +1424,7 @@ def triton_scale_swizzle(
14231424
mask = (global_rows < scale_rows) & (global_cols < scale_cols)
14241425

14251426
input_scales = tl.load(
1426-
scale_ptr + global_rows * input_row_stride + global_cols,
1427+
scale_ptr + global_rows * input_row_stride + global_cols * input_col_stride,
14271428
mask=mask,
14281429
other=0.0,
14291430
)
@@ -1463,7 +1464,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14631464
assert scale_tensor.element_size() == 1, (
14641465
"Expected element size to be 1 byte (8 bits)"
14651466
)
1466-
assert scale_tensor.is_contiguous(), "Input tensor must be contiguous"
14671467

14681468
rows, cols = scale_tensor.shape
14691469

@@ -1476,7 +1476,8 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14761476
out = scale_tensor.new_empty((padded_rows, padded_cols))
14771477

14781478
# Input stride (for row-major format)
1479-
input_row_stride = cols
1479+
input_row_stride = scale_tensor.stride()[0]
1480+
input_col_stride = scale_tensor.stride()[1]
14801481

14811482
# We probably want handle multiple blocks per tile but for now keep it simple
14821483
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -1495,6 +1496,7 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
14951496
cols,
14961497
out.view(torch.uint8),
14971498
input_row_stride,
1499+
input_col_stride,
14981500
output_block_stride,
14991501
BLOCK_ROWS=BLOCK_ROWS,
15001502
BLOCK_COLS=BLOCK_COLS,
@@ -1812,6 +1814,42 @@ def _(
18121814

18131815
return output_rowwise, output_colwise, scales_rowwise, scales_colwise
18141816

1817+
@register_sharding(torch.ops.torchao.mxfp8_quantize_cuda.default)
1818+
def custom_mxfp8_quantize_cuda_dim1_sharding(
1819+
x: torch.Tensor,
1820+
rowwise: bool = False,
1821+
colwise: bool = True,
1822+
scaling_mode: str = "floor",
1823+
):
1824+
# This function signature can be used to understand the shardings:
1825+
# _, colwise_data, _, colwise_scales = mxfp8_quantize_cuda(x, rowwise=False, colwise=True)
1826+
1827+
# When inputs and scale are replicated, we return a quantized output tensor (replicated).
1828+
inputs_replicated = [None, Replicate(), None, Replicate()]
1829+
outputs_replicated = [None, Replicate(), None, None]
1830+
rule_for_input_replicated = (
1831+
inputs_replicated,
1832+
outputs_replicated,
1833+
)
1834+
1835+
# When inputs and scale are sharded along dim 0,
1836+
# we return a quantized output tensor (sharded along dim1 due to transpose).
1837+
inputs_sharded_dim0 = [None, Shard(0), None, Shard(0)]
1838+
outputs_sharded_dim1 = [None, Shard(1), None, None]
1839+
rule_for_input_sharded_dim0 = (inputs_sharded_dim0, outputs_sharded_dim1)
1840+
1841+
# When inputs and scale are sharded along dim 1,
1842+
# we return a quantized output tensor (sharded along dim0 due to transpose).
1843+
inputs_sharded_dim1 = [None, Shard(1), None, Shard(1)]
1844+
outputs_sharded_dim0 = [None, Shard(0), None, None]
1845+
rule_for_input_sharded_dim1 = (inputs_sharded_dim1, outputs_sharded_dim0)
1846+
1847+
acceptable_shardings = [
1848+
rule_for_input_replicated,
1849+
rule_for_input_sharded_dim0,
1850+
rule_for_input_sharded_dim1,
1851+
]
1852+
return acceptable_shardings
18151853
else:
18161854

18171855
def mxfp8_quantize_cuda(

0 commit comments

Comments
 (0)