Skip to content

Commit d05e54f

Browse files
authored
reorganize MX inference code (#2616)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 344d201 commit d05e54f

File tree

5 files changed

+692
-246
lines changed

5 files changed

+692
-246
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
13+
from torchao.prototype.mx_formats.config import (
14+
MXGemmKernelChoice,
15+
)
16+
from torchao.prototype.mx_formats.inference_workflow import (
17+
MXFPInferenceConfig,
18+
NVFP4InferenceConfig,
19+
NVFP4MMConfig,
20+
)
21+
from torchao.quantization import quantize_
22+
from torchao.quantization.utils import compute_error
23+
from torchao.testing.utils import skip_if_rocm
24+
from torchao.utils import (
25+
TORCH_VERSION_AT_LEAST_2_8,
26+
is_sm_at_least_89,
27+
is_sm_at_least_100,
28+
)
29+
30+
torch.manual_seed(2)
31+
32+
if not TORCH_VERSION_AT_LEAST_2_8:
33+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
34+
35+
36+
# source: https://stackoverflow.com/a/22638709
37+
@pytest.fixture(autouse=True)
38+
def run_around_tests():
39+
# 1. before test - set up (currently do nothing)
40+
# 2. run test
41+
yield
42+
# 3. after test - teardown
43+
torch._dynamo.reset()
44+
45+
46+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
47+
@pytest.mark.skipif(
48+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
49+
)
50+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
51+
@pytest.mark.parametrize("bias", [True, False])
52+
@pytest.mark.parametrize("compile", [True, False])
53+
@torch.no_grad()
54+
@skip_if_rocm(
55+
"ROCm float4 gemm require gfx950"
56+
) # TODO(future): deploy gfx950 in ROCM CI
57+
@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required")
58+
def test_inference_workflow(elem_dtype, bias: bool, compile: bool):
59+
"""
60+
Smoke test for inference compile
61+
"""
62+
# TODO(future): figure out why these CUDA capability conditions are not properly
63+
# applied when inside `pytest.mark.skipif` for this test
64+
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
65+
if not is_sm_at_least_89():
66+
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
67+
elif elem_dtype == torch.float4_e2m1fn_x2:
68+
if not is_sm_at_least_100():
69+
pytest.skip("CUDA capability >= 10.0 required for float4 gemm")
70+
71+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
72+
m_mx = copy.deepcopy(m)
73+
kernel_choice = (
74+
MXGemmKernelChoice.CUTLASS
75+
if elem_dtype == torch.float4_e2m1fn_x2
76+
else MXGemmKernelChoice.CUBLAS
77+
)
78+
config = MXFPInferenceConfig(
79+
activation_dtype=elem_dtype,
80+
weight_dtype=elem_dtype,
81+
gemm_kernel_choice=kernel_choice,
82+
)
83+
quantize_(m_mx, config=config)
84+
if compile:
85+
m_mx = torch.compile(m_mx, fullgraph=True)
86+
87+
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
88+
y_ref = m(x)
89+
y_mx = m_mx(x)
90+
sqnr = compute_error(y_ref, y_mx)
91+
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
92+
assert sqnr >= SQNR_THRESHOLD, (
93+
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
94+
)
95+
96+
97+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98+
@pytest.mark.skipif(
99+
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
100+
)
101+
@pytest.mark.parametrize("bias", [True, False])
102+
@pytest.mark.parametrize("compile", [True, False])
103+
@pytest.mark.parametrize(
104+
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
105+
)
106+
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
107+
@pytest.mark.parametrize("use_triton_kernel", [True, False])
108+
@pytest.mark.parametrize(
109+
"shapes",
110+
[
111+
(128, 64, 256),
112+
(256, 128, 512),
113+
(145, 64, 256),
114+
(128, 96, 256),
115+
(128, 160, 256),
116+
(64, 64, 256),
117+
(200, 192, 256),
118+
],
119+
ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}",
120+
)
121+
@torch.no_grad()
122+
@skip_if_rocm("ROCm float4 gemm require gfx950")
123+
def test_inference_workflow_nvfp4(
124+
bias: bool,
125+
compile: bool,
126+
mm_config: NVFP4MMConfig,
127+
inpt_dtype: torch.dtype,
128+
use_triton_kernel: bool,
129+
shapes: tuple,
130+
):
131+
"""
132+
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
133+
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
134+
"""
135+
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
136+
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
137+
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
138+
139+
if bias and inpt_dtype == torch.float32:
140+
pytest.xfail("Bias is not supported when module weight is in fp32")
141+
142+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
143+
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
144+
batch_size, in_features, out_features = shapes
145+
146+
m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda")
147+
m_mx = copy.deepcopy(m)
148+
149+
config = NVFP4InferenceConfig(
150+
mm_config=mm_config, use_triton_kernel=use_triton_kernel
151+
)
152+
quantize_(m_mx, config=config)
153+
154+
if compile:
155+
m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager")
156+
157+
x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype)
158+
y_ref = m(x)
159+
y_mx = m_mx(x)
160+
sqnr = compute_error(y_ref, y_mx)
161+
162+
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
163+
SQNR_THRESHOLD = 18.0
164+
else:
165+
SQNR_THRESHOLD = 15.0
166+
167+
assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}"
168+
assert sqnr >= SQNR_THRESHOLD, (
169+
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
170+
)

0 commit comments

Comments
 (0)