Skip to content

Commit b6ef500

Browse files
authored
Integrate PARQ with lowbit Arm CPU kernels (#2622)
* Integrate PARQ with lowbit Arm CPU kernels * up
1 parent 3515cb6 commit b6ef500

File tree

7 files changed

+493
-4
lines changed

7 files changed

+493
-4
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ jobs:
5353
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
5454
python torchao/experimental/tests/test_embedding_xbit_quantizer.py
5555
python torchao/experimental/tests/test_quant_passes.py
56+
pytest -s test/prototype/test_dynamic_activation_lut.py
5657
- name: Run kernels/cpu/aarch64/tests
5758
if: runner.os == 'macOS'
5859
run: |
@@ -106,7 +107,7 @@ jobs:
106107
# conda run -n test-mps-ops-env pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
107108
# - name: Print torch version
108109
# run: |
109-
110+
110111
# conda run -n test-mps-ops-env python -c "import torch; print(torch.__version__)"
111112
# - name: Install requirements
112113
# run: |
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import platform
8+
import sys
9+
from copy import deepcopy
10+
from dataclasses import dataclass
11+
12+
import pytest
13+
import torch
14+
import torch.nn as nn
15+
16+
from torchao.core.config import AOBaseConfig
17+
from torchao.prototype.parq.quant import StretchedUnifTorchaoQuantizer
18+
from torchao.prototype.parq.quant.quant_api import StretchedIntxWeightOnlyConfig
19+
from torchao.prototype.quantization.dynamic_activation_lut import (
20+
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig,
21+
)
22+
from torchao.quantization import quantize_
23+
from torchao.quantization.granularity import PerAxis, PerGroup
24+
from torchao.quantization.linear_activation_quantized_tensor import (
25+
to_linear_activation_quantized,
26+
)
27+
from torchao.quantization.quant_api import (
28+
_int8_asymm_per_token_quant,
29+
)
30+
from torchao.quantization.transform_module import register_quantize_module_handler
31+
32+
is_arm64_mac = sys.platform == "darwin" and platform.machine() == "arm64"
33+
34+
35+
@dataclass
36+
class Int8DynamicActivationConfig(AOBaseConfig):
37+
pass
38+
39+
40+
@register_quantize_module_handler(Int8DynamicActivationConfig)
41+
def _int8_dynamic_activation_transform(
42+
module: nn.Module, config: Int8DynamicActivationConfig
43+
) -> nn.Module:
44+
weight = module.weight
45+
weight = to_linear_activation_quantized(weight, _int8_asymm_per_token_quant)
46+
module.weight = torch.nn.Parameter(weight, requires_grad=False)
47+
return module
48+
49+
50+
class ToyLinearModel(torch.nn.Module):
51+
def __init__(self, d1=512, d2=256, d3=128, d4=8):
52+
super().__init__()
53+
self.linear1 = torch.nn.Linear(d1, d2, bias=False)
54+
self.linear2 = torch.nn.Linear(d2, d3, bias=True)
55+
self.linear3 = torch.nn.Linear(d3, d4, bias=False)
56+
57+
def example_inputs(
58+
self,
59+
lead_dim=(1,),
60+
dtype=torch.bfloat16,
61+
):
62+
return torch.randn(
63+
*lead_dim, self.linear1.in_features, dtype=dtype, device="cpu"
64+
)
65+
66+
def forward(self, x):
67+
x = self.linear1(x)
68+
x = self.linear2(x)
69+
x = self.linear3(x)
70+
return x
71+
72+
73+
@pytest.fixture(autouse=True)
74+
def run_before_and_after_tests():
75+
yield
76+
torch._dynamo.reset() # reset cache between tests
77+
78+
79+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
80+
@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)])
81+
@pytest.mark.parametrize("bit_width", [1, 2, 3, 4])
82+
@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)])
83+
@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac")
84+
def test_parq_conversion(dtype, granularity, bit_width, lead_dim):
85+
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
86+
config = StretchedIntxWeightOnlyConfig(
87+
b=bit_width,
88+
quant_min=quantizer.quant_min,
89+
quant_max=quantizer.quant_max,
90+
granularity=granularity,
91+
)
92+
93+
parq_model = ToyLinearModel(128, 256, 128, 1).to(dtype)
94+
activations = parq_model.example_inputs(lead_dim=lead_dim, dtype=dtype)
95+
quantize_(parq_model, config)
96+
97+
# Apply dynamic activation to parq model. This will serve as the LUT reference
98+
parq_model_with_dyn_quant = deepcopy(parq_model)
99+
quantize_(
100+
parq_model_with_dyn_quant,
101+
Int8DynamicActivationConfig(),
102+
# We have to explicitly provide filter_fn because the default linear filter
103+
# excludes modules with AffinQUnatizedTensor weights
104+
filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear),
105+
)
106+
107+
# Convert PARQ model to lowbit LUT model
108+
lut_model = deepcopy(parq_model)
109+
conversion_config = (
110+
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig(
111+
config.b, config.granularity
112+
)
113+
)
114+
quantize_(lut_model, conversion_config, filter_fn=conversion_config.get_filter_fn())
115+
116+
# Run both models and compare
117+
parq_out = parq_model(activations)
118+
parq_with_dyn_quant_out = parq_model_with_dyn_quant(activations)
119+
lut_out = lut_model(activations)
120+
121+
assert torch.allclose(parq_out, parq_with_dyn_quant_out, atol=1e-1, rtol=1e-1)
122+
if dtype == torch.float32:
123+
assert torch.allclose(lut_out, parq_with_dyn_quant_out, atol=1e-4, rtol=1e-4)
124+
elif dtype == torch.bfloat16:
125+
assert torch.allclose(lut_out, parq_with_dyn_quant_out, atol=1e-2, rtol=1e-2)
126+
else:
127+
raise ValueError(f"Unsupported dtype {dtype}")
128+
129+
130+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
131+
@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)])
132+
@pytest.mark.parametrize("bit_width", [1, 2, 3, 4])
133+
@pytest.mark.parametrize("lead_dim", [(5,), (2, 3)])
134+
@pytest.mark.skipif(not is_arm64_mac, reason="requires arm64 mac")
135+
def test_export(dtype, granularity, bit_width, lead_dim):
136+
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
137+
config = StretchedIntxWeightOnlyConfig(
138+
b=bit_width,
139+
quant_min=quantizer.quant_min,
140+
quant_max=quantizer.quant_max,
141+
granularity=granularity,
142+
)
143+
144+
parq_model = ToyLinearModel(128, 256, 128, 8).to(dtype)
145+
activations = parq_model.example_inputs(lead_dim=lead_dim)
146+
quantize_(parq_model, config)
147+
148+
conversion_config = (
149+
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig(
150+
config.b, config.granularity
151+
)
152+
)
153+
quantize_(
154+
parq_model, conversion_config, filter_fn=conversion_config.get_filter_fn()
155+
)
156+
157+
ep = torch.export.export(parq_model, (activations,))
158+
assert (
159+
f"torch.ops.torchao._linear_8bit_act_{bit_width}bit_weight.default"
160+
in ep.graph_module.code
161+
)

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void register_ukernel_config_lut(
184184
namespace kernel = torchao::kernels::cpu::aarch64::linear::
185185
channelwise_8bit_activation_groupwise_lowbit_weight;
186186

187-
if (cpuinfo_has_arm_neon_dot()) {
187+
if (!cpuinfo_has_arm_neon_dot()) {
188188
return;
189189
}
190190
if (format.has_weight_zeros) {

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ Tensor pack_weights_with_lut_cpu(
251251
"weight_scales must be float32");
252252
TORCHAO_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D");
253253
TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1");
254+
TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16");
254255
TORCHAO_CHECK(
255256
weight_scales.size(0) == ((n * k) / group_size),
256257
"expected 1 scale per group");
@@ -285,8 +286,8 @@ Tensor pack_weights_with_lut_cpu(
285286
weight_nbit>(target, has_weight_zeros, has_bias);
286287
TORCHAO_CHECK(packed_weights_format.nr == 8, "nr must be 8");
287288
TORCHAO_CHECK(
288-
lut_channel_group_size % 8 == 0,
289-
"the lut_channel_group_size must be a multiple of nr (8)");
289+
lut_channel_group_size == n || lut_channel_group_size % 8 == 0,
290+
"the lut_channel_group_size must be n or a multiple of nr (8)");
290291

291292
auto packed_weights_header = packed_weights_format.to_packed_weights_header();
292293
auto uk = torchao::ops::linear_8bit_act_xbit_weight::select_ukernel_config<
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .api import StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig
2+
from .int8_dynamic_activation_lut_tensor import Int8DynamicActivationLutTensor
3+
4+
__all__ = [
5+
"StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig",
6+
"Int8DynamicActivationLutTensor",
7+
]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Callable
9+
10+
import torch
11+
import torch.nn as nn
12+
13+
from torchao.core.config import AOBaseConfig
14+
from torchao.prototype.parq.quant.quant_api import StretchedAffineQuantizedTensor
15+
from torchao.prototype.quantization.dynamic_activation_lut.int8_dynamic_activation_lut_tensor import (
16+
Int8DynamicActivationLutTensor,
17+
)
18+
from torchao.quantization.granularity import Granularity, PerAxis, PerGroup
19+
from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS
20+
from torchao.quantization.transform_module import register_quantize_module_handler
21+
22+
23+
@dataclass
24+
class StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig(
25+
AOBaseConfig
26+
):
27+
bit_width: int
28+
granularity: Granularity
29+
30+
def get_filter_fn(self) -> Callable[[nn.Module, str], bool]:
31+
return lambda m, fqn: isinstance(m, torch.nn.Linear) and isinstance(
32+
m.weight, StretchedAffineQuantizedTensor
33+
)
34+
35+
36+
@register_quantize_module_handler(
37+
StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig
38+
)
39+
def _(
40+
module: nn.Module,
41+
config: StretchedAffineQuantizedTensor_to_Int8DynamicActivationLutTensorConfig,
42+
) -> nn.Module:
43+
weight = module.weight
44+
bias = module.bias
45+
assert isinstance(weight, StretchedAffineQuantizedTensor)
46+
47+
b = config.bit_width
48+
granularity = config.granularity
49+
if isinstance(granularity, PerGroup):
50+
group_size = granularity.group_size
51+
elif isinstance(granularity, PerAxis):
52+
assert granularity.axis == 0, (
53+
f"axis must be 0 with PerAxis, but got {granularity.axis}"
54+
)
55+
group_size = weight.shape[-1]
56+
else:
57+
raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}")
58+
59+
int_data, scale, zero_point = weight.tensor_impl.get_plain()
60+
q_min, q_max = _DTYPE_TO_QVALUE_BOUNDS[getattr(torch, f"int{b}")]
61+
62+
# Construct LUT as 2 * ([q_min, q_max] - 0.5)
63+
assert torch.all(zero_point == -0.5)
64+
lut = torch.arange(q_min, q_max + 1)
65+
lut = 2 * lut + 1
66+
67+
# Construct idx values
68+
qval_idx = int_data - q_min
69+
70+
# Construct scale
71+
scale = scale.reshape(-1).to(torch.float32)
72+
scale = 0.5 * scale # since we multiply LUT values by 2
73+
74+
weight_tensor = Int8DynamicActivationLutTensor.from_plain(
75+
qval_idx,
76+
lut,
77+
scale,
78+
group_size,
79+
bias.to(torch.float32) if bias is not None else None,
80+
)
81+
module.weight = torch.nn.Parameter(weight_tensor, requires_grad=False)
82+
module.bias = None
83+
return module

0 commit comments

Comments
 (0)