Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/quantization_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ First we want to lay out the torchao stack::

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
---------------------------------------------------------------------------------------------
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
- scaled int4
- preshuffled (special format to optimize for loading)
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
* - Int8Tensor
- plain

.. note::
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Expand Down
266 changes: 266 additions & 0 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from torch.testing._internal import common_utils

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PerRow,
PerTensor,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
Int8Tensor,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase


# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged
class ToyTwoLinearModel(torch.nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
output_dim,
has_bias=False,
dtype=None,
device=None,
):
super().__init__()
self.dtype = dtype
self.device = device
self.linear1 = torch.nn.Linear(
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device
)
self.linear2 = torch.nn.Linear(
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device
)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
class TestInt8Tensor(TorchAOIntegrationTestCase):
def setUp(self):
super().setUp()

self.test_shape = (4, 3)
self.dtype = torch.bfloat16
self.batch_size = 32
self.int8_min = -128
self.int8_max = 127

torch.manual_seed(42)
self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
self.block_size = list(self.test_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we probably don't need these, it's also easier for people to follow to define everything / most of things in the test itself


def test_creation_and_attributes(self):
"""Test tensor creation, dtypes, and ranges"""
tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size)

self.assertEqual(tensor.shape, self.test_shape)
self.assertEqual(tensor.qdata.dtype, torch.int8)
self.assertTrue(
torch.all(tensor.qdata >= self.int8_min)
and torch.all(tensor.qdata <= self.int8_max)
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
],
)
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
def test_int8_linear_quantization_accuracy(
self,
dtype: torch.dtype,
sizes: tuple,
config,
):
"""Test quantization preserves reasonable accuracy"""
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

# Create a linear layer
m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda")
m_q = copy.deepcopy(m)

# Quantize
quantize_(m_q, config)

output_original = m(input_tensor)
output_quantized = m_q(input_tensor)

error = compute_error(output_original, output_quantized)
assert error > 20, (
f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)"
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_quantization_shapes(self, dtype):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:

@common_utils.parametrize("mode", ["dynamic", "weight-only"])

also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.

"""Test static and dynamic quantization output shapes"""
K, N = 128, 64
weight = torch.randn(N, K, dtype=dtype, device="cuda")
input_tensor = torch.randn(self.batch_size, K, dtype=dtype, device="cuda")

# Dynamic quantization (runtime scale computation)
dynamic_tensor = Int8Tensor.from_hp(weight, block_size=[N, K])

# Static quantization (pre-computed scale)
act_scale, _ = choose_qparams_affine(
input=input_tensor,
mapping_type=MappingType.SYMMETRIC,
block_size=(input_tensor.shape[0], K),
target_dtype=torch.int8,
quant_min=self.int8_min,
quant_max=self.int8_max,
scale_dtype=dtype,
zero_point_dtype=torch.int8,
)

# Static quantization (with pre-computed scale)
static_tensor = Int8Tensor.from_hp(
weight,
block_size=[N, K],
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
block_size=[input_tensor.shape[0], K],
static_scale=act_scale,
),
)

dynamic_output = torch.nn.functional.linear(input_tensor, dynamic_tensor)
static_output = torch.nn.functional.linear(input_tensor, static_tensor)

expected_shape = (self.batch_size, N)
self.assertEqual(dynamic_output.shape, expected_shape)
self.assertEqual(static_output.shape, expected_shape)
self.assertEqual(dynamic_output.dtype, dtype)
self.assertEqual(static_output.dtype, dtype)

@unittest.skip("granularity parameter not supported in current API")
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_slice_preserves_aliasing(self, granularity):
slice_size = 512
tensor_size = 1024

config = Int8DynamicActivationInt8WeightConfig(
granularity=granularity, version=2
)
l = torch.nn.Linear(tensor_size, tensor_size).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(tensor_size, tensor_size, dtype=torch.bfloat16, device="cuda")
)
quantize_(l, config)
param = l.weight
param_data = param.data
param_data = param_data.narrow(0, 0, slice_size)
# Making sure the aliasing is preserved in sliced quantized Tensor
assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr()
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_slice(self, config, device, dtype):
"""Test tensor slicing"""
tensor_size = 256
slice_sizes = (64, 128)

dummy = torch.nn.Linear(
tensor_size, tensor_size, bias=False, dtype=dtype, device=device
)
quantize_(dummy, config)

weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0])
weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1])

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))

# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
# Int8WeightOnlyConfig uses per-tensor (PerTensor)
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be per row I think?

group_size = weight.shape[-1]

if isinstance(config, Int8DynamicActivationInt8WeightConfig):
# PerRow: dim 0 slicing affects scale, dim 1 doesn't
self.assertEqual(
weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])
)
self.assertEqual(weight2.scale, dummy.weight.scale)
else:
# PerTensor: scale unchanged by slicing
self.assertEqual(weight1.scale, dummy.weight.scale)
self.assertEqual(weight2.scale, dummy.weight.scale)
with self.assertRaises(NotImplementedError):
_ = dummy.weight[::2]

def test_index_select(self):
"""test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`."""
N, K = 256, 512
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
x_int8 = Int8Tensor.from_hp(x, block_size=[N, K])
x_int8_0 = x_int8[0]
torch.testing.assert_close(
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
)

def test_invalid_input_handling(self):
"""Test input validation with specific error types"""
invalid_tensor = torch.randn(5)
incompatible_block_size = [1]

with self.assertRaises(
ValueError, msg="Should reject incompatible tensor dimensions"
):
Int8Tensor.from_hp(invalid_tensor, incompatible_block_size)

with self.assertRaises(
ValueError, msg="Should reject mismatched block size dimensions"
):
Int8Tensor.from_hp(self.weight_fp, [1])

def test_dequantization_accuracy(self):
"""Test dequantization accuracy separately"""
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16)
tensor = Int8Tensor.from_hp(test_data, [1, 2])

dequantized = tensor.dequantize()
self.assertEqual(dequantized.shape, test_data.shape)
self.assertLess(
torch.abs(dequantized - test_data).max().item(),
0.1,
msg=f"Dequantization error exceeds tolerance of {0.1}",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe check sqnr with compute_error instead of hardcoded absolute value?



if __name__ == "__main__":
common_utils.run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -168,6 +169,7 @@
"IntxOpaqueTensor",
"IntxUnpackedToInt8Tensor",
"Int4TilePackedTo4dTensor",
"Int8Tensor",
"Float8Tensor",
"Int4OpaqueTensor",
# smooth quant - subject to change
Expand Down
Loading