Skip to content

Commit d613f9a

Browse files
authored
Revert "Unify get_block_size (#3039)"
This reverts commit 8e2ca35.
1 parent 8e2ca35 commit d613f9a

File tree

11 files changed

+188
-50
lines changed

11 files changed

+188
-50
lines changed

test/quantization/pt2e/test_quantize_pt2e.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,11 +2948,10 @@ def has_inplace_ops(graph_module: torch.fx.GraphModule) -> bool:
29482948
@unittest.skipIf(not torch_version_at_least("2.7.0"), "Requires torch 2.7+")
29492949
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
29502950
def test_channel_group_quantization(self):
2951-
from torchao.quantization import PerGroup, PerToken
29522951
from torchao.quantization.pt2e._affine_quantization import (
29532952
AffineQuantizedMinMaxObserver,
29542953
)
2955-
from torchao.quantization.pt2e.observer import MappingType
2954+
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
29562955

29572956
class BackendAQuantizer(Quantizer):
29582957
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -3032,13 +3031,13 @@ def forward(self, x):
30323031
def test_dynamic_affine_act_per_channel_weights(self):
30333032
import operator
30343033

3035-
from torchao.quantization import PerToken
30363034
from torchao.quantization.pt2e._affine_quantization import (
30373035
AffineQuantizedMovingAverageMinMaxObserver,
30383036
)
30393037
from torchao.quantization.pt2e.observer import (
30403038
MappingType,
30413039
PerChannelMinMaxObserver,
3040+
PerToken,
30423041
)
30433042

30443043
class BackendAQuantizer(Quantizer):
@@ -3123,14 +3122,12 @@ def forward(self, x):
31233122
def test_dynamic_per_tok_act_per_group_weights(self):
31243123
import operator
31253124

3126-
from torchao.quantization import PerGroup, PerToken
3127-
31283125
# TODO: merge into torchao observer
31293126
from torchao.quantization.pt2e._affine_quantization import (
31303127
AffineQuantizedMinMaxObserver,
31313128
AffineQuantizedPlaceholderObserver,
31323129
)
3133-
from torchao.quantization.pt2e.observer import MappingType
3130+
from torchao.quantization.pt2e.observer import MappingType, PerGroup, PerToken
31343131

31353132
class BackendAQuantizer(Quantizer):
31363133
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

torchao/quantization/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
MultiTensorInputRecorder,
2020
)
2121
from .granularity import (
22-
Granularity,
2322
PerAxis,
2423
PerGroup,
2524
PerRow,
@@ -198,7 +197,6 @@
198197
"MappingType",
199198
"ZeroPointDomain",
200199
"TorchAODType",
201-
"Granularity",
202200
"PerTensor",
203201
"PerAxis",
204202
"PerGroup",

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _same_metadata(
133133

134134
@implements([torch.nn.functional.linear, aten.linear.default])
135135
def _(func, types, args, kwargs):
136+
136137
input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
137138
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
138139
bias = kwargs.get("bias", args[2] if len(args) > 2 else None)

torchao/quantization/observer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .granularity import (
1616
Granularity,
17+
PerAxis,
1718
PerRow,
1819
PerTensor,
1920
)
@@ -23,7 +24,6 @@
2324
_get_reduction_params,
2425
choose_qparams_affine_with_min_max,
2526
)
26-
from .utils import get_block_size
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -63,6 +63,26 @@ def _with_args(cls_or_self, *args, **kwargs):
6363
return r
6464

6565

66+
def get_block_size(
67+
input_shape: Tuple[int, ...], granularity: Granularity
68+
) -> Tuple[int, ...]:
69+
"""Get the block size based on the input shape and granularity type.
70+
71+
Args:
72+
input_shape: The input tensor shape possibly more than 2 dimensions
73+
granularity: The granularity type of the quantization
74+
"""
75+
if isinstance(granularity, PerTensor):
76+
return input_shape
77+
elif isinstance(granularity, PerAxis):
78+
block_size = list(input_shape)
79+
block_size[granularity.axis] = 1
80+
return tuple(block_size)
81+
elif isinstance(granularity, PerRow):
82+
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
83+
raise ValueError(f"Unsupported Granularity: {granularity}")
84+
85+
6686
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
6787

6888

torchao/quantization/pt2e/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .observer import (
4949
AffineQuantizedObserverBase,
5050
FixedQParamsObserver,
51+
Granularity,
5152
HistogramObserver,
5253
MappingType,
5354
MinMaxObserver,
@@ -56,13 +57,20 @@
5657
NoopObserver,
5758
ObserverBase,
5859
PartialWrapper,
60+
PerAxis,
61+
PerBlock,
5962
PerChannelMinMaxObserver,
63+
PerGroup,
64+
PerRow,
65+
PerTensor,
66+
PerToken,
6067
PlaceholderObserver,
6168
RecordingObserver,
6269
ReuseInputObserver,
6370
TorchAODType,
6471
UniformQuantizationObserverBase,
6572
ZeroPointDomain,
73+
get_block_size,
6674
)
6775

6876
for _f in [
@@ -131,9 +139,17 @@
131139
"compare_results",
132140
# should be merged with torchao/quantization/observer.py in the future
133141
"AffineQuantizedObserverBase",
142+
"Granularity",
134143
"MappingType",
144+
"PerAxis",
145+
"PerBlock",
146+
"PerGroup",
147+
"PerRow",
148+
"PerTensor",
149+
"PerToken",
135150
"TorchAODType",
136151
"ZeroPointDomain",
152+
"get_block_size",
137153
"default_fake_quant",
138154
"default_dynamic_fake_quant",
139155
]

torchao/quantization/pt2e/_affine_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
MappingType,
2020
TorchAODType,
2121
ZeroPointDomain,
22+
get_block_size,
2223
)
23-
from torchao.quantization.utils import get_block_size
2424

