|
27 | 27 | from torch.fx import Node
|
28 | 28 |
|
29 | 29 | import torchao
|
| 30 | +from torchao.quantization import ( |
| 31 | + Granularity, |
| 32 | + PerAxis, |
| 33 | + PerBlock, |
| 34 | + PerGroup, |
| 35 | + PerRow, |
| 36 | + PerTensor, |
| 37 | + PerToken, |
| 38 | +) |
30 | 39 | from torchao.quantization.pt2e.utils import (
|
31 | 40 | calculate_qmin_qmax,
|
32 | 41 | check_min_max_valid,
|
33 | 42 | is_per_channel,
|
34 | 43 | is_per_tensor,
|
35 | 44 | validate_qmin_qmax,
|
36 | 45 | )
|
| 46 | +from torchao.quantization.utils import get_block_size |
37 | 47 |
|
38 | 48 | __all__ = [
|
39 | 49 | "default_affine_fixed_qparams_observer",
|
@@ -1622,7 +1632,6 @@ def calculate_qparams(self):
|
1622 | 1632 | We plan to merge the following with torchao repo after we move pt2e flow to torchao
|
1623 | 1633 | copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
|
1624 | 1634 | """
|
1625 |
| -from dataclasses import dataclass |
1626 | 1635 | from enum import Enum, auto
|
1627 | 1636 |
|
1628 | 1637 |
|
@@ -1679,139 +1688,6 @@ class TorchAODType(Enum):
|
1679 | 1688 | INT7 = auto()
|
1680 | 1689 |
|
1681 | 1690 |
|
1682 |
| -@dataclass(frozen=True) |
1683 |
| -class Granularity: |
1684 |
| - """ |
1685 |
| - Base class for representing the granularity of quantization. |
1686 |
| -
|
1687 |
| - This class serves as a parent for specific granularity types used in |
1688 |
| - quantization operations, such as per-tensor or per-axis quantization. |
1689 |
| - """ |
1690 |
| - |
1691 |
| - |
1692 |
| -@dataclass(frozen=True) |
1693 |
| -class PerBlock(Granularity): |
1694 |
| - """ |
1695 |
| - Represents per-block granularity in quantization. See |
1696 |
| - :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for |
1697 |
| - `block_size` |
1698 |
| -
|
1699 |
| - Attributes: |
1700 |
| - block_size (Tuple[int, ...]): The size of each quantization group |
1701 |
| - """ |
1702 |
| - |
1703 |
| - block_size: tuple[int, ...] |
1704 |
| - |
1705 |
| - |
1706 |
| -@dataclass(frozen=True) |
1707 |
| -class PerTensor(Granularity): |
1708 |
| - """ |
1709 |
| - Represents per-tensor granularity in quantization. |
1710 |
| -
|
1711 |
| - This granularity type calculates the quantization parameters |
1712 |
| - based off the entire tensor. |
1713 |
| -
|
1714 |
| - """ |
1715 |
| - |
1716 |
| - |
1717 |
| -@dataclass(frozen=True) |
1718 |
| -class PerAxis(Granularity): |
1719 |
| - """ |
1720 |
| - Represents per-axis granularity in quantization. |
1721 |
| -
|
1722 |
| - This granularity type calculates different quantization parameters |
1723 |
| - along a specified axis of the tensor. |
1724 |
| -
|
1725 |
| - For example if the input tensor is shape [8, 16] and axis=0, then |
1726 |
| - the quantization parameters are calculated for each row of the tensor. |
1727 |
| - Giving a total of 8 quantization parameters. |
1728 |
| -
|
1729 |
| - Attributes: |
1730 |
| - axis (int): The axis along which reduction is performed. |
1731 |
| - """ |
1732 |
| - |
1733 |
| - axis: int |
1734 |
| - |
1735 |
| - |
1736 |
| -@dataclass(frozen=True) |
1737 |
| -class PerGroup(Granularity): |
1738 |
| - """ |
1739 |
| - Represents per-channel group granularity in quantization. |
1740 |
| -
|
1741 |
| - This granularity type calculates different quantization parameters |
1742 |
| - for each group of <group_size> elements. |
1743 |
| -
|
1744 |
| - For example if the input tensor is shape [8, 16], and the group size is 4, then |
1745 |
| - the input tensor is reshaped to [64, 4] |
1746 |
| - quantization parameters are calculated for each group of 4 elements, |
1747 |
| - giving a total of 64 quantization parameters. |
1748 |
| -
|
1749 |
| - Attributes: |
1750 |
| - group_size (int): The size of each quantization group |
1751 |
| -
|
1752 |
| - """ |
1753 |
| - |
1754 |
| - group_size: int |
1755 |
| - |
1756 |
| - |
1757 |
| -class PerRow(Granularity): |
1758 |
| - """ |
1759 |
| - Represents row-wise granularity in quantization. |
1760 |
| -
|
1761 |
| - This is a special case of per-axis quantization and is unique to Float8 matmuls |
1762 |
| - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight |
1763 |
| - is quantized with a block_size of (1, weight.shape[1]). |
1764 |
| - """ |
1765 |
| - |
1766 |
| - |
1767 |
| -class PerToken(Granularity): |
1768 |
| - """ |
1769 |
| - Represents per-token granularity in quantization. |
1770 |
| -
|
1771 |
| - This granularity type calculates a different set of quantization parameters |
1772 |
| - for each token, which is represented as the last dimension of the tensor. |
1773 |
| -
|
1774 |
| - For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens |
1775 |
| - with 4 elements each, and we will calculate 6 sets of quantization parameters, |
1776 |
| - one for each token. |
1777 |
| -
|
1778 |
| - If the input tensor has only two dimensions, e.g. [8, 16], then this is |
1779 |
| - equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. |
1780 |
| - """ |
1781 |
| - |
1782 |
| - |
1783 |
| -def get_block_size( |
1784 |
| - input_shape: tuple[int, ...], granularity: Granularity |
1785 |
| -) -> tuple[int, ...]: |
1786 |
| - """Get the block size based on the input shape and granularity type. |
1787 |
| -
|
1788 |
| - Args: |
1789 |
| - input_shape: The input tensor shape possibly more than 2 dimensions |
1790 |
| - granularity: The granularity type of the quantization |
1791 |
| - """ |
1792 |
| - assert isinstance(granularity, Granularity), ( |
1793 |
| - "Please provide an instance of Granularity, not subclass of it" |
1794 |
| - ) |
1795 |
| - if isinstance(granularity, PerTensor): |
1796 |
| - return input_shape |
1797 |
| - elif isinstance(granularity, PerAxis): |
1798 |
| - block_size = list(input_shape) |
1799 |
| - block_size[granularity.axis] = 1 |
1800 |
| - return tuple(block_size) |
1801 |
| - elif isinstance(granularity, PerRow): |
1802 |
| - return (1,) * (len(input_shape) - 1) + (input_shape[-1],) |
1803 |
| - elif isinstance(granularity, PerGroup): |
1804 |
| - assert len(input_shape) == 2, ( |
1805 |
| - f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" |
1806 |
| - ) |
1807 |
| - return (1, granularity.group_size) |
1808 |
| - elif isinstance(granularity, PerToken): |
1809 |
| - block_size = [1] * len(input_shape) |
1810 |
| - block_size[-1] = input_shape[-1] |
1811 |
| - return tuple(block_size) |
1812 |
| - raise ValueError(f"Unsupported Granularity: {granularity}") |
1813 |
| - |
1814 |
| - |
1815 | 1691 | class AffineQuantizedObserverBase(ABC, torch.nn.Module):
|
1816 | 1692 | """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
|
1817 | 1693 |
|
|
0 commit comments