Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,10 +2948,11 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
def test_channel_group_quantization(self):
from torchao.quantization import PerGroup, PerToken
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMinMaxObserver,
)
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
from torchao.quantization.pt2e.observer import MappingType

class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -3031,13 +3032,13 @@ def forward(self, x):
def test_dynamic_affine_act_per_channel_weights(self):
import operator

from torchao.quantization import PerToken
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMovingAverageMinMaxObserver,
)
from torchao.quantization.pt2e.observer import (
MappingType,
PerChannelMinMaxObserver,
PerToken,
)

class BackendAQuantizer(Quantizer):
Expand Down Expand Up @@ -3122,12 +3123,14 @@ def forward(self, x):
def test_dynamic_per_tok_act_per_group_weights(self):
import operator

from torchao.quantization import PerGroup, PerToken

# TODO: merge into torchao observer
from torchao.quantization.pt2e._affine_quantization import (
AffineQuantizedMinMaxObserver,
AffineQuantizedPlaceholderObserver,
)
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
from torchao.quantization.pt2e.observer import MappingType

class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MultiTensorInputRecorder,
)
from .granularity import (
Granularity,
PerAxis,
PerGroup,
PerRow,
Expand Down Expand Up @@ -197,6 +198,7 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"Granularity",
"PerTensor",
"PerAxis",
"PerGroup",
Expand Down
22 changes: 1 addition & 21 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from .granularity import (
Granularity,
PerAxis,
PerRow,
PerTensor,
)
Expand All @@ -24,6 +23,7 @@
_get_reduction_params,
choose_qparams_affine_with_min_max,
)
from .utils import get_block_size

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs):
return r


def get_block_size(
input_shape: Tuple[int, ...], granularity: Granularity
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.

Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
raise ValueError(f"Unsupported Granularity: {granularity}")


ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:


Expand Down
16 changes: 0 additions & 16 deletions torchao/quantization/pt2e/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from .observer import (
AffineQuantizedObserverBase,
FixedQParamsObserver,
Granularity,
HistogramObserver,
MappingType,
MinMaxObserver,
Expand All @@ -57,20 +56,13 @@
NoopObserver,
ObserverBase,
PartialWrapper,
PerAxis,
PerBlock,
PerChannelMinMaxObserver,
PerGroup,
PerRow,
PerTensor,
PerToken,
PlaceholderObserver,
RecordingObserver,
ReuseInputObserver,
TorchAODType,
UniformQuantizationObserverBase,
ZeroPointDomain,
get_block_size,
)

for _f in [
Expand Down Expand Up @@ -139,17 +131,9 @@
"compare_results",
# should be merged with torchao/quantization/observer.py in the future
"AffineQuantizedObserverBase",
"Granularity",
"MappingType",
"PerAxis",
"PerBlock",
"PerGroup",
"PerRow",
"PerTensor",
"PerToken",
"TorchAODType",
"ZeroPointDomain",
"get_block_size",
"default_fake_quant",
"default_dynamic_fake_quant",
]
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/pt2e/_affine_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
MappingType,
TorchAODType,
ZeroPointDomain,
get_block_size,
)
from torchao.quantization.utils import get_block_size

ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:

Expand Down
143 changes: 1 addition & 142 deletions torchao/quantization/pt2e/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch.fx import Node

import torchao
from torchao.quantization import Granularity
from torchao.quantization.pt2e.utils import (
calculate_qmin_qmax,
check_min_max_valid,
Expand Down Expand Up @@ -67,17 +68,9 @@
"ReuseInputObserver",
"UniformQuantizationObserverBase",
"AffineQuantizedObserverBase",
"Granularity",
"MappingType",
"PerAxis",
"PerBlock",
"PerGroup",
"PerRow",
"PerTensor",
"PerToken",
"TorchAODType",
"ZeroPointDomain",
"get_block_size",
]


