Skip to content

Commit 0fa6302

Browse files
committed
Add codebook support to CoreML using quantize_
1 parent 4c1673b commit 0fa6302

File tree

2 files changed

+45
-172
lines changed

2 files changed

+45
-172
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 13 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -8,149 +8,7 @@
88
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99
# the op to the coremltools library.
1010

11-
import torch as _torch
12-
from coremltools import _logger
13-
from coremltools.converters.mil.frontend import _utils
14-
from coremltools.converters.mil.frontend.torch.ops import (
15-
_get_inputs,
16-
_get_kwinputs,
17-
NUM_TO_NUMPY_DTYPE,
18-
NUM_TO_TORCH_DTYPE,
19-
split,
20-
to,
21-
transpose,
22-
unbind,
23-
)
2411
import numpy as np
25-
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
26-
register_torch_op,
27-
)
28-
from coremltools.converters.mil.mil import types
29-
from executorch.exir.dim_order_utils import get_memory_format
30-
31-
32-
# https://github.com/apple/coremltools/pull/2556
33-
@register_torch_op(override=False)
34-
def transpose_copy(context, node):
35-
transpose(context, node)
36-
37-
38-
# https://github.com/apple/coremltools/pull/2557
39-
@register_torch_op(override=False)
40-
def unbind_copy(context, node):
41-
unbind(context, node)
42-
43-
44-
# https://github.com/apple/coremltools/pull/2563
45-
@register_torch_op(override=False)
46-
def split_copy(context, node):
47-
split(context, node)
48-
49-
50-
@register_torch_op(
51-
torch_alias=[
52-
"dim_order_ops::_to_dim_order_copy",
53-
"dim_order_ops._to_dim_order_copy",
54-
],
55-
override=False,
56-
)
57-
def _to_dim_order_copy(context, node):
58-
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
59-
node.kwinputs.pop("dim_order")
60-
61-
# In CoreML, dim_order.val will be an ndarray, so we convert it to a list
62-
dim_order = [int(d) for d in dim_order.val]
63-
memory_format = get_memory_format(dim_order)
64-
assert (
65-
memory_format == _torch.contiguous_format
66-
), "Only contiguous memory format is supported in CoreML"
67-
to(context, node)
68-
69-
70-
# https://github.com/apple/coremltools/pull/2558
71-
@register_torch_op(
72-
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
73-
override=False,
74-
)
75-
def dequantize_affine(context, node):
76-
inputs = _get_inputs(context, node, expected=[7, 8])
77-
int_data = inputs[0].val
78-
block_size = inputs[1].val
79-
scale = inputs[2].val
80-
zero_point = (
81-
inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None
82-
)
83-
# I do not think we need to worry about input_dtype b/c it gets cast to int4/int8
84-
# For now, we just check that it is int8 or int32
85-
input_dtype = inputs[4].val # noqa: F841
86-
assert NUM_TO_TORCH_DTYPE[input_dtype] in [
87-
_torch.int8,
88-
_torch.int32,
89-
], "input_dtype should be int8 or int32"
90-
91-
quant_min = inputs[5].val
92-
quant_max = inputs[6].val
93-
94-
assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs"
95-
96-
assert len(int_data.shape) == len(
97-
block_size
98-
), "block_size must have the same length as int_data.shape"
99-
assert block_size[0] == 1, "block_size[0] must be 1"
100-
group_size = block_size[1]
101-
k = int_data.shape[1]
102-
assert k % group_size == 0, "k must be divisible by group_size"
103-
scales_per_row = k // group_size
104-
scale = scale.reshape(-1, scales_per_row)
105-
if zero_point is not None:
106-
zero_point = zero_point.reshape(-1, scales_per_row)
107-
108-
# TODO: I don't know if CoreML can make use of this
109-
# We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass
110-
# For now, we just log a warning
111-
out_np_dtype = None
112-
if len(inputs) > 7:
113-
out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val]
114-
_logger.warning(
115-
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
116-
)
117-
118-
if quant_min == -8 and quant_max == 7:
119-
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4"))
120-
elif quant_min == -128 and quant_max == 127:
121-
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8"))
122-
else:
123-
raise ValueError(
124-
f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization."
125-
)
126-
127-
output = _utils._construct_constexpr_dequant_op(
128-
int_data.astype(quantized_np_dtype),
129-
zero_point,
130-
scale,
131-
axis=-1,
132-
name=node.name,
133-
)
134-
context.add(output, node.name)
135-
136-
137-
138-
# codes: torch.Tensor,
139-
# codebook: torch.Tensor,
140-
# code_dtype: torch.dtype,
141-
# block_size: List[int],
142-
# output_dtype: torch.dtype = torch.float32,
143-
144-
# Copyright (c) Meta Platforms, Inc. and affiliates.
145-
# All rights reserved.
146-
#
147-
# This source code is licensed under the BSD-style license found in the
148-
# LICENSE file in the root directory of this source tree.
149-
150-
# This file registers torch ops that are not yet in coremltools, or are in a more recent version of
151-
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
152-
# the op to the coremltools library.
153-
15412
import torch as _torch
15513
from coremltools import _logger
15614
from coremltools.converters.mil.frontend import _utils
@@ -164,8 +22,6 @@ def dequantize_affine(context, node):
16422
transpose,
16523
unbind,
16624
)
167-
import numpy as np
168-
16925
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
17026
register_torch_op,
17127
)
@@ -283,24 +139,25 @@ def dequantize_affine(context, node):
283139
override=False,
284140
)
285141
def dequantize_codebook(context, node):
286-
print("IN DEQUANTIZE CODEBOOK")
287142
inputs = _get_inputs(context, node, expected=[4, 5])
288143
codes = inputs[0].val
289144
codebook = inputs[1].val
290-
code_dtype = inputs[2].val
291-
block_size = inputs[3].val
145+
nbits = inputs[2].val
292146

