Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 21, 2025

Summary

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:

  • Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
  • New: Direct int8 tensor with scaling factor and zero point

Related Issue/PR: #3012 (comment) #2752

Test plan

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Future plan

Implement block-wise quantization using block_size

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:
- Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
- New: Direct int8 tensor with qdata, scale, and zero_point attributes

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Future plan:
Implement block-wise quantization using `block_size` parameter
Copy link

pytorch-bot bot commented Sep 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 21, 2025


# TODO: Implement block-wise quantization using block_size
class Int8PlainInt8Tensor(TorchAOBaseTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can just use Int8Tensor if it's plain, since that's the default

@jerryzh168
Copy link
Contributor

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
? similar to
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

args[2] if len(args) > 2 else None,
)

if isinstance(input_tensor, Int8PlainInt8Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to quantize input_tensor in this function now, please check

if act_quant_kwargs is not None:
input_tensor = _choose_quant_func_and_quantize_tensor(
input_tensor, act_quant_kwargs
)

@namgyu-youn namgyu-youn changed the title Add Int8PlainInt8Tensor for clearer interface Add Int8Tensor for clearer interface Sep 23, 2025
@namgyu-youn
Copy link
Contributor Author

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):

? similar to

class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

This ticket looks good, but since I need to learn more about granularity, it might be delayed. May I split this into separate PRs?

Comment on lines +176 to +182
x_int32 = input_tensor.qdata.to(torch.int32)
w_int32 = weight_tensor.qdata.to(torch.int32).t()

result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32)
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0)
result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the same as

def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_int8(weight_tensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)
def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)
# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test to check the kernel that's used similar to

def test_expected_gpu_kernel_fbgemm(self):
as well?

result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
else:
# FP × INT8 (static)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this is the code for weight only quant I think:

def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

block_size (Optional[list[int]]): block size for quantization granularity
"""

kernel_preference: KernelPreference = KernelPreference.AUTO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like no multiple kernel preferences right now right? if so, we can remove this for now



@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
class TestInt8Tensor(TorchAOIntegrationTestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for test, maybe try to follow https://github.com/pytorch/ao/blob/main/test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py for now and also add some tests for slicing?

def test_slice(self, granularity):
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
dummy1.weight = torch.nn.Parameter(
dummy.weight.narrow(0, 0, 64), requires_grad=False
)
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
dummy2.weight = torch.nn.Parameter(
dummy.weight.narrow(1, 0, 128), requires_grad=False
)
quantize_(dummy, config)
weight1 = dummy.weight.clone().narrow(0, 0, 64)
weight2 = dummy.weight.clone().narrow(1, 0, 128)
self.assertEqual(
weight1.qdata,
dummy.weight.qdata.narrow(0, 0, 64),
)
self.assertEqual(
weight2.qdata,
dummy.weight.qdata.narrow(1, 0, 128),
)
if isinstance(granularity, PerRow):
self.assertEqual(
weight1.scale,
dummy.weight.scale.narrow(0, 0, 64),
)
self.assertEqual(
weight2.scale,
dummy.weight.scale,
)
else:
self.assertEqual(
weight1.scale,
dummy.weight.scale,
)
self.assertEqual(
weight2.scale,
dummy.weight.scale,
)
# check for sliced weight, before and after float8 quantization
# does not differ too much
input = torch.randn(2, 256, dtype=dtype, device=device)
res_ref = dummy1(input)
dummy.weight = torch.nn.Parameter(weight1.contiguous(), requires_grad=False)
res = dummy(input)
sqnr = compute_error(res, res_ref)
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}")
input = torch.randn(2, 128, dtype=dtype, device=device)
res_ref = dummy2(input)
dummy.weight = torch.nn.Parameter(weight2.contiguous(), requires_grad=False)
res = dummy(input)
sqnr = compute_error(res, res_ref)
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")
and
def test_slice_preserves_aliasing(self, granularity):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants