-
Notifications
You must be signed in to change notification settings - Fork 338
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Introduce new tensor subclass API for int8 quantization with clearer interface. The main change can be summarized to the following: - Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling - New: Direct int8 tensor with qdata, scale, and zero_point attributes Test plan: test/quantization/quantize_/workflows/int8/test_int8_tensor.py Future plan: Implement block-wise quantization using `block_size` parameter
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
|
||
# TODO: Implement block-wise quantization using block_size | ||
class Int8PlainInt8Tensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can just use Int8Tensor
if it's plain, since that's the default
can you add a version 2 and expose this tensor through ao/torchao/quantization/quant_api.py Line 1497 in 8525185
ao/torchao/quantization/quant_api.py Line 1752 in 8525185
|
args[2] if len(args) > 2 else None, | ||
) | ||
|
||
if isinstance(input_tensor, Int8PlainInt8Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need to quantize input_tensor in this function now, please check
ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Lines 263 to 266 in 9d88c16
if act_quant_kwargs is not None: | |
input_tensor = _choose_quant_func_and_quantize_tensor( | |
input_tensor, act_quant_kwargs | |
) |
This ticket looks good, but since I need to learn more about granularity, it might be delayed. May I split this into separate PRs? |
x_int32 = input_tensor.qdata.to(torch.int32) | ||
w_int32 = weight_tensor.qdata.to(torch.int32).t() | ||
|
||
result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32) | ||
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0) | ||
result = result.to(scale.dtype) * scale | ||
result = result.view(*input_tensor.shape[:-1], -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not the same as
ao/torchao/dtypes/uintx/plain_layout.py
Lines 269 to 315 in 122b307
def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): | |
return ( | |
isinstance(input_tensor, AffineQuantizedTensor) | |
and _aqt_is_int8_reduced_range(input_tensor) | |
and isinstance(weight_tensor, AffineQuantizedTensor) | |
and _aqt_is_int8(weight_tensor) | |
and input_tensor.dtype == weight_tensor.dtype | |
and isinstance(input_tensor._layout, PlainLayout) | |
and isinstance(weight_tensor._layout, PlainLayout) | |
) | |
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): | |
# | |
# 1. do the matrix form of dot(X_i, W_j) | |
# | |
# | |
# 2. rescale the output | |
# | |
# in cases with large matrices, y_dot_int32 can grow sufficiently | |
# large that y_dot_int32 * a float16 scale is greater than the maximum | |
# value of a float 16, (which results in a value of inf even if multiplying | |
# by the other scale would bring it within the expected range) | |
x_vals_int8 = input_tensor.tensor_impl.int_data | |
x_scales = input_tensor.tensor_impl.scale | |
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() | |
w_scales = weight_tensor.tensor_impl.scale | |
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) | |
x_scales_dtype = x_scales.dtype | |
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul | |
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype | |
y_dot_scaled = int_scaled_matmul( | |
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) | |
) | |
y_dot_scaled = y_dot_scaled.to(x_scales_dtype) | |
y = (y_dot_scaled * w_scales).reshape( | |
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] | |
) | |
# can downcast only at the very end | |
output_dtype = input_tensor.dtype | |
y = y.to(output_dtype) | |
if bias is not None: | |
y += bias | |
return y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test to check the kernel that's used similar to
def test_expected_gpu_kernel_fbgemm(self): |
result = result.to(scale.dtype) * scale | ||
result = result.view(*input_tensor.shape[:-1], -1) | ||
else: | ||
# FP × INT8 (static) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this is the code for weight only quant I think:
ao/torchao/dtypes/uintx/plain_layout.py
Line 250 in 122b307
def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
block_size (Optional[list[int]]): block size for quantization granularity | ||
""" | ||
|
||
kernel_preference: KernelPreference = KernelPreference.AUTO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like no multiple kernel preferences right now right? if so, we can remove this for now
|
||
|
||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
class TestInt8Tensor(TorchAOIntegrationTestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for test, maybe try to follow https://github.com/pytorch/ao/blob/main/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py for now and also add some tests for slicing?
ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Lines 158 to 216 in 8e2ca35
def test_slice(self, granularity): | |
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) | |
dtype = torch.bfloat16 | |
device = "cuda" | |
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) | |
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) | |
dummy1.weight = torch.nn.Parameter( | |
dummy.weight.narrow(0, 0, 64), requires_grad=False | |
) | |
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) | |
dummy2.weight = torch.nn.Parameter( | |
dummy.weight.narrow(1, 0, 128), requires_grad=False | |
) | |
quantize_(dummy, config) | |
weight1 = dummy.weight.clone().narrow(0, 0, 64) | |
weight2 = dummy.weight.clone().narrow(1, 0, 128) | |
self.assertEqual( | |
weight1.qdata, | |
dummy.weight.qdata.narrow(0, 0, 64), | |
) | |
self.assertEqual( | |
weight2.qdata, | |
dummy.weight.qdata.narrow(1, 0, 128), | |
) | |
if isinstance(granularity, PerRow): | |
self.assertEqual( | |
weight1.scale, | |
dummy.weight.scale.narrow(0, 0, 64), | |
) | |
self.assertEqual( | |
weight2.scale, | |
dummy.weight.scale, | |
) | |
else: | |
self.assertEqual( | |
weight1.scale, | |
dummy.weight.scale, | |
) | |
self.assertEqual( | |
weight2.scale, | |
dummy.weight.scale, | |
) | |
# check for sliced weight, before and after float8 quantization | |
# does not differ too much | |
input = torch.randn(2, 256, dtype=dtype, device=device) | |
res_ref = dummy1(input) | |
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False) | |
res = dummy(input) | |
sqnr = compute_error(res, res_ref) | |
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}") | |
input = torch.randn(2, 128, dtype=dtype, device=device) | |
res_ref = dummy2(input) | |
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False) | |
res = dummy(input) | |
sqnr = compute_error(res, res_ref) | |
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}") |
def test_slice_preserves_aliasing(self, granularity): |
Summary
Introduce new tensor subclass API for int8 quantization with clearer interface.
The main change can be summarized to the following:
AffineQuantizedTensor
) with separate layout handlingRelated Issue/PR: #3012 (comment) #2752
Test plan
test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Future plan
Implement block-wise quantization using
block_size