293-
294-
assert len(codes.shape) == 2, "Only rank 2 inputs are supported"
147+
# information in block_size is redundant with codebook.shape
148+
block_size = inputs[3].val # noqa: F841
295149

296-
# In TorchAO, the codebook shape is (n_lut, nbit, 1). The LUTs are for the columns.
297-
# In CoreML, the expected shape is (lut_block_size, nbit, 1). 1 here is for scalar
298-
# lut_block_size has the same rank as codes/idxs and tells you how many LUTs there are per block, e.g.,
299-
# lut_block_size=(1, 8) means there is 1 LUT per 8 columns
300-
assert len(codebook.shape) == 3, "Only rank 3 inputs are supported for codebook"
301-
assert codebook.shape[-1] == 1, "we only support scalar palletization"
302-
codebook = np.expand_dims(codebook, 0)
150+
assert len(codes.shape) == 2, "Only rank 2 inputs are supported"
303151

152+
# Assert codebook is as expected. codebook.dim() = codes.dim() + 2
153+
assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook"
154+
assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported"
155+
n_luts = codebook.shape[1]
156+
assert (
157+
codes.shape[1] % n_luts == 0
158+
), "codes.shape[1] must be divisible by codebook.shape[1]"
159+
assert codebook.shape[2] == 2**nbits
160+
assert codebook.shape[3] == 1, "Only scalar look up values are supported"
304161

305162
if len(inputs) > 4:
306163
output_dtype = inputs[4].val
@@ -309,11 +166,6 @@ def dequantize_codebook(context, node):
309166
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
310167
)
311168

312-
print("codes", codes.shape)
313-
print("codebook", codebook.shape)
314-
print("code_dtype", code_dtype)
315-
print("block_size", block_size)
316-
317169
output = _utils._construct_constexpr_lut_op(
318170
codes.astype(np.int8),
319171
codebook,

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414

1515
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1616
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17-
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
1817

19-
from torchao.prototype.quantization.codebook_coreml import (
20-
CodebookWeightOnlyConfig,
21-
)
18+
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig
19+
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
2220

2321

2422
def is_fbcode():
@@ -167,12 +165,12 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
167165
], f"Got unexpected node target after delegation: {node.target.__name__}"
168166
et_prog = delegated_program.to_executorch()
169167
self._compare_outputs(et_prog, model, example_inputs)
170-
168+
171169
def test_dequantize_codebook_linear(self):
172170
model, example_inputs = self._get_test_model()
173171
quantize_(
174172
model,
175-
CodebookWeightOnlyConfig(dtype=torch.uint8, block_size=[-1, 16]),
173+
CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[-1, 16]),
176174
)
177175
ep = torch.export.export(model, example_inputs)
178176
print("ORIGINAL MODEL", ep)
@@ -190,12 +188,35 @@ def test_dequantize_codebook_linear(self):
190188
print(et_prog.exported_program())
191189
self._compare_outputs(et_prog, model, example_inputs)
192190

191+
def test_dequantize_codebook_embedding(self):
192+
model, example_inputs = self._get_test_model()
193+
quantize_(
194+
model,
195+
CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[-1, 16]),
196+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
197+
)
198+
ep = torch.export.export(model, example_inputs)
199+
delegated_program = executorch.exir.to_edge_transform_and_lower(
200+
ep,
201+
partitioner=[self._coreml_partitioner()],
202+
)
203+
for node in delegated_program.exported_program().graph.nodes:
204+
if node.op == "call_function":
205+
assert node.target.__name__ in [
206+
"executorch_call_delegate",
207+
"getitem",
208+
], f"Got unexpected node target after delegation: {node.target.__name__}"
209+
et_prog = delegated_program.to_executorch()
210+
print(et_prog.exported_program())
211+
self._compare_outputs(et_prog, model, example_inputs)
212+
193213

194214
if __name__ == "__main__":
195215
test_runner = TestTorchOps()
196-
# test_runner.test_dequantize_affine_b4w_embedding()
197-
# test_runner.test_dequantize_affine_b4w_linear()
198-
# test_runner.test_dequantize_affine_c4w_embedding()
199-
# test_runner.test_dequantize_affine_c4w_linear()
200-
# test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
216+
test_runner.test_dequantize_affine_b4w_embedding()
217+
test_runner.test_dequantize_affine_b4w_linear()
218+
test_runner.test_dequantize_affine_c4w_embedding()
219+
test_runner.test_dequantize_affine_c4w_linear()
220+
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
201221
test_runner.test_dequantize_codebook_linear()
222+
test_runner.test_dequantize_codebook_embedding()

0 commit comments

Comments
 (0)