Skip to content

Commit 0064084

Browse files
authored
[Reland] Unify get_block_size (#3059)
* [Reland] Unify get_block_size * Refine code * Refine code
1 parent 5e90c47 commit 0064084

File tree

10 files changed

+91
-169
lines changed

10 files changed

+91
-169
lines changed

torchao/quantization/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
MultiTensorInputRecorder,
2020
)
2121
from .granularity import (
22+
Granularity,
2223
PerAxis,
24+
PerBlock,
2325
PerGroup,
2426
PerRow,
2527
PerTensor,
@@ -197,8 +199,10 @@
197199
"MappingType",
198200
"ZeroPointDomain",
199201
"TorchAODType",
202+
"Granularity",
200203
"PerTensor",
201204
"PerAxis",
205+
"PerBlock",
202206
"PerGroup",
203207
"PerRow",
204208
"PerToken",

torchao/quantization/granularity.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class PerGroup(Granularity):
7171
group_size: int
7272

7373

74+
@dataclass(frozen=True)
7475
class PerRow(Granularity):
7576
"""
7677
Represents row-wise granularity in quantization.
@@ -83,6 +84,7 @@ class PerRow(Granularity):
8384
pass
8485

8586

87+
@dataclass(frozen=True)
8688
class PerToken(Granularity):
8789
"""
8890
Represents per-token granularity in quantization.
@@ -99,3 +101,16 @@ class PerToken(Granularity):
99101
"""
100102

101103
pass
104+
105+
106+
@dataclass(frozen=True)
107+
class PerBlock(Granularity):
108+
"""
109+
Represents per-block granularity in quantization. See
110+
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
111+
`block_size`
112+
Attributes:
113+
block_size (tuple[int, ...]): The size of each quantization group
114+
"""
115+
116+
block_size: tuple[int, ...]

torchao/quantization/observer.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from .granularity import (
1616
Granularity,
17-
PerAxis,
1817
PerRow,
1918
PerTensor,
2019
)
@@ -24,6 +23,7 @@
2423
_get_reduction_params,
2524
choose_qparams_affine_with_min_max,
2625
)
26+
from .utils import get_block_size
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -63,26 +63,6 @@ def _with_args(cls_or_self, *args, **kwargs):
6363
return r
6464

6565

66-
def get_block_size(
67-
input_shape: Tuple[int, ...], granularity: Granularity
68-
) -> Tuple[int, ...]:
69-
"""Get the block size based on the input shape and granularity type.
70-
71-
Args:
72-
input_shape: The input tensor shape possibly more than 2 dimensions
73-
granularity: The granularity type of the quantization
74-
"""
75-
if isinstance(granularity, PerTensor):
76-
return input_shape
77-
elif isinstance(granularity, PerAxis):
78-
block_size = list(input_shape)
79-
block_size[granularity.axis] = 1
80-
return tuple(block_size)
81-
elif isinstance(granularity, PerRow):
82-
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
83-
raise ValueError(f"Unsupported Granularity: {granularity}")
84-
85-
8666
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
8767

8868

torchao/quantization/pt2e/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
import torch
66
from torch import Tensor
77

8+
from torchao.quantization import (
9+
Granularity,
10+
PerAxis,
11+
PerBlock,
12+
PerGroup,
13+
PerRow,
14+
PerTensor,
15+
PerToken,
16+
)
817
from torchao.quantization.pt2e._numeric_debugger import ( # noqa: F401
918
CUSTOM_KEY,
1019
FROM_NODE_KEY,
@@ -32,6 +41,7 @@
3241
get_equivalent_types,
3342
update_equivalent_types_dict,
3443
)
44+
from torchao.quantization.utils import get_block_size
3545

3646
from .fake_quantize import (
3747
FakeQuantize,
@@ -48,7 +58,6 @@
4858
from .observer import (
4959
AffineQuantizedObserverBase,
5060
FixedQParamsObserver,
51-
Granularity,
5261
HistogramObserver,
5362
MappingType,
5463
MinMaxObserver,
@@ -57,20 +66,13 @@
5766
NoopObserver,
5867
ObserverBase,
5968
PartialWrapper,
60-
PerAxis,
61-
PerBlock,
6269
PerChannelMinMaxObserver,
63-
PerGroup,
64-
PerRow,
65-
PerTensor,
66-
PerToken,
6770
PlaceholderObserver,
6871
RecordingObserver,
6972
ReuseInputObserver,
7073
TorchAODType,
7174
UniformQuantizationObserverBase,
7275
ZeroPointDomain,
73-
get_block_size,
7476
)
7577

7678
for _f in [

torchao/quantization/pt2e/_affine_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313

1414
import torch
1515

16+
from torchao.quantization import Granularity
1617
from torchao.quantization.pt2e.observer import (
1718
AffineQuantizedObserverBase,
18-
Granularity,
1919
MappingType,
2020
TorchAODType,
2121
ZeroPointDomain,
22-
get_block_size,
2322
)
23+
from torchao.quantization.utils import get_block_size
2424

2525
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
2626

torchao/quantization/pt2e/observer.py

Lines changed: 10 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,23 @@
2727
from torch.fx import Node
2828

2929
import torchao
30+
from torchao.quantization import (
31+
Granularity,
32+
PerAxis,
33+
PerBlock,
34+
PerGroup,
35+
PerRow,
36+
PerTensor,
37+
PerToken,
38+
)
3039
from torchao.quantization.pt2e.utils import (
3140
calculate_qmin_qmax,
3241
check_min_max_valid,
3342
is_per_channel,
3443
is_per_tensor,
3544
validate_qmin_qmax,
3645
)
46+
from torchao.quantization.utils import get_block_size
3747

3848
__all__ = [
3949
"default_affine_fixed_qparams_observer",
@@ -1622,7 +1632,6 @@ def calculate_qparams(self):
16221632
We plan to merge the following with torchao repo after we move pt2e flow to torchao
16231633
copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
16241634
"""
1625-
from dataclasses import dataclass
16261635
from enum import Enum, auto
16271636

16281637

@@ -1679,139 +1688,6 @@ class TorchAODType(Enum):
16791688
INT7 = auto()
16801689

16811690

1682-
@dataclass(frozen=True)
1683-
class Granularity:
1684-
"""
1685-
Base class for representing the granularity of quantization.
1686-
1687-
This class serves as a parent for specific granularity types used in
1688-
quantization operations, such as per-tensor or per-axis quantization.
1689-
"""
1690-
1691-
1692-
@dataclass(frozen=True)
1693-
class PerBlock(Granularity):
1694-
"""
1695-
Represents per-block granularity in quantization. See
1696-
:func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for
1697-
`block_size`
1698-
1699-
Attributes:
1700-
block_size (Tuple[int, ...]): The size of each quantization group
1701-
"""
1702-
1703-
block_size: tuple[int, ...]
1704-
1705-
1706-
@dataclass(frozen=True)
1707-
class PerTensor(Granularity):
1708-
"""
1709-
Represents per-tensor granularity in quantization.
1710-
1711-
This granularity type calculates the quantization parameters
1712-
based off the entire tensor.
1713-
1714-
"""
1715-
1716-
1717-
@dataclass(frozen=True)
1718-
class PerAxis(Granularity):
1719-
"""
1720-
Represents per-axis granularity in quantization.
1721-
1722-
This granularity type calculates different quantization parameters
1723-
along a specified axis of the tensor.
1724-
1725-
For example if the input tensor is shape [8, 16] and axis=0, then
1726-
the quantization parameters are calculated for each row of the tensor.
1727-
Giving a total of 8 quantization parameters.
1728-
1729-
Attributes:
1730-
axis (int): The axis along which reduction is performed.
1731-
"""
1732-
1733-
axis: int
1734-
1735-
1736-
@dataclass(frozen=True)
1737-
class PerGroup(Granularity):
1738-
"""
1739-
Represents per-channel group granularity in quantization.
1740-
1741-
This granularity type calculates different quantization parameters
1742-
for each group of <group_size> elements.
1743-
1744-
For example if the input tensor is shape [8, 16], and the group size is 4, then
1745-
the input tensor is reshaped to [64, 4]
1746-
quantization parameters are calculated for each group of 4 elements,
1747-
giving a total of 64 quantization parameters.
1748-
1749-
Attributes:
1750-
group_size (int): The size of each quantization group
1751-
1752-
"""
1753-
1754-
group_size: int
1755-
1756-
1757-
class PerRow(Granularity):
1758-
"""
1759-
Represents row-wise granularity in quantization.
1760-
1761-
This is a special case of per-axis quantization and is unique to Float8 matmuls
1762-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
1763-
is quantized with a block_size of (1, weight.shape[1]).
1764-
"""
1765-
1766-
1767-
class PerToken(Granularity):
1768-
"""
1769-
Represents per-token granularity in quantization.
1770-
1771-
This granularity type calculates a different set of quantization parameters
1772-
for each token, which is represented as the last dimension of the tensor.
1773-
1774-
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
1775-
with 4 elements each, and we will calculate 6 sets of quantization parameters,
1776-
one for each token.
1777-
1778-
If the input tensor has only two dimensions, e.g. [8, 16], then this is
1779-
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
1780-
"""
1781-
1782-
1783-
def get_block_size(
1784-
input_shape: tuple[int, ...], granularity: Granularity
1785-
) -> tuple[int, ...]:
1786-
"""Get the block size based on the input shape and granularity type.
1787-
1788-
Args:
1789-
input_shape: The input tensor shape possibly more than 2 dimensions
1790-
granularity: The granularity type of the quantization
1791-
"""
1792-
assert isinstance(granularity, Granularity), (
1793-
"Please provide an instance of Granularity, not subclass of it"
1794-
)
1795-
if isinstance(granularity, PerTensor):
1796-
return input_shape
1797-
elif isinstance(granularity, PerAxis):
1798-
block_size = list(input_shape)
1799-
block_size[granularity.axis] = 1
1800-
return tuple(block_size)
1801-
elif isinstance(granularity, PerRow):
1802-
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
1803-
elif isinstance(granularity, PerGroup):
1804-
assert len(input_shape) == 2, (
1805-
f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
1806-
)
1807-
return (1, granularity.group_size)
1808-
elif isinstance(granularity, PerToken):
1809-
block_size = [1] * len(input_shape)
1810-
block_size[-1] = input_shape[-1]
1811-
return tuple(block_size)
1812-
raise ValueError(f"Unsupported Granularity: {granularity}")
1813-
1814-
18151691
class AffineQuantizedObserverBase(ABC, torch.nn.Module):
18161692
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
18171693

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
PerRow,
1515
PerToken,
1616
)
17-
from torchao.quantization.observer import get_block_size
1817
from torchao.quantization.quant_primitives import (
1918
_DTYPE_TO_BIT_WIDTH,
2019
_DTYPE_TO_QVALUE_BOUNDS,
@@ -28,6 +27,7 @@
2827
)
2928
from torchao.quantization.utils import (
3029
_get_per_token_block_size,
30+
get_block_size,
3131
get_group_qparams_symmetric,
3232
get_groupwise_affine_qparams,
3333
)

