Skip to content

Commit 97e9a7b

Browse files
integrate mxfp8 dim1 cast kernel choice enum into MXLinear
1 parent d858130 commit 97e9a7b

File tree

4 files changed

+57
-21
lines changed

4 files changed

+57
-21
lines changed

test/prototype/mx_formats/test_mx_dtensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tqdm import tqdm
2626

2727
from torchao.prototype.mx_formats import MXLinearConfig
28+
from torchao.prototype.mx_formats.config import MXFP8Dim1CastKernelChoice
2829
from torchao.prototype.mx_formats.mx_tensor import MXTensor
2930
from torchao.testing.training.dtensor_utils import (
3031
_test_lowp_mlp_tensor_parallelism_base,
@@ -82,7 +83,7 @@ def _test_mxfp8_mlp_tensor_parallelism(mesh: DeviceMesh, size=128):
8283
def _test_mxfp8_mlp_tensor_parallelism_dim1_triton(mesh: DeviceMesh, size=128):
8384
config = MXLinearConfig.from_recipe_name("mxfp8_emulated")
8485
config.block_size = 32
85-
config.use_fp8_dim1_cast_triton_kernel = True
86+
config.mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice.CUDA
8687
_test_lowp_mlp_tensor_parallelism_base(
8788
mesh, config, size, compile=False, allgather_in_lowp=False
8889
)

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313

1414
from torchao.prototype.mx_formats.config import (
15+
MXFP8Dim1CastKernelChoice,
1516
MXGemmKernelChoice,
1617
MXInferenceLinearConfig,
1718
MXLinearConfig,
@@ -81,16 +82,19 @@ def run_around_tests():
8182
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
8283
@pytest.mark.parametrize("bias", [True, False])
8384
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
84-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
85+
@pytest.mark.parametrize(
86+
"mxfp8_dim1_cast_kernel_choice",
87+
[None, MXFP8Dim1CastKernelChoice.TRITON, MXFP8Dim1CastKernelChoice.CUDA],
88+
)
8589
def test_linear_eager_vs_hp(
86-
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
90+
elem_dtype, bias, input_shape, mxfp8_dim1_cast_kernel_choice
8791
):
8892
"""
8993
Smoke test for training linear module with mx weight, compares the following:
9094
* baseline: float32
9195
* experiment: emulated MX
9296
"""
93-
if use_fp8_dim1_cast_triton_kernel:
97+
if mxfp8_dim1_cast_kernel_choice is not None:
9498
if elem_dtype != (
9599
torch.float8_e4m3fn,
96100
torch.float8_e4m3fn,
@@ -109,11 +113,11 @@ def test_linear_eager_vs_hp(
109113
)
110114
m_mx = copy.deepcopy(m)
111115
config = MXLinearConfig(
112-
block_size=4,
116+
block_size=32, # Only 32 is supported for now
113117
elem_dtype=elem_dtype[0],
114118
elem_dtype_weight_override=elem_dtype[1],
115119
elem_dtype_grad_output_override=elem_dtype[2],
116-
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
120+
mxfp8_dim1_cast_kernel_choice=mxfp8_dim1_cast_kernel_choice,
117121
)
118122
quantize_(m_mx, config)
119123

@@ -227,8 +231,11 @@ def test_activation_checkpointing():
227231
@pytest.mark.parametrize("bias", [False, True])
228232
# TODO(future PR): figure out why torch.compile does not match eager when
229233
# autocast is on
230-
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
231-
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
234+
@pytest.mark.parametrize(
235+
"mxfp8_dim1_cast_kernel_choice",
236+
[None, MXFP8Dim1CastKernelChoice.TRITON, MXFP8Dim1CastKernelChoice.CUDA],
237+
)
238+
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_dim1_cast_kernel_choice):
232239
"""
233240
Verify that compile does not change numerics of MX linear fw + bw
234241
"""
@@ -246,7 +253,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
246253
# TODO(future PR): fix this, things are clearly broken with bias=True
247254
pytest.skip("this test is broken for non-emulated recipes with bias=True")
248255

249-
if use_fp8_dim1_cast_triton_kernel:
256+
if mxfp8_dim1_cast_kernel_choice is not None:
250257
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas"):
251258
pytest.skip("unsupported configuration")
252259
if not is_sm_at_least_89():
@@ -267,7 +274,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
267274
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
268275
)
269276
config = MXLinearConfig.from_recipe_name(recipe_name)
270-
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
277+
config.mxfp8_dim1_cast_kernel_choice = mxfp8_dim1_cast_kernel_choice
271278

272279
quantize_(m_mx, config=config)
273280
m_mx_c = copy.deepcopy(m_mx)

torchao/prototype/mx_formats/config.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ class MXGemmKernelChoice(Enum):
3333
CUBLAS = "cublas"
3434

3535

36+
class MXFP8Dim1CastKernelChoice(Enum):
37+
TRITON = "triton"
38+
CUDA = "cuda"
39+
40+
3641
# Pre-made recipes for common configurations
3742
class MXLinearRecipeName(Enum):
3843
MXFP8_EMULATED = "mxfp8_emulated"
@@ -85,10 +90,12 @@ class MXLinearConfig(AOBaseConfig):
8590
# on the given hardware an exception will be thrown
8691
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
8792

88-
# If True, uses a custom triton kernel for cast to mxfp8 across dim1
93+
# define which kernel to use for dim1 cast
8994
# TODO(1945): remove this config option once torch.compile gives us
9095
# a fast kernel
91-
use_fp8_dim1_cast_triton_kernel: bool = False
96+
mxfp8_dim1_cast_kernel_choice: Optional[MXFP8Dim1CastKernelChoice] = (
97+
MXFP8Dim1CastKernelChoice.TRITON
98+
)
9299

93100
# If True, uses a custom triton kernel for fp4 dequantize
94101
use_fp4_custom_triton_dequant_kernel: bool = False
@@ -146,8 +153,7 @@ def short_str(self) -> str:
146153
if self.elem_dtype_grad_output_override is not None:
147154
s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}"
148155
s += f", kernel={self.gemm_kernel_choice.value}"
149-
if self.use_fp8_dim1_cast_triton_kernel:
150-
s += ", use_fp8_dim1_cast_triton_kernel=True"
156+
s += f", mxfp8_dim1_cast_kernel_choice={self.mxfp8_dim1_cast_kernel_choice.value}"
151157
if self.use_fp4_custom_triton_dequant_kernel:
152158
s += ", use_fp4_custom_triton_dequant_kernel=True"
153159
return s

