-
Notifications
You must be signed in to change notification settings - Fork 368
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
da7cfea
27076d3
cdb1d9f
9f1b6c9
9301717
caaba7a
0f51ee6
305c3a9
3ab38ba
b516304
0c2bb76
d11af10
027afd8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,266 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the BSD 3-Clause license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|
|
||||
| import copy | ||||
| import unittest | ||||
|
|
||||
| import torch | ||||
| from torch.testing._internal import common_utils | ||||
|
|
||||
| from torchao.quantization import ( | ||||
| Int8DynamicActivationInt8WeightConfig, | ||||
| Int8WeightOnlyConfig, | ||||
| PerRow, | ||||
| PerTensor, | ||||
| quantize_, | ||||
| ) | ||||
| from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine | ||||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||||
| Int8Tensor, | ||||
| QuantizeTensorToInt8Kwargs, | ||||
| ) | ||||
| from torchao.quantization.utils import compute_error | ||||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||||
|
|
||||
|
|
||||
| # TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged | ||||
| class ToyTwoLinearModel(torch.nn.Module): | ||||
| def __init__( | ||||
| self, | ||||
| input_dim, | ||||
| hidden_dim, | ||||
| output_dim, | ||||
| has_bias=False, | ||||
| dtype=None, | ||||
| device=None, | ||||
| ): | ||||
| super().__init__() | ||||
| self.dtype = dtype | ||||
| self.device = device | ||||
| self.linear1 = torch.nn.Linear( | ||||
| input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device | ||||
| ) | ||||
| self.linear2 = torch.nn.Linear( | ||||
| hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device | ||||
| ) | ||||
|
|
||||
| def forward(self, x): | ||||
| x = self.linear1(x) | ||||
| x = self.linear2(x) | ||||
| return x | ||||
|
|
||||
|
|
||||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||||
| @common_utils.instantiate_parametrized_tests | ||||
| class TestInt8Tensor(TorchAOIntegrationTestCase): | ||||
| def setUp(self): | ||||
| super().setUp() | ||||
|
|
||||
| self.test_shape = (4, 3) | ||||
| self.dtype = torch.bfloat16 | ||||
| self.batch_size = 32 | ||||
| self.int8_min = -128 | ||||
| self.int8_max = 127 | ||||
|
|
||||
| torch.manual_seed(42) | ||||
| self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype) | ||||
| self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype) | ||||
| self.bias = torch.randn(self.test_shape[0], dtype=self.dtype) | ||||
| self.block_size = list(self.test_shape) | ||||
|
||||
|
|
||||
| def test_creation_and_attributes(self): | ||||
| """Test tensor creation, dtypes, and ranges""" | ||||
| tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||||
|
|
||||
| self.assertEqual(tensor.shape, self.test_shape) | ||||
| self.assertEqual(tensor.qdata.dtype, torch.int8) | ||||
| self.assertTrue( | ||||
| torch.all(tensor.qdata >= self.int8_min) | ||||
| and torch.all(tensor.qdata <= self.int8_max) | ||||
| ) | ||||
|
|
||||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||||
| @common_utils.parametrize( | ||||
| "sizes", | ||||
| [ | ||||
| ((128,), 256, 128), | ||||
| ], | ||||
| ) | ||||
| @common_utils.parametrize( | ||||
| "config", | ||||
| [ | ||||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||||
| Int8WeightOnlyConfig(version=2), | ||||
| ], | ||||
| ) | ||||
| def test_int8_linear_quantization_accuracy( | ||||
namgyu-youn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| self, | ||||
| dtype: torch.dtype, | ||||
| sizes: tuple, | ||||
| config, | ||||
| ): | ||||
| """Test quantization preserves reasonable accuracy""" | ||||
| M, N, K = sizes | ||||
| input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") | ||||
|
|
||||
| # Create a linear layer | ||||
| m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") | ||||
| m_q = copy.deepcopy(m) | ||||
|
|
||||
| # Quantize | ||||
| quantize_(m_q, config) | ||||
|
|
||||
| output_original = m(input_tensor) | ||||
| output_quantized = m_q(input_tensor) | ||||
|
|
||||
| error = compute_error(output_original, output_quantized) | ||||
| assert error > 20, ( | ||||
| f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)" | ||||
| ) | ||||
|
|
||||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||||
| def test_quantization_shapes(self, dtype): | ||||
|
||||
| @common_utils.parametrize("mode", ["dynamic", "weight-only"]) |
also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.
namgyu-youn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be per row I think?
ao/torchao/quantization/quant_api.py
Line 1343 in 6815e57
| group_size = weight.shape[-1] |
namgyu-youn marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe check sqnr with compute_error instead of hardcoded absolute value?
Uh oh!
There was an error while loading. Please reload this page.