|
27 | 27 | from torch.fx import Node
|
28 | 28 |
|
29 | 29 | import torchao
|
30 |
| -from torchao.quantization import Granularity |
31 | 30 | from torchao.quantization.pt2e.utils import (
|
32 | 31 | calculate_qmin_qmax,
|
33 | 32 | check_min_max_valid,
|
|
68 | 67 | "ReuseInputObserver",
|
69 | 68 | "UniformQuantizationObserverBase",
|
70 | 69 | "AffineQuantizedObserverBase",
|
| 70 | + "Granularity", |
71 | 71 | "MappingType",
|
| 72 | + "PerAxis", |
| 73 | + "PerBlock", |
| 74 | + "PerGroup", |
| 75 | + "PerRow", |
| 76 | + "PerTensor", |
| 77 | + "PerToken", |
72 | 78 | "TorchAODType",
|
73 | 79 | "ZeroPointDomain",
|
| 80 | + "get_block_size", |
74 | 81 | ]
|
75 | 82 |
|
76 | 83 |
|
@@ -1615,6 +1622,7 @@ def calculate_qparams(self):
|
1615 | 1622 | We plan to merge the following with torchao repo after we move pt2e flow to torchao
|
1616 | 1623 | copied from https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py
|
1617 | 1624 | """
|
| 1625 | +from dataclasses import dataclass |
1618 | 1626 | from enum import Enum, auto
|
1619 | 1627 |
|
1620 | 1628 |
|
@@ -1671,6 +1679,139 @@ class TorchAODType(Enum):
|
1671 | 1679 | INT7 = auto()
|
1672 | 1680 |
|
1673 | 1681 |
|
| 1682 | +@dataclass(frozen=True) |
| 1683 | +class Granularity: |
| 1684 | + """ |
| 1685 | + Base class for representing the granularity of quantization. |
| 1686 | +
|
| 1687 | + This class serves as a parent for specific granularity types used in |
| 1688 | + quantization operations, such as per-tensor or per-axis quantization. |
| 1689 | + """ |
| 1690 | + |
| 1691 | + |
| 1692 | +@dataclass(frozen=True) |
| 1693 | +class PerBlock(Granularity): |
| 1694 | + """ |
| 1695 | + Represents per-block granularity in quantization. See |
| 1696 | + :func:`~torchao.quantization.quant_primitives.quantize_affine` for docs for |
| 1697 | + `block_size` |
| 1698 | +
|
| 1699 | + Attributes: |
| 1700 | + block_size (Tuple[int, ...]): The size of each quantization group |
| 1701 | + """ |
| 1702 | + |
| 1703 | + block_size: tuple[int, ...] |
| 1704 | + |
| 1705 | + |
| 1706 | +@dataclass(frozen=True) |
| 1707 | +class PerTensor(Granularity): |
| 1708 | + """ |
| 1709 | + Represents per-tensor granularity in quantization. |
| 1710 | +
|
| 1711 | + This granularity type calculates the quantization parameters |
| 1712 | + based off the entire tensor. |
| 1713 | +
|
| 1714 | + """ |
| 1715 | + |
| 1716 | + |
| 1717 | +@dataclass(frozen=True) |
| 1718 | +class PerAxis(Granularity): |
| 1719 | + """ |
| 1720 | + Represents per-axis granularity in quantization. |
| 1721 | +
|
| 1722 | + This granularity type calculates different quantization parameters |
| 1723 | + along a specified axis of the tensor. |
| 1724 | +
|
| 1725 | + For example if the input tensor is shape [8, 16] and axis=0, then |
| 1726 | + the quantization parameters are calculated for each row of the tensor. |
| 1727 | + Giving a total of 8 quantization parameters. |
| 1728 | +
|
| 1729 | + Attributes: |
| 1730 | + axis (int): The axis along which reduction is performed. |
| 1731 | + """ |
| 1732 | + |
| 1733 | + axis: int |
| 1734 | + |
| 1735 | + |
| 1736 | +@dataclass(frozen=True) |
| 1737 | +class PerGroup(Granularity): |
| 1738 | + """ |
| 1739 | + Represents per-channel group granularity in quantization. |
| 1740 | +
|
| 1741 | + This granularity type calculates different quantization parameters |
| 1742 | + for each group of <group_size> elements. |
| 1743 | +
|
| 1744 | + For example if the input tensor is shape [8, 16], and the group size is 4, then |
| 1745 | + the input tensor is reshaped to [64, 4] |
| 1746 | + quantization parameters are calculated for each group of 4 elements, |
| 1747 | + giving a total of 64 quantization parameters. |
| 1748 | +
|
| 1749 | + Attributes: |
| 1750 | + group_size (int): The size of each quantization group |
| 1751 | +
|
| 1752 | + """ |
| 1753 | + |
| 1754 | + group_size: int |
| 1755 | + |
| 1756 | + |
| 1757 | +class PerRow(Granularity): |
| 1758 | + """ |
| 1759 | + Represents row-wise granularity in quantization. |
| 1760 | +
|
| 1761 | + This is a special case of per-axis quantization and is unique to Float8 matmuls |
| 1762 | + where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight |
| 1763 | + is quantized with a block_size of (1, weight.shape[1]). |
| 1764 | + """ |
| 1765 | + |
| 1766 | + |
| 1767 | +class PerToken(Granularity): |
| 1768 | + """ |
| 1769 | + Represents per-token granularity in quantization. |
| 1770 | +
|
| 1771 | + This granularity type calculates a different set of quantization parameters |
| 1772 | + for each token, which is represented as the last dimension of the tensor. |
| 1773 | +
|
| 1774 | + For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens |
| 1775 | + with 4 elements each, and we will calculate 6 sets of quantization parameters, |
| 1776 | + one for each token. |
| 1777 | +
|
| 1778 | + If the input tensor has only two dimensions, e.g. [8, 16], then this is |
| 1779 | + equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters. |
| 1780 | + """ |
| 1781 | + |
| 1782 | + |
| 1783 | +def get_block_size( |
| 1784 | + input_shape: tuple[int, ...], granularity: Granularity |
| 1785 | +) -> tuple[int, ...]: |
| 1786 | + """Get the block size based on the input shape and granularity type. |
| 1787 | +
|
| 1788 | + Args: |
| 1789 | + input_shape: The input tensor shape possibly more than 2 dimensions |
| 1790 | + granularity: The granularity type of the quantization |
| 1791 | + """ |
| 1792 | + assert isinstance(granularity, Granularity), ( |
| 1793 | + "Please provide an instance of Granularity, not subclass of it" |
| 1794 | + ) |
| 1795 | + if isinstance(granularity, PerTensor): |
| 1796 | + return input_shape |
| 1797 | + elif isinstance(granularity, PerAxis): |
| 1798 | + block_size = list(input_shape) |
| 1799 | + block_size[granularity.axis] = 1 |
| 1800 | + return tuple(block_size) |
| 1801 | + elif isinstance(granularity, PerRow): |
| 1802 | + return (1,) * (len(input_shape) - 1) + (input_shape[-1],) |
| 1803 | + elif isinstance(granularity, PerGroup): |
| 1804 | + assert len(input_shape) == 2, ( |
| 1805 | + f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" |
| 1806 | + ) |
| 1807 | + return (1, granularity.group_size) |
| 1808 | + elif isinstance(granularity, PerToken): |
| 1809 | + block_size = [1] * len(input_shape) |
| 1810 | + block_size[-1] = input_shape[-1] |
| 1811 | + return tuple(block_size) |
| 1812 | + raise ValueError(f"Unsupported Granularity: {granularity}") |
| 1813 | + |
| 1814 | + |
1674 | 1815 | class AffineQuantizedObserverBase(ABC, torch.nn.Module):
|
1675 | 1816 | """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
|
1676 | 1817 |
|
|
0 commit comments