|
40 | 40 | except RuntimeError:
|
41 | 41 | pytest.skip("torchao.ops not available")
|
42 | 42 |
|
| 43 | +from torchao.quantization import PerGroup, PerRow, PerTensor |
| 44 | +from torchao.quantization.quant_primitives import ( |
| 45 | + _choose_scale_float8, |
| 46 | + _dequantize_affine_float8, |
| 47 | + _quantize_affine_float8, |
| 48 | +) |
43 | 49 | from torchao.quantization.utils import (
|
| 50 | + get_block_size, |
44 | 51 | get_groupwise_affine_qparams,
|
45 | 52 | groupwise_affine_dequantize_tensor_from_qparams,
|
46 | 53 | groupwise_affine_quantize_tensor_from_qparams,
|
@@ -905,5 +912,61 @@ def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type
|
905 | 912 | torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
|
906 | 913 |
|
907 | 914 |
|
| 915 | +@pytest.mark.skipif( |
| 916 | + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu") |
| 917 | + or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), |
| 918 | + reason="cpp kernels not built", |
| 919 | +) |
| 920 | +@pytest.mark.skipif( |
| 921 | + not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+" |
| 922 | +) |
| 923 | +@pytest.mark.parametrize("shape", [(64, 64), (256, 256)]) |
| 924 | +@pytest.mark.parametrize("bs", [1, 160]) |
| 925 | +@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half]) |
| 926 | +@pytest.mark.parametrize("bias", [True, False]) |
| 927 | +@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)]) |
| 928 | +@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)]) |
| 929 | +def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity): |
| 930 | + in_feature, out_feature = shape |
| 931 | + if isinstance(x_granularity, PerGroup): |
| 932 | + if x_granularity.group_size >= in_feature: |
| 933 | + return |
| 934 | + if not isinstance(w_granularity, PerGroup): |
| 935 | + return |
| 936 | + if isinstance(w_granularity, PerGroup): |
| 937 | + if w_granularity.group_size >= in_feature: |
| 938 | + return |
| 939 | + m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval() |
| 940 | + b = m.bias |
| 941 | + x = torch.randn(bs, in_feature) |
| 942 | + x_block_size = get_block_size(x.shape, x_granularity) |
| 943 | + x_scale = _choose_scale_float8( |
| 944 | + x, |
| 945 | + float8_dtype=torch.float8_e4m3fn, |
| 946 | + block_size=x_block_size, |
| 947 | + ) |
| 948 | + x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn) |
| 949 | + |
| 950 | + w = m.weight.detach() |
| 951 | + w_block_size = get_block_size(w.shape, w_granularity) |
| 952 | + w_scale = _choose_scale_float8( |
| 953 | + w, |
| 954 | + float8_dtype=torch.float8_e4m3fn, |
| 955 | + block_size=w_block_size, |
| 956 | + ) |
| 957 | + w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn) |
| 958 | + |
| 959 | + x_dq = _dequantize_affine_float8(x_fp8, x_scale) |
| 960 | + w_dq = _dequantize_affine_float8(w_fp8, w_scale) |
| 961 | + ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype) |
| 962 | + |
| 963 | + packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale) |
| 964 | + y = torch.ops.torchao.float8_linear_cpu( |
| 965 | + x_fp8, x_scale, packed_w, packed_scale, b, out_dtype |
| 966 | + ) |
| 967 | + |
| 968 | + torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2) |
| 969 | + |
| 970 | + |
908 | 971 | if __name__ == "__main__":
|
909 | 972 | pytest.main(sys.argv)
|
0 commit comments