torchao/quantization/quant_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from torchao.quantization.linear_activation_weight_observed_tensor import (
6565
LinearActivationWeightObservedTensor,
6666
)
67-
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
67+
from torchao.quantization.observer import AffineQuantizedObserverBase
6868
from torchao.quantization.quantize_.common import (
6969
KernelPreference,
7070
)
@@ -87,6 +87,7 @@
8787
_QUANTIZE_CONFIG_HANDLER,
8888
register_quantize_module_handler,
8989
)
90+
from torchao.quantization.utils import get_block_size
9091
from torchao.quantization.weight_tensor_linear_activation_quantization import (
9192
to_weight_tensor_with_linear_activation_quantization_metadata,
9293
)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
preprocess_scale,
2424
)
2525
from torchao.quantization.granularity import PerRow, PerTensor
26-
from torchao.quantization.observer import get_block_size
2726
from torchao.quantization.quant_primitives import (
2827
_choose_scale_float8,
2928
_dequantize_affine_float8,
@@ -34,6 +33,7 @@
3433
QuantizeTensorKwargs,
3534
_choose_quant_func_and_quantize_tensor,
3635
)
36+
from torchao.quantization.utils import get_block_size
3737
from torchao.utils import (
3838
TorchAOBaseTensor,
3939
_is_fbgemm_genai_gpu_available,

0 commit comments

Comments
 (0)