Skip to content

Commit 5adbce8

Browse files
committed
mx: expose scaling calculation methods in training UX
Summary: Test Plan: performance on individual cast ```bash (pytorch_nightly) [[email protected] ~/local/ao (20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx_floor M 16384 K 16384 BLOCK_SIZE 32 GPU: NVIDIA B200 torch version: 2.9.0.dev20250724+cu128 triton version: 3.4.0 mode: dim0_mx_floor time_us 184.38400328159332 mem_bw_gbps 4413.045391781173 (pytorch_nightly) [[email protected] ~/local/ao (20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx_rceil M 16384 K 16384 BLOCK_SIZE 32 GPU: NVIDIA B200 torch version: 2.9.0.dev20250724+cu128 triton version: 3.4.0 mode: dim0_mx_rceil time_us 143.39199662208557 mem_bw_gbps 5674.619191924083 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0531997 ghstack-comment-id: 3129597761 Pull-Request: #2620
1 parent d05e54f commit 5adbce8

File tree

8 files changed

+201
-42
lines changed

8 files changed

+201
-42
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def get_gemm_times(
170170
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
171171
scale_a = torch.ones(M, 1, device=device)
172172
scale_b = torch.ones(1, N, device=device)
173-
elif mx_recipe_name == "mxfp8_cublas":
173+
elif mx_recipe_name in ("mxfp8_cublas", "mxfp8_cublas_rceil"):
174174
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
175175
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
176176
else:

benchmarks/mx_formats/cast_bench.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import triton
1212
from triton.testing import do_bench
1313

14+
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1415
from torchao.prototype.mx_formats.kernels import (
1516
triton_to_mxfp8_dim1,
1617
)
@@ -53,14 +54,18 @@ def scale_dim0_dim1_reference(
5354
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
5455

5556

56-
def to_mx_dim0_reference(x_hp, block_size):
57-
scale_d0, data_d0 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
57+
def to_mx_dim0_reference(x_hp, block_size, scaling_mode=ScaleCalculationMode.FLOOR):
58+
scale_d0, data_d0 = to_mx(
59+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
60+
)
5861
return data_d0, scale_d0
5962

6063

61-
def to_mx_dim1_reference(x_hp, block_size):
64+
def to_mx_dim1_reference(x_hp, block_size, scaling_mode=ScaleCalculationMode.FLOOR):
6265
x_hp = x_hp.t().contiguous()
63-
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
66+
scale_d1, data_d1 = to_mx(
67+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
68+
)
6469
return data_d1.t(), scale_d1
6570

6671

@@ -84,7 +89,9 @@ def run(
8489
"dim1",
8590
"dim0_dim1",
8691
"dim0_mx_floor",
92+
"dim0_mx_rceil",
8793
"dim1_mx_floor",
94+
"dim1_mx_rceil",
8895
"dim1_mx_triton_floor",
8996
"dim1_mx_cuda_floor",
9097
"dim1_mx_cuda_rceil",
@@ -165,6 +172,24 @@ def run(
165172
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
166173
bps = (bytes_r + bytes_w) / (time_us / 1e6)
167174

175+
elif mode == "dim0_mx_rceil":
176+
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
177+
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
178+
179+
for _ in range(2):
180+
__ = to_mx_dim0_reference_c(x, BLOCK_SIZE)
181+
time_us = benchmark_cuda_function_in_microseconds(
182+
lambda x, b: to_mx_dim0_reference_c(x, BLOCK_SIZE),
183+
x,
184+
BLOCK_SIZE,
185+
)
186+
187+
assert y_d0.dtype == torch.float8_e4m3fn
188+
assert s_d0.dtype == torch.float8_e8m0fnu
189+
bytes_r = x.numel() * bytes_per_el_bf16
190+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
191+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
192+
168193
elif mode == "dim1_mx_floor":
169194
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
170195
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
@@ -183,6 +208,24 @@ def run(
183208
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
184209
bps = (bytes_r + bytes_w) / (time_us / 1e6)
185210

211+
elif mode == "dim1_mx_rceil":
212+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
213+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
214+
215+
for _ in range(2):
216+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
217+
time_us = benchmark_cuda_function_in_microseconds(
218+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
219+
x,
220+
BLOCK_SIZE,
221+
)
222+
223+
assert y_d1.dtype == torch.float8_e4m3fn
224+
assert s_d1.dtype == torch.float8_e8m0fnu
225+
bytes_r = x.numel() * bytes_per_el_bf16
226+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
227+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
228+
186229
elif mode == "dim1_mx_triton_floor":
187230
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
188231

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MXFP8Dim1CastKernelChoice,
1515
MXLinearConfig,
1616
MXLinearRecipeName,
17+
ScaleCalculationMode,
1718
)
1819
from torchao.prototype.mx_formats.constants import (
1920
DTYPE_FP6_E2M3,
@@ -78,7 +79,18 @@ def run_around_tests():
7879
MXFP8Dim1CastKernelChoice.CUDA,
7980
],
8081
)
81-
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
82+
@pytest.mark.parametrize(
83+
"scale_calculation_mode",
84+
[
85+
ScaleCalculationMode.FLOOR,
86+
ScaleCalculationMode.CEIL,
87+
ScaleCalculationMode.EVEN,
88+
ScaleCalculationMode.RCEIL,
89+
],
90+
)
91+
def test_linear_eager_vs_hp(
92+
elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice, scale_calculation_mode
93+
):
8294
"""
8395
Smoke test for training linear module with mx weight, compares the following:
8496
* baseline: float32
@@ -94,6 +106,16 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho
94106
elif not is_sm_at_least_89():
95107
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
96108

109+
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
110+
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
111+
pytest.skip("unsupported configuration")
112+
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
113+
if scale_calculation_mode not in (
114+
ScaleCalculationMode.FLOOR,
115+
ScaleCalculationMode.RCEIL,
116+
):
117+
pytest.skip("unsupported configuration")
118+
97119
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
98120
grad_shape = list(input_shape)
99121
grad_shape[-1] = 256
@@ -108,6 +130,7 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho
108130
elem_dtype_weight_override=elem_dtype[1],
109131
elem_dtype_grad_output_override=elem_dtype[2],
110132
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
133+
scale_calculation_mode=scale_calculation_mode,
111134
)
112135
quantize_(m_mx, config)
113136

@@ -125,9 +148,9 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho
125148
y_ref.backward(g)
126149
y_mx.backward(g)
127150

128-
y_sqnr = compute_error(y_ref, y_mx)
129-
w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad)
130-
x_g_sqnr = compute_error(x_ref.grad, x.grad)
151+
y_sqnr = compute_error(y_ref, y_mx).item()
152+
w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad).item()
153+
x_g_sqnr = compute_error(x_ref.grad, x.grad).item()
131154

132155
if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn):
133156
assert y_sqnr >= 18.0
@@ -229,7 +252,20 @@ def test_activation_checkpointing():
229252
MXFP8Dim1CastKernelChoice.CUDA,
230253
],
231254
)
232-
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
255+
@pytest.mark.parametrize(
256+
"scale_calculation_mode",
257+
[
258+
ScaleCalculationMode.FLOOR,
259+
ScaleCalculationMode.CEIL,
260+
# even + compile does not work yet:
261+
# https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4
262+
# ScaleCalculationMode.EVEN,
263+
ScaleCalculationMode.RCEIL,
264+
],
265+
)
266+
def test_linear_compile(
267+
hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice, scale_calculation_mode
268+
):
233269
"""
234270
Verify that compile does not change numerics of MX linear fw + bw
235271
"""
@@ -255,6 +291,16 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
255291
if hp_dtype != torch.bfloat16:
256292
pytest.skip("unsupported configuration")
257293

294+
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
295+
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
296+
pytest.skip("unsupported configuration")
297+
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
298+
if scale_calculation_mode not in (
299+
ScaleCalculationMode.FLOOR,
300+
ScaleCalculationMode.RCEIL,
301+
):
302+
pytest.skip("unsupported configuration")
303+
258304
if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas":
259305
# TODO(future PR): properly enable float32 + bfloat16 for every
260306
# recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even
@@ -269,6 +315,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
269315
)
270316
config = MXLinearConfig.from_recipe_name(recipe_name)
271317
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
318+
config.scale_calculation_mode = scale_calculation_mode
272319

273320
quantize_(m_mx, config=config)
274321
m_mx_c = copy.deepcopy(m_mx)

torchao/prototype/mx_formats/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,23 @@ We plan to add the following features in the near future:
2323
```python
2424
import torch
2525
from torchao.quantization import quantize_
26-
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice
26+
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode
2727