2525
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
2626

torchao/quantization/pt2e/observer.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from torch.fx import Node
2828

2929
import torchao
30-
from torchao.quantization import Granularity
3130
from torchao.quantization.pt2e.utils import (
3231
calculate_qmin_qmax,
3332
check_min_max_valid,
@@ -68,9 +67,17 @@
6867
"ReuseInputObserver",
6968
"UniformQuantizationObserverBase",
7069
"AffineQuantizedObserverBase",
70+
"Granularity",
7171
"MappingType",
72+
"PerAxis",
73+
"PerBlock",
74+
"PerGroup",
75+
"PerRow",
76+
"PerTensor",
77+
"PerToken",
7278
"TorchAODType",
7379
"ZeroPointDomain",
80+
"get_block_size",
7481
]
7582

7683

@@ -1615,6 +1622,7 @@ def calculate_qparams(self):
16151622
We plan to merge the following with torchao repo after we move pt2e flow to torchao
16161623
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
16171624
"""
1625+
from dataclasses import dataclass
16181626
from enum import Enum, auto
16191627

16201628

@@ -1671,6 +1679,139 @@ class TorchAODType(Enum):
16711679
INT7 = auto()
16721680

16731681

1682+
@dataclass(frozen=True)
1683+
class Granularity:
1684+
"""
1685+
Base class for representing the granularity of quantization.
1686+
1687+
This class serves as a parent for specific granularity types used in
1688+
quantization operations, such as per-tensor or per-axis quantization.
1689+
"""
1690+
1691+
1692+
@dataclass(frozen=True)
1693+
class PerBlock(Granularity):
1694+
"""
1695+
Represents per-block granularity in quantization. See
1696+
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
1697+
`block_size`
1698+
1699+
Attributes:
1700+
block_size (Tuple[int, ...]): The size of each quantization group
1701+
"""
1702+
1703+
block_size: tuple[int, ...]
1704+
1705+
1706+
@dataclass(frozen=True)
1707+
class PerTensor(Granularity):
1708+
"""
1709+
Represents per-tensor granularity in quantization.
1710+
1711+
This granularity type calculates the quantization parameters
1712+
based off the entire tensor.
1713+
1714+
"""
1715+
1716+
1717+
@dataclass(frozen=True)
1718+
class PerAxis(Granularity):
1719+
"""
1720+
Represents per-axis granularity in quantization.
1721+
1722+
This granularity type calculates different quantization parameters
1723+
along a specified axis of the tensor.
1724+
1725+
For example if the input tensor is shape [8, 16] and axis=0, then
1726+
the quantization parameters are calculated for each row of the tensor.
1727+
Giving a total of 8 quantization parameters.
1728+
1729+
Attributes:
1730+
axis (int): The axis along which reduction is performed.
1731+
"""
1732+
1733+
axis: int
1734+
1735+
1736+
@dataclass(frozen=True)
1737+
class PerGroup(Granularity):
1738+
"""
1739+
Represents per-channel group granularity in quantization.
1740+
1741+
This granularity type calculates different quantization parameters
1742+
for each group of <group_size> elements.
1743+
1744+
For example if the input tensor is shape [8, 16], and the group size is 4, then
1745+
the input tensor is reshaped to [64, 4]
1746+
quantization parameters are calculated for each group of 4 elements,
1747+
giving a total of 64 quantization parameters.
1748+
1749+
Attributes:
1750+
group_size (int): The size of each quantization group
1751+
1752+
"""
1753+
1754+
group_size: int
1755+
1756+
1757+
class PerRow(Granularity):
1758+
"""
1759+
Represents row-wise granularity in quantization.
1760+
1761+
This is a special case of per-axis quantization and is unique to Float8 matmuls
1762+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
1763+
is quantized with a block_size of (1, weight.shape[1]).
1764+
"""
1765+
1766+
1767+
class PerToken(Granularity):
1768+
"""
1769+
Represents per-token granularity in quantization.
1770+
1771+
This granularity type calculates a different set of quantization parameters
1772+
for each token, which is represented as the last dimension of the tensor.
1773+
1774+
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
1775+
with 4 elements each, and we will calculate 6 sets of quantization parameters,
1776+
one for each token.
1777+
1778+
If the input tensor has only two dimensions, e.g. [8, 16], then this is
1779+
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
1780+
"""
1781+
1782+
1783+
def get_block_size(
1784+
input_shape: tuple[int, ...], granularity: Granularity
1785+
) -> tuple[int, ...]:
1786+
"""Get the block size based on the input shape and granularity type.
1787+
1788+
Args:
1789+
input_shape: The input tensor shape possibly more than 2 dimensions
1790+
granularity: The granularity type of the quantization
1791+
"""
1792+
assert isinstance(granularity, Granularity), (
1793+
"Please provide an instance of Granularity, not subclass of it"
1794+
)
1795+
if isinstance(granularity, PerTensor):
1796+
return input_shape
1797+
elif isinstance(granularity, PerAxis):
1798+
block_size = list(input_shape)
1799+
block_size[granularity.axis] = 1
1800+
return tuple(block_size)
1801+
elif isinstance(granularity, PerRow):
1802+
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
1803+
elif isinstance(granularity, PerGroup):
1804+
assert len(input_shape) == 2, (
1805+
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
1806+
)
1807+
return (1, granularity.group_size)
1808+
elif isinstance(granularity, PerToken):
1809+
block_size = [1] * len(input_shape)
1810+
block_size[-1] = input_shape[-1]
1811+
return tuple(block_size)
1812+
raise ValueError(f"Unsupported Granularity: {granularity}")
1813+
1814+
16741815
class AffineQuantizedObserverBase(ABC, torch.nn.Module):
16751816
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
16761817

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
PerRow,
1515
PerToken,
1616
)
17+
from torchao.quantization.observer import get_block_size
1718
from torchao.quantization.quant_primitives import (
1819
_DTYPE_TO_BIT_WIDTH,
1920
_DTYPE_TO_QVALUE_BOUNDS,
@@ -27,7 +28,6 @@
2728
)
2829
from torchao.quantization.utils import (
2930
_get_per_token_block_size,
30-
get_block_size,
3131
get_group_qparams_symmetric,
3232
get_groupwise_affine_qparams,
3333
)

torchao/quantization/quant_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from torchao.quantization.linear_activation_weight_observed_tensor import (
6565
LinearActivationWeightObservedTensor,
6666
)
67-
from torchao.quantization.observer import AffineQuantizedObserverBase
67+
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
6868
from torchao.quantization.quantize_.common import (
6969
KernelPreference,
7070
)
@@ -87,7 +87,6 @@
8787
_QUANTIZE_CONFIG_HANDLER,
8888
register_quantize_module_handler,
8989
)
90-
from torchao.quantization.utils import get_block_size
9190
from torchao.quantization.weight_tensor_linear_activation_quantization import (
9291
to_weight_tensor_with_linear_activation_quantization_metadata,
9392
)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
preprocess_scale,
2424
)
2525
from torchao.quantization.granularity import PerRow, PerTensor
26+
from torchao.quantization.observer import get_block_size
2627
from torchao.quantization.quant_primitives import (
2728
_choose_scale_float8,
2829
_dequantize_affine_float8,
@@ -33,7 +34,6 @@
3334
QuantizeTensorKwargs,
3435
_choose_quant_func_and_quantize_tensor,
3536
)
36-
from torchao.quantization.utils import get_block_size
3737
from torchao.utils import (
3838
TorchAOBaseTensor,
3939
_is_fbgemm_genai_gpu_available,

0 commit comments

Comments
 (0)