|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import platform |
| 8 | +import sys |
| 9 | +from copy import deepcopy |
| 10 | +from dataclasses import dataclass |
| 11 | + |
| 12 | +import pytest |
| 13 | +import torch |
| 14 | +import torch.nn as nn |
| 15 | + |
| 16 | +from torchao.core.config import AOBaseConfig |
| 17 | +from torchao.prototype.parq.quant import StretchedUnifTorchaoQuantizer |
| 18 | +from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig |
| 19 | +from torchao.prototype.quantization.dynamic_activation_lut import ( |
| 20 | + StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig, |
| 21 | +) |
| 22 | +from torchao.quantization import quantize_ |
| 23 | +from torchao.quantization.granularity import PerAxis, PerGroup |
| 24 | +from torchao.quantization.linear_activation_quantized_tensor import ( |
| 25 | + to_linear_activation_quantized, |
| 26 | +) |
| 27 | +from torchao.quantization.quant_api import ( |
| 28 | + _int8_asymm_per_token_quant, |
| 29 | +) |
| 30 | +from torchao.quantization.transform_module import register_quantize_module_handler |
| 31 | + |
| 32 | +is_arm64_mac = sys.platform == "darwin" and platform.machine() == "arm64" |
| 33 | + |
| 34 | + |
| 35 | +@dataclass |
| 36 | +class Int8DynamicActivationConfig(AOBaseConfig): |
| 37 | + pass |
| 38 | + |
| 39 | + |
| 40 | +@register_quantize_module_handler(Int8DynamicActivationConfig) |
| 41 | +def _int8_dynamic_activation_transform( |
| 42 | + module: nn.Module, config: Int8DynamicActivationConfig |
| 43 | +) -> nn.Module: |
| 44 | + weight = module.weight |
| 45 | + weight = to_linear_activation_quantized(weight, _int8_asymm_per_token_quant) |
| 46 | + module.weight = torch.nn.Parameter(weight, requires_grad=False) |
| 47 | + return module |
| 48 | + |
| 49 | + |
| 50 | +class ToyLinearModel(torch.nn.Module): |
| 51 | + def __init__(self, d1=512, d2=256, d3=128, d4=8): |
| 52 | + super().__init__() |
| 53 | + self.linear1 = torch.nn.Linear(d1, d2, bias=False) |
| 54 | + self.linear2 = torch.nn.Linear(d2, d3, bias=True) |
| 55 | + self.linear3 = torch.nn.Linear(d3, d4, bias=False) |
| 56 | + |
| 57 | + def example_inputs( |
| 58 | + self, |
| 59 | + lead_dim=(1,), |
| 60 | + dtype=torch.bfloat16, |
| 61 | + ): |
| 62 | + return torch.randn( |
| 63 | + *lead_dim, self.linear1.in_features, dtype=dtype, device="cpu" |
| 64 | + ) |
| 65 | + |
| 66 | + def forward(self, x): |
| 67 | + x = self.linear1(x) |
| 68 | + x = self.linear2(x) |
| 69 | + x = self.linear3(x) |
| 70 | + return x |
| 71 | + |
| 72 | + |
| 73 | +@pytest.fixture(autouse=True) |
| 74 | +def run_before_and_after_tests(): |
| 75 | + yield |
| 76 | + torch._dynamo.reset() # reset cache between tests |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) |
| 80 | +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) |
| 81 | +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) |
| 82 | +@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)]) |
| 83 | +@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac") |
| 84 | +def test_parq_conversion(dtype, granularity, bit_width, lead_dim): |
| 85 | + quantizer = StretchedUnifTorchaoQuantizer(bit_width) |
| 86 | + config = StretchedIntxWeightOnlyConfig( |
| 87 | + b=bit_width, |
| 88 | + quant_min=quantizer.quant_min, |
| 89 | + quant_max=quantizer.quant_max, |
| 90 | + granularity=granularity, |
| 91 | + ) |
| 92 | + |
| 93 | + parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype) |
| 94 | + activations = parq_model.example_inputs(lead_dim=lead_dim, dtype=dtype) |
| 95 | + quantize_(parq_model, config) |
| 96 | + |
| 97 | + # Apply dynamic activation to parq model. This will serve as the LUT reference |
| 98 | + parq_model_with_dyn_quant = deepcopy(parq_model) |
| 99 | + quantize_( |
| 100 | + parq_model_with_dyn_quant, |
| 101 | + Int8DynamicActivationConfig(), |
| 102 | + # We have to explicitly provide filter_fn because the default linear filter |
| 103 | + # excludes modules with AffinQUnatizedTensor weights |
| 104 | + filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear), |
| 105 | + ) |
| 106 | + |
| 107 | + # Convert PARQ model to lowbit LUT model |
| 108 | + lut_model = deepcopy(parq_model) |
| 109 | + conversion_config = ( |
| 110 | + StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig( |
| 111 | + config.b, config.granularity |
| 112 | + ) |
| 113 | + ) |
| 114 | + quantize_(lut_model, conversion_config, filter_fn=conversion_config.get_filter_fn()) |
| 115 | + |
| 116 | + # Run both models and compare |
| 117 | + parq_out = parq_model(activations) |
| 118 | + parq_with_dyn_quant_out = parq_model_with_dyn_quant(activations) |
| 119 | + lut_out = lut_model(activations) |
| 120 | + |
| 121 | + assert torch.allclose(parq_out, parq_with_dyn_quant_out, atol=1e-1, rtol=1e-1) |
| 122 | + if dtype == torch.float32: |
| 123 | + assert torch.allclose(lut_out, parq_with_dyn_quant_out, atol=1e-4, rtol=1e-4) |
| 124 | + elif dtype == torch.bfloat16: |
| 125 | + assert torch.allclose(lut_out, parq_with_dyn_quant_out, atol=1e-2, rtol=1e-2) |
| 126 | + else: |
| 127 | + raise ValueError(f"Unsupported dtype {dtype}") |
| 128 | + |
| 129 | + |
| 130 | +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) |
| 131 | +@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)]) |
| 132 | +@pytest.mark.parametrize("bit_width", [1, 2, 3, 4]) |
| 133 | +@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)]) |
| 134 | +@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac") |
| 135 | +def test_export(dtype, granularity, bit_width, lead_dim): |
| 136 | + quantizer = StretchedUnifTorchaoQuantizer(bit_width) |
| 137 | + config = StretchedIntxWeightOnlyConfig( |
| 138 | + b=bit_width, |
| 139 | + quant_min=quantizer.quant_min, |
| 140 | + quant_max=quantizer.quant_max, |
| 141 | + granularity=granularity, |
| 142 | + ) |
| 143 | + |
| 144 | + parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype) |
| 145 | + activations = parq_model.example_inputs(lead_dim=lead_dim) |
| 146 | + quantize_(parq_model, config) |
| 147 | + |
| 148 | + conversion_config = ( |
| 149 | + StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig( |
| 150 | + config.b, config.granularity |
| 151 | + ) |
| 152 | + ) |
| 153 | + quantize_( |
| 154 | + parq_model, conversion_config, filter_fn=conversion_config.get_filter_fn() |
| 155 | + ) |
| 156 | + |
| 157 | + ep = torch.export.export(parq_model, (activations,)) |
| 158 | + assert ( |
| 159 | + f"torch.ops.torchao._linear_8bit_act_{bit_width}bit_weight.default" |
| 160 | + in ep.graph_module.code |
| 161 | + ) |
0 commit comments