2828
# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels
2929
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
3030
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
31-
3231
# on older NVIDIA gpus, you can run training with emulated MX gemm
3332
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED
3433

34+
scale_calculation_mode = ScaleCalculationMode.FLOOR
35+
# other supported modes: RCEIL, CEIL, EVEN
36+
3537
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
3638
config = MXLinearConfig(
3739
elem_dtype=torch.float8_e4m3fn,
3840
block_size=32,
3941
gemm_kernel_choice=gemm_kernel_choice,
42+
scale_calculation_mode=scale_calculation_mode,
4043
)
4144
quantize_(m, config)
4245

torchao/prototype/mx_formats/config.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,39 @@ class MXFP8Dim1CastKernelChoice(Enum):
4646
class MXLinearRecipeName(Enum):
4747
MXFP8_EMULATED = "mxfp8_emulated"
4848
MXFP8_CUBLAS = "mxfp8_cublas"
49+
MXFP8_CUBLAS_RCEIL = "mxfp8_cublas_rceil"
4950
MXFP4_EMULATED = "mxfp4_emulated"
5051
MXFP4_CUTLASS = "mxfp4_cutlass"
5152

5253

54+
class ScaleCalculationMode(Enum):
55+
"""
56+
Enum representing the different methods for calculating MX block scaling.
57+
There are four methods available:
58+
59+
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
60+
It result in overflow issues for large values and bad for gradient quantization.
61+
62+
RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos.
63+
This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization
64+
Section "Computing scaling and conversion factors for FP8 with UE8M0 scales"
65+
66+
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
67+
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
68+
69+
EVEN: This method is a trade-off between FLOOR and CEIL. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
70+
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
71+
Note: EVEN does not work with torch.compile yet:
72+
https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4
73+
74+
"""
75+
76+
FLOOR = "floor"
77+
RCEIL = "rceil"
78+
CEIL = "ceil"
79+
EVEN = "even"
80+
81+
5382
def _validate_elem_dtype(elem_dtype):
5483
assert elem_dtype in SUPPORTED_ELEM_DTYPES, (
5584
f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}"
@@ -75,6 +104,22 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
75104
)
76105

