Skip to content

Commit f321317

Browse files
committed
Update base for Update on "Remove old QAT APIs"
**Summary:** As a follow-up to #2641, which deprecated the old QAT APIs in 0.13.0, we remove them now in the next release 0.15.0. Fixes #2630. **Test Plan:** CI [ghstack-poisoned]
2 parents b0a4f39 + 6c24a7a commit f321317

File tree

12 files changed

+982
-27
lines changed

12 files changed

+982
-27
lines changed

.github/workflows/doc_build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ jobs:
4343
- name: Install dependencies
4444
run: |
4545
python -m pip install torch
46+
python -m pip install setuptools==78.1.1 --force-reinstall
4647
python -m pip install -e .
4748
pip install -r dev-requirements.txt
4849
python -m pip install -r docs/requirements.txt

benchmarks/float8/float8_roofline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,12 @@ def run(
245245
bf16_gemm_time_sympy = get_gemm_time_sympy(
246246
M, K, N, torch.bfloat16, None, None, None
247247
)
248+
lowp_input_dtype = torch.float8_e4m3fn
249+
if mx_recipe_name == "mxfp4_cutlass":
250+
lowp_input_dtype = torch.float4_e2m1fn_x2
251+
248252
fp8_gemm_time_sympy = get_gemm_time_sympy(
249-
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name, None
253+
M, K, N, lowp_input_dtype, float8_recipe_name, mx_recipe_name, None
250254
)
251255
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
252256
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
@@ -304,6 +308,8 @@ def run(
304308
rb_fp8_gemm_ratio = -1
305309

306310
if do_benchmarks:
311+
assert mx_recipe_name != "mxfp4_cutlass", "unsupported"
312+
307313
# TODO(future): make the bf16 gemm times exactly match the e2e
308314
# benchmarks, there is a slight deviation, probably related to gemm
309315
# operand memory formats/transpositions below not exactly matching

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
115115
dn_quant(up_quant(example_input))
116116

117117
mesh = self.build_device_mesh()
118-
mesh.device_type = "cuda"
118+
mesh._device_type = "cuda"
119119

120120
# Shard the models
121121
up_dist = self.colwise_shard(up_quant, mesh)

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
import copy
88
import tempfile
9+
from contextlib import contextmanager
910

1011
import pytest
1112
import torch
1213
import torch.nn as nn
14+
from torch.profiler import ProfilerActivity, profile
1315

1416
from torchao.prototype.mx_formats.config import (
1517
MXGemmKernelChoice,
@@ -44,6 +46,23 @@ def run_around_tests():
4446
torch._dynamo.reset()
4547

4648

49+
@contextmanager
50+
def cuda_kernel_profiler(kernel_pattern):
51+
"""Context manager for profiling CUDA kernels."""
52+
result = {"found": False, "kernel_names": []}
53+
54+
with profile(activities=[ProfilerActivity.CUDA]) as prof:
55+
yield result
56+
57+
kernel_names = [
58+
evt.name
59+
for evt in prof.events()
60+
if evt.device_type == torch.autograd.DeviceType.CUDA and evt.name
61+
]
62+
result["kernel_names"] = kernel_names
63+
result["found"] = any(kernel_pattern in name for name in kernel_names)
64+
65+
4766
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
4867
@pytest.mark.skipif(
4968
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
@@ -178,7 +197,14 @@ def test_inference_workflow_nvfp4(
178197

179198
x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype)
180199
y_ref = m(x)
181-
y_mx = m_mx(x)
200+
201+
if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY:
202+
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
203+
y_mx = m_mx(x)
204+
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
205+
else:
206+
y_mx = m_mx(x)
207+
182208
sqnr = compute_error(y_ref, y_mx)
183209

184210
if mm_config == NVFP4MMConfig.WEIGHT_ONLY:

test/quantization/pt2e/test_quantize_pt2e_qat.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,12 +686,6 @@ def get_source_fn(node: torch.fx.Node):
686686
self.assertNotEqual(get_source_fn(second_conv), get_source_fn(second_relu))
687687
self.assertNotEqual(get_source_fn(first_relu), get_source_fn(second_relu))
688688

689-
# Assert that "backbone" exists only in the second set of conv and relu's partition
690-
self.assertTrue("backbone" not in get_source_fn(first_conv))
691-
self.assertTrue("backbone" not in get_source_fn(first_relu))
692-
self.assertTrue("backbone" in get_source_fn(second_conv))
693-
self.assertTrue("backbone" in get_source_fn(second_relu))
694-
695689
def test_qat_conv_bn_bias_derived_qspec(self):
696690
m = self._get_conv_bn_model()
697691
example_inputs = self.example_inputs

test/test_ops.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
except RuntimeError:
4141
pytest.skip("torchao.ops not available")
4242

43+
from torchao.quantization import PerGroup, PerRow, PerTensor
44+
from torchao.quantization.quant_primitives import (
45+
_choose_scale_float8,
46+
_dequantize_affine_float8,
47+
_quantize_affine_float8,
48+
)
4349
from torchao.quantization.utils import (
50+
get_block_size,
4451
get_groupwise_affine_qparams,
4552
groupwise_affine_dequantize_tensor_from_qparams,
4653
groupwise_affine_quantize_tensor_from_qparams,
@@ -901,5 +908,91 @@ def _test_scaled_embedding_bag_cpu_helper(
901908
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
902909

903910

911+
@pytest.mark.skipif(
912+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
913+
reason="cpp kernels not built",
914+
)
915+
@pytest.mark.parametrize(
916+
"multi_hot, batch_size, vector_size, index_type",
917+
EMBEDINGBAG_TEST_PARAMS,
918+
ids=str,
919+
)
920+
def test_scaled_embedding_bag_int8_cpu(multi_hot, batch_size, vector_size, index_type):
921+
_test_scaled_embedding_bag_cpu_helper(
922+
multi_hot, batch_size, vector_size, index_type, torch.int8
923+
)
924+
925+
926+
@pytest.mark.skipif(
927+
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
928+
reason="cpp kernels not built",
929+
)
930+
@pytest.mark.parametrize(
931+
"multi_hot, batch_size, vector_size, index_type",
932+
EMBEDINGBAG_TEST_PARAMS,
933+
ids=str,
934+
)
935+
def test_scaled_embedding_bag_fp8_cpu(multi_hot, batch_size, vector_size, index_type):
936+
_test_scaled_embedding_bag_cpu_helper(
937+
multi_hot, batch_size, vector_size, index_type, torch.float8_e4m3fn
938+
)
939+
940+
941+
@pytest.mark.skipif(
942+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu")
943+
or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
944+
reason="cpp kernels not built",
945+
)
946+
@pytest.mark.skipif(
947+
not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+"
948+
)
949+
@pytest.mark.parametrize("shape", [(64, 64), (256, 256)])
950+
@pytest.mark.parametrize("bs", [1, 160])
951+
@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half])
952+
@pytest.mark.parametrize("bias", [True, False])
953+
@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)])
954+
@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)])
955+
def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity):
956+
in_feature, out_feature = shape
957+
if isinstance(x_granularity, PerGroup):
958+
if x_granularity.group_size >= in_feature:
959+
return
960+
if not isinstance(w_granularity, PerGroup):
961+
return
962+
if isinstance(w_granularity, PerGroup):
963+
if w_granularity.group_size >= in_feature:
964+
return
965+
m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval()
966+
b = m.bias
967+
x = torch.randn(bs, in_feature)
968+
x_block_size = get_block_size(x.shape, x_granularity)
969+
x_scale = _choose_scale_float8(
970+
x,
971+
float8_dtype=torch.float8_e4m3fn,
972+
block_size=x_block_size,
973+
)
974+
x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn)
975+
976+
w = m.weight.detach()
977+
w_block_size = get_block_size(w.shape, w_granularity)
978+
w_scale = _choose_scale_float8(
979+
w,
980+
float8_dtype=torch.float8_e4m3fn,
981+
block_size=w_block_size,
982+
)
983+
w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn)
984+
985+
x_dq = _dequantize_affine_float8(x_fp8, x_scale)
986+
w_dq = _dequantize_affine_float8(w_fp8, w_scale)
987+
ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype)
988+
989+
packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale)
990+
y = torch.ops.torchao.float8_linear_cpu(
991+
x_fp8, x_scale, packed_w, packed_scale, b, out_dtype
992+
)
993+
994+
torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2)
995+
996+
904997
if __name__ == "__main__":
905998
pytest.main(sys.argv)

0 commit comments

Comments
 (0)