torchao/prototype/mx_formats/mx_linear.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributed._tensor import DTensor
1616

1717
from torchao.prototype.mx_formats.config import (
18+
MXFP8Dim1CastKernelChoice,
1819
MXGemmKernelChoice,
1920
MXInferenceLinearConfig,
2021
MXLinearConfig,
@@ -134,15 +135,15 @@ def forward(
134135
grad_elem_dtype: Any,
135136
block_size: int,
136137
gemm_kernel_choice: MXGemmKernelChoice,
137-
use_fp8_dim1_cast_triton_kernel: bool,
138+
mxfp8_dim1_cast_kernel_choice: MXFP8Dim1CastKernelChoice,
138139
):
139140
ctx.save_for_backward(input_hp, weight_hp)
140141
ctx.in_elem_dtype = in_elem_dtype
141142
ctx.w_elem_dtype = w_elem_dtype
142143
ctx.grad_elem_dtype = grad_elem_dtype
143144
ctx.block_size = block_size
144145
ctx.gemm_kernel_choice = gemm_kernel_choice
145-
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
146+
ctx.mxfp8_dim1_cast_kernel_choice = mxfp8_dim1_cast_kernel_choice
146147

147148
# input @ weight_t = output
148149
input_orig_shape = input_hp.shape
@@ -167,7 +168,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
167168
grad_elem_dtype = ctx.grad_elem_dtype
168169
block_size = ctx.block_size
169170
gemm_kernel_choice = ctx.gemm_kernel_choice
170-
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel
171+
mxfp8_dim1_cast_kernel_choice = ctx.mxfp8_dim1_cast_kernel_choice
171172

172173
grad_output_orig_shape = grad_output_hp.shape
173174
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
@@ -183,10 +184,14 @@ def backward(ctx, grad_output_hp: torch.Tensor):
183184
gemm_kernel_choice=gemm_kernel_choice,
184185
)
185186

186-
if use_fp8_dim1_cast_triton_kernel:
187+
if mxfp8_dim1_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
187188
weight_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
188189
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
189190
)
191+
elif mxfp8_dim1_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
192+
weight_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper(
193+
weight_hp, block_size, w_elem_dtype, weight_hp.dtype, gemm_kernel_choice
194+
)
190195
else:
191196
weight_hp_t_c = weight_hp.t().contiguous()
192197
weight_mx_dim1 = MXTensor.to_mx(
@@ -201,14 +206,22 @@ def backward(ctx, grad_output_hp: torch.Tensor):
201206
)
202207

203208
# input_t @ grad_output = grad_weight
204-
if use_fp8_dim1_cast_triton_kernel:
209+
if mxfp8_dim1_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
205210
grad_output_mx_dim1 = _triton_to_mxfp8_dim1_wrapper(
206211
grad_output_hp_r,
207212
block_size,
208213
grad_elem_dtype,
209214
grad_output_hp_r.dtype,
210215
gemm_kernel_choice,
211216
)
217+
elif mxfp8_dim1_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
218+
grad_output_mx_dim1 = _cuda_to_mxfp8_dim1_wrapper(
219+
grad_output_hp_r,
220+
block_size,
221+
grad_elem_dtype,
222+
grad_output_hp_r.dtype,
223+
gemm_kernel_choice,
224+
)
212225
else:
213226
grad_output_mx_dim1 = MXTensor.to_mx(
214227
grad_output_hp_r.t().contiguous(),
@@ -217,7 +230,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
217230
gemm_kernel_choice=gemm_kernel_choice,
218231
)
219232

220-
if use_fp8_dim1_cast_triton_kernel:
233+
if mxfp8_dim1_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
221234
input_t_mx_dim0_tmp = _triton_to_mxfp8_dim1_wrapper(
222235
input_hp_r,
223236
block_size,
@@ -226,6 +239,15 @@ def backward(ctx, grad_output_hp: torch.Tensor):
226239
gemm_kernel_choice,
227240
)
228241
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
242+
elif mxfp8_dim1_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
243+
input_t_mx_dim0_tmp = _cuda_to_mxfp8_dim1_wrapper(
244+
input_hp_r,
245+
block_size,
246+
in_elem_dtype,
247+
input_hp_r.dtype,
248+
gemm_kernel_choice,
249+
)
250+
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
229251
else:
230252
input_t_mx_dim0_tmp = MXTensor.to_mx(
231253
input_hp_r.t().contiguous(),
@@ -280,7 +302,7 @@ def forward(self, x):
280302
config.elem_dtype_grad_output_override or config.elem_dtype,
281303
config.block_size,
282304
config.gemm_kernel_choice,
283-
config.use_fp8_dim1_cast_triton_kernel,
305+
config.mxfp8_dim1_cast_kernel_choice,
284306
)
285307
if self.bias is not None:
286308
y = y + self.bias

0 commit comments

Comments
 (0)