77106

107+
def _validate_mxfp8_cast_kernel_choice(
108+
mxfp8_cast_kernel_choice, scale_calculation_mode
109+
):
110+
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
111+
assert scale_calculation_mode == ScaleCalculationMode.FLOOR, (
112+
f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 triton cast"
113+
)
114+
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
115+
assert scale_calculation_mode in (
116+
ScaleCalculationMode.FLOOR,
117+
ScaleCalculationMode.RCEIL,
118+
), (
119+
f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 cuda cast"
120+
)
121+
122+
78123
@dataclass
79124
class MXLinearConfig(AOBaseConfig):
80125
# block size for scaling, default is 32 to match
@@ -104,6 +149,8 @@ class MXLinearConfig(AOBaseConfig):
104149
# If True, uses a custom triton kernel for fp4 dequantize
105150
use_fp4_custom_triton_dequant_kernel: bool = False
106151

152+
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR
153+
107154
def __post_init__(self):
108155
_validate_elem_dtype(self.elem_dtype)
109156
_validate_gemm_kernel_choice(
@@ -115,6 +162,9 @@ def __post_init__(self):
115162
if self.elem_dtype_grad_output_override is not None:
116163
_validate_elem_dtype(self.elem_dtype_grad_output_override)
117164
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
165+
_validate_mxfp8_cast_kernel_choice(
166+
self.mxfp8_cast_kernel_choice, self.scale_calculation_mode
167+
)
118168

119169
@staticmethod
120170
def from_recipe_name(
@@ -134,7 +184,17 @@ def from_recipe_name(
134184
if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
135185
return MXLinearConfig()
136186
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
137-
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
187+
# TODO(future PR): default to CUDA dim1 kernel
188+
return MXLinearConfig(
189+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
190+
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
191+
)
192+
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
193+
return MXLinearConfig(
194+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
195+
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
196+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
197+
)
138198
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
139199
return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2)
140200
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
@@ -160,4 +220,6 @@ def short_str(self) -> str:
160220
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
161221
if self.use_fp4_custom_triton_dequant_kernel:
162222
s += ", use_fp4_custom_triton_dequant_kernel=True"
223+
if self.scale_calculation_mode != ScaleCalculationMode.FLOOR:
224+
s += f", scale_calculation_mode={self.scale_calculation_mode}"
163225
return s

0 commit comments

Comments
 (0)