Skip to content

Commit 5e90c47

Browse files
authored
[CPU] Add ops for float8 linear (#3052)
* [CPU] Add ops for float8 linear * Refine code
1 parent b47f1a3 commit 5e90c47

File tree

4 files changed

+851
-0
lines changed

4 files changed

+851
-0
lines changed

test/test_ops.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@
4040
except RuntimeError:
4141
pytest.skip("torchao.ops not available")
4242

43+
from torchao.quantization import PerGroup, PerRow, PerTensor
44+
from torchao.quantization.quant_primitives import (
45+
_choose_scale_float8,
46+
_dequantize_affine_float8,
47+
_quantize_affine_float8,
48+
)
4349
from torchao.quantization.utils import (
50+
get_block_size,
4451
get_groupwise_affine_qparams,
4552
groupwise_affine_dequantize_tensor_from_qparams,
4653
groupwise_affine_quantize_tensor_from_qparams,
@@ -905,5 +912,61 @@ def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type
905912
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)
906913

907914

915+
@pytest.mark.skipif(
916+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_prepack_cpu")
917+
or "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
918+
reason="cpp kernels not built",
919+
)
920+
@pytest.mark.skipif(
921+
not torch_version_at_least("2.6.0"), reason="Test only enabled for 2.6+"
922+
)
923+
@pytest.mark.parametrize("shape", [(64, 64), (256, 256)])
924+
@pytest.mark.parametrize("bs", [1, 160])
925+
@pytest.mark.parametrize("out_dtype", [torch.float, torch.bfloat16, torch.half])
926+
@pytest.mark.parametrize("bias", [True, False])
927+
@pytest.mark.parametrize("x_granularity", [PerTensor(), PerRow(), PerGroup(128)])
928+
@pytest.mark.parametrize("w_granularity", [PerTensor(), PerRow(), PerGroup(128)])
929+
def test_float8_linear_cpu(shape, bs, out_dtype, bias, x_granularity, w_granularity):
930+
in_feature, out_feature = shape
931+
if isinstance(x_granularity, PerGroup):
932+
if x_granularity.group_size >= in_feature:
933+
return
934+
if not isinstance(w_granularity, PerGroup):
935+
return
936+
if isinstance(w_granularity, PerGroup):
937+
if w_granularity.group_size >= in_feature:
938+
return
939+
m = torch.nn.Linear(in_feature, out_feature, bias=bias).eval()
940+
b = m.bias
941+
x = torch.randn(bs, in_feature)
942+
x_block_size = get_block_size(x.shape, x_granularity)
943+
x_scale = _choose_scale_float8(
944+
x,
945+
float8_dtype=torch.float8_e4m3fn,
946+
block_size=x_block_size,
947+
)
948+
x_fp8 = _quantize_affine_float8(x, x_scale, torch.float8_e4m3fn)
949+
950+
w = m.weight.detach()
951+
w_block_size = get_block_size(w.shape, w_granularity)
952+
w_scale = _choose_scale_float8(
953+
w,
954+
float8_dtype=torch.float8_e4m3fn,
955+
block_size=w_block_size,
956+
)
957+
w_fp8 = _quantize_affine_float8(w, w_scale, torch.float8_e4m3fn)
958+
959+
x_dq = _dequantize_affine_float8(x_fp8, x_scale)
960+
w_dq = _dequantize_affine_float8(w_fp8, w_scale)
961+
ref = torch.nn.functional.linear(x_dq, w_dq, b).to(out_dtype)
962+
963+
packed_w, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu(w_fp8, w_scale)
964+
y = torch.ops.torchao.float8_linear_cpu(
965+
x_fp8, x_scale, packed_w, packed_scale, b, out_dtype
966+
)
967+
968+
torch.testing.assert_close(y, ref, atol=1e-2, rtol=1e-2)
969+
970+
908971
if __name__ == "__main__":
909972
pytest.main(sys.argv)

0 commit comments

Comments
 (0)