Skip to content

Commit c21b2e5

Browse files
committed
mx: expose scaling calculation methods in training UX
Summary: Test Plan: performance on individual cast ```bash (pytorch_nightly) [[email protected] ~/local/ao (20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx_floor M 16384 K 16384 BLOCK_SIZE 32 GPU: NVIDIA B200 torch version: 2.9.0.dev20250724+cu128 triton version: 3.4.0 mode: dim0_mx_floor time_us 184.38400328159332 mem_bw_gbps 4413.045391781173 (pytorch_nightly) [[email protected] ~/local/ao (20250728_mx_expose_scale)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_mx_rceil M 16384 K 16384 BLOCK_SIZE 32 GPU: NVIDIA B200 torch version: 2.9.0.dev20250724+cu128 triton version: 3.4.0 mode: dim0_mx_rceil time_us 143.39199662208557 mem_bw_gbps 5674.619191924083 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: aec9d07 ghstack-comment-id: 3129597761 Pull-Request: #2620
1 parent d05e54f commit c21b2e5

File tree

7 files changed

+187
-39
lines changed

7 files changed

+187
-39
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def get_gemm_times(
170170
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
171171
scale_a = torch.ones(M, 1, device=device)
172172
scale_b = torch.ones(1, N, device=device)
173-
elif mx_recipe_name == "mxfp8_cublas":
173+
elif mx_recipe_name in ("mxfp8_cublas", "mxfp8_cublas_rceil"):
174174
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
175175
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
176176
else:

benchmarks/mx_formats/cast_bench.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import triton
1212
from triton.testing import do_bench
1313

14+
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1415
from torchao.prototype.mx_formats.kernels import (
1516
triton_to_mxfp8_dim1,
1617
)
@@ -53,14 +54,18 @@ def scale_dim0_dim1_reference(
5354
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
5455

5556

56-
def to_mx_dim0_reference(x_hp, block_size):
57-
scale_d0, data_d0 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
57+
def to_mx_dim0_reference(x_hp, block_size, scaling_mode=ScaleCalculationMode.FLOOR):
58+
scale_d0, data_d0 = to_mx(
59+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
60+
)
5861
return data_d0, scale_d0
5962

6063

61-
def to_mx_dim1_reference(x_hp, block_size):
64+
def to_mx_dim1_reference(x_hp, block_size, scaling_mode=ScaleCalculationMode.FLOOR):
6265
x_hp = x_hp.t().contiguous()
63-
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
66+
scale_d1, data_d1 = to_mx(
67+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
68+
)
6469
return data_d1.t(), scale_d1
6570

6671

@@ -84,7 +89,9 @@ def run(
8489
"dim1",
8590
"dim0_dim1",
8691
"dim0_mx_floor",
92+
"dim0_mx_rceil",
8793
"dim1_mx_floor",
94+
"dim1_mx_rceil",
8895
"dim1_mx_triton_floor",
8996
"dim1_mx_cuda_floor",
9097
"dim1_mx_cuda_rceil",
@@ -165,6 +172,24 @@ def run(
165172
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
166173
bps = (bytes_r + bytes_w) / (time_us / 1e6)
167174

175+
elif mode == "dim0_mx_rceil":
176+
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
177+
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
178+
179+
for _ in range(2):
180+
__ = to_mx_dim0_reference_c(x, BLOCK_SIZE)
181+
time_us = benchmark_cuda_function_in_microseconds(
182+
lambda x, b: to_mx_dim0_reference_c(x, BLOCK_SIZE),
183+
x,
184+
BLOCK_SIZE,
185+
)
186+
187+
assert y_d0.dtype == torch.float8_e4m3fn
188+
assert s_d0.dtype == torch.float8_e8m0fnu
189+
bytes_r = x.numel() * bytes_per_el_bf16
190+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
191+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
192+
168193
elif mode == "dim1_mx_floor":
169194
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
170195
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
@@ -183,6 +208,24 @@ def run(
183208
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
184209
bps = (bytes_r + bytes_w) / (time_us / 1e6)
185210

211+
elif mode == "dim1_mx_rceil":
212+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
213+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE, ScaleCalculationMode.RCEIL)
214+
215+
for _ in range(2):
216+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
217+
time_us = benchmark_cuda_function_in_microseconds(
218+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
219+
x,
220+
BLOCK_SIZE,
221+
)
222+
223+
assert y_d1.dtype == torch.float8_e4m3fn
224+
assert s_d1.dtype == torch.float8_e8m0fnu
225+
bytes_r = x.numel() * bytes_per_el_bf16
226+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
227+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
228+
186229
elif mode == "dim1_mx_triton_floor":
187230
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
188231

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MXFP8Dim1CastKernelChoice,
1515
MXLinearConfig,
1616
MXLinearRecipeName,
17+
ScaleCalculationMode,
1718
)
1819
from torchao.prototype.mx_formats.constants import (
1920
DTYPE_FP6_E2M3,
@@ -78,7 +79,18 @@ def run_around_tests():
7879
MXFP8Dim1CastKernelChoice.CUDA,
7980
],
8081
)
81-
def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice):
82+
@pytest.mark.parametrize(
83+
"scale_calculation_mode",
84+
[
85+
ScaleCalculationMode.FLOOR,
86+
ScaleCalculationMode.CEIL,
87+
ScaleCalculationMode.EVEN,
88+
ScaleCalculationMode.RCEIL,
89+
],
90+
)
91+
def test_linear_eager_vs_hp(
92+
elem_dtype, bias, input_shape, mxfp8_cast_kernel_choice, scale_calculation_mode
93+
):
8294
"""
8395
Smoke test for training linear module with mx weight, compares the following:
8496
* baseline: float32
@@ -94,6 +106,16 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho
94106
elif not is_sm_at_least_89():
95107
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
96108

109+
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
110+
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
111+
pytest.skip("unsupported configuration")
112+
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
113+
if scale_calculation_mode not in (
114+
ScaleCalculationMode.FLOOR,
115+
ScaleCalculationMode.RCEIL,
116+
):
117+
pytest.skip("unsupported configuration")
118+
97119
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
98120
grad_shape = list(input_shape)
99121
grad_shape[-1] = 256
@@ -108,6 +130,7 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho
108130
elem_dtype_weight_override=elem_dtype[1],
109131
elem_dtype_grad_output_override=elem_dtype[2],
110132
mxfp8_cast_kernel_choice=mxfp8_cast_kernel_choice,
133+
scale_calculation_mode=scale_calculation_mode,
111134
)
112135
quantize_(m_mx, config)
113136

@@ -125,9 +148,9 @@ def test_linear_eager_vs_hp(elem_dtype, bias, input_shape, mxfp8_cast_kernel_cho
125148
y_ref.backward(g)
126149
y_mx.backward(g)
127150

128-
y_sqnr = compute_error(y_ref, y_mx)
129-
w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad)
130-
x_g_sqnr = compute_error(x_ref.grad, x.grad)
151+
y_sqnr = compute_error(y_ref, y_mx).item()
152+
w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad).item()
153+
x_g_sqnr = compute_error(x_ref.grad, x.grad).item()
131154

132155
if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn):
133156
assert y_sqnr >= 18.0
@@ -229,7 +252,20 @@ def test_activation_checkpointing():
229252
MXFP8Dim1CastKernelChoice.CUDA,
230253
],
231254
)
232-
def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
255+
@pytest.mark.parametrize(
256+
"scale_calculation_mode",
257+
[
258+
ScaleCalculationMode.FLOOR,
259+
ScaleCalculationMode.CEIL,
260+
# even + compile does not work yet:
261+
# https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4
262+
# ScaleCalculationMode.EVEN,
263+
ScaleCalculationMode.RCEIL,
264+
],
265+
)
266+
def test_linear_compile(
267+
hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice, scale_calculation_mode
268+
):
233269
"""
234270
Verify that compile does not change numerics of MX linear fw + bw
235271
"""
@@ -255,6 +291,16 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
255291
if hp_dtype != torch.bfloat16:
256292
pytest.skip("unsupported configuration")
257293

294+
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
295+
if scale_calculation_mode != ScaleCalculationMode.FLOOR:
296+
pytest.skip("unsupported configuration")
297+
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
298+
if scale_calculation_mode not in (
299+
ScaleCalculationMode.FLOOR,
300+
ScaleCalculationMode.RCEIL,
301+
):
302+
pytest.skip("unsupported configuration")
303+
258304
if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas":
259305
# TODO(future PR): properly enable float32 + bfloat16 for every
260306
# recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even
@@ -269,6 +315,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, mxfp8_cast_kernel_choice):
269315
)
270316
config = MXLinearConfig.from_recipe_name(recipe_name)
271317
config.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice
318+
config.scale_calculation_mode = scale_calculation_mode
272319

273320
quantize_(m_mx, config=config)
274321
m_mx_c = copy.deepcopy(m_mx)

torchao/prototype/mx_formats/config.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,34 @@ class MXFP8Dim1CastKernelChoice(Enum):
4646
class MXLinearRecipeName(Enum):
4747
MXFP8_EMULATED = "mxfp8_emulated"
4848
MXFP8_CUBLAS = "mxfp8_cublas"
49+
MXFP8_CUBLAS_RCEIL = "mxfp8_cublas_rceil"
4950
MXFP4_EMULATED = "mxfp4_emulated"
5051
MXFP4_CUTLASS = "mxfp4_cutlass"
5152

5253

54+
class ScaleCalculationMode(Enum):
55+
"""
56+
Enum representing the different methods for calculating MX block scaling.
57+
There are three methods available:
58+
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
59+
It result in overflow issues for large values and bad for gradient quantization.
60+
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
61+
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
62+
EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
63+
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
64+
RCEIL: The method is to apply ceil to the ratio of max_abs(v) and max_pos.
65+
This method's detail is described in https://docs.nvidia.com/cuda/cublas/index.html#d-block-quantization
66+
Section "Computing scaling and conversion factors for FP8 with UE8M0 scales"
67+
"""
68+
69+
FLOOR = "floor"
70+
CEIL = "ceil"
71+
# Note: `even` does not work with torch.compile yet:
72+
# https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4
73+
EVEN = "even"
74+
RCEIL = "rceil"
75+
76+
5377
def _validate_elem_dtype(elem_dtype):
5478
assert elem_dtype in SUPPORTED_ELEM_DTYPES, (
5579
f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}"
@@ -75,6 +99,22 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
7599
)
76100

77101

102+
def _validate_mxfp8_cast_kernel_choice(
103+
mxfp8_cast_kernel_choice, scale_calculation_mode
104+
):
105+
if mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.TRITON:
106+
assert scale_calculation_mode == ScaleCalculationMode.FLOOR, (
107+
f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 triton cast"
108+
)
109+
elif mxfp8_cast_kernel_choice == MXFP8Dim1CastKernelChoice.CUDA:
110+
assert scale_calculation_mode in (
111+
ScaleCalculationMode.FLOOR,
112+
ScaleCalculationMode.RCEIL,
113+
), (
114+
f"unsupported ScaleCalculationMode value {scale_calculation_mode} for dim1 cuda cast"
115+
)
116+
117+
78118
@dataclass
79119
class MXLinearConfig(AOBaseConfig):
80120
# block size for scaling, default is 32 to match
@@ -104,6 +144,8 @@ class MXLinearConfig(AOBaseConfig):
104144
# If True, uses a custom triton kernel for fp4 dequantize
105145
use_fp4_custom_triton_dequant_kernel: bool = False
106146

147+
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR
148+
107149
def __post_init__(self):
108150
_validate_elem_dtype(self.elem_dtype)
109151
_validate_gemm_kernel_choice(
@@ -115,6 +157,9 @@ def __post_init__(self):
115157
if self.elem_dtype_grad_output_override is not None:
116158
_validate_elem_dtype(self.elem_dtype_grad_output_override)
117159
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
160+
_validate_mxfp8_cast_kernel_choice(
161+
self.mxfp8_cast_kernel_choice, self.scale_calculation_mode
162+
)
118163

119164
@staticmethod
120165
def from_recipe_name(
@@ -134,7 +179,14 @@ def from_recipe_name(
134179
if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
135180
return MXLinearConfig()
136181
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
182+
# TODO(future PR): default to CUDA dim1 kernel
137183
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
184+
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL:
185+
return MXLinearConfig(
186+
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
187+
mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
188+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
189+
)
138190
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
139191
return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2)
140192
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
@@ -160,4 +212,6 @@ def short_str(self) -> str:
160212
s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}"
161213
if self.use_fp4_custom_triton_dequant_kernel:
162214
s += ", use_fp4_custom_triton_dequant_kernel=True"
215+
if self.scale_calculation_mode != ScaleCalculationMode.FLOOR:
216+
s += ", scale_calculation_mode={self.scale_calculation_mode}"
163217
return s

0 commit comments

Comments
 (0)