Expand Down Expand Up @@ -1622,7 +1615,6 @@ def calculate_qparams(self):
We plan to merge the following with torchao repo after we move pt2e flow to torchao
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
"""
from dataclasses import dataclass
from enum import Enum, auto


Expand Down Expand Up @@ -1679,139 +1671,6 @@ class TorchAODType(Enum):
INT7 = auto()


@dataclass(frozen=True)
class Granularity:
"""
Base class for representing the granularity of quantization.

This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""


@dataclass(frozen=True)
class PerBlock(Granularity):
"""
Represents per-block granularity in quantization. See
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
`block_size`

Attributes:
block_size (Tuple[int, ...]): The size of each quantization group
"""

block_size: tuple[int, ...]


@dataclass(frozen=True)
class PerTensor(Granularity):
"""
Represents per-tensor granularity in quantization.

This granularity type calculates the quantization parameters
based off the entire tensor.

"""


@dataclass(frozen=True)
class PerAxis(Granularity):
"""
Represents per-axis granularity in quantization.

This granularity type calculates different quantization parameters
along a specified axis of the tensor.

For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.

Attributes:
axis (int): The axis along which reduction is performed.
"""

axis: int


@dataclass(frozen=True)
class PerGroup(Granularity):
"""
Represents per-channel group granularity in quantization.

This granularity type calculates different quantization parameters
for each group of <group_size> elements.

For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.

Attributes:
group_size (int): The size of each quantization group

"""

group_size: int


class PerRow(Granularity):
"""
Represents row-wise granularity in quantization.

This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""


class PerToken(Granularity):
"""
Represents per-token granularity in quantization.

This granularity type calculates a different set of quantization parameters
for each token, which is represented as the last dimension of the tensor.

For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
with 4 elements each, and we will calculate 6 sets of quantization parameters,
one for each token.

If the input tensor has only two dimensions, e.g. [8, 16], then this is
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
"""


def get_block_size(
input_shape: tuple[int, ...], granularity: Granularity
) -> tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.

Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity: The granularity type of the quantization
"""
assert isinstance(granularity, Granularity), (
"Please provide an instance of Granularity, not subclass of it"
)
if isinstance(granularity, PerTensor):
return input_shape
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
return tuple(block_size)
elif isinstance(granularity, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
elif isinstance(granularity, PerGroup):
assert len(input_shape) == 2, (
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
)
return (1, granularity.group_size)
elif isinstance(granularity, PerToken):
block_size = [1] * len(input_shape)
block_size[-1] = input_shape[-1]
return tuple(block_size)
raise ValueError(f"Unsupported Granularity: {granularity}")


class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)

Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PerRow,
PerToken,
)
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_DTYPE_TO_BIT_WIDTH,
_DTYPE_TO_QVALUE_BOUNDS,
Expand All @@ -28,6 +27,7 @@
)
from torchao.quantization.utils import (
_get_per_token_block_size,
get_block_size,
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
)
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from torchao.quantization.linear_activation_weight_observed_tensor import (
LinearActivationWeightObservedTensor,
)
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
from torchao.quantization.observer import AffineQuantizedObserverBase
from torchao.quantization.quantize_.common import (
KernelPreference,
)
Expand All @@ -87,6 +87,7 @@
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
)
from torchao.quantization.utils import get_block_size
from torchao.quantization.weight_tensor_linear_activation_quantization import (
to_weight_tensor_with_linear_activation_quantization_metadata,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
preprocess_scale,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.observer import get_block_size
from torchao.quantization.quant_primitives import (
_choose_scale_float8,
_dequantize_affine_float8,
Expand All @@ -34,6 +33,7 @@
QuantizeTensorKwargs,
_choose_quant_func_and_quantize_tensor,
)
from torchao.quantization.utils import get_block_size
from torchao.utils import (
TorchAOBaseTensor,
_is_fbgemm_genai_gpu_available,
Expand Down
Loading
Loading