Skip to content

Commit 4c1673b

Browse files
committed
Add palletization support
1 parent d4c78ab commit 4c1673b

File tree

2 files changed

+220
-5
lines changed

2 files changed

+220
-5
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,150 @@
2121
transpose,
2222
unbind,
2323
)
24+
import numpy as np
25+
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
26+
register_torch_op,
27+
)
28+
from coremltools.converters.mil.mil import types
29+
from executorch.exir.dim_order_utils import get_memory_format
30+
31+
32+
# https://github.com/apple/coremltools/pull/2556
33+
@register_torch_op(override=False)
34+
def transpose_copy(context, node):
35+
transpose(context, node)
36+
37+
38+
# https://github.com/apple/coremltools/pull/2557
39+
@register_torch_op(override=False)
40+
def unbind_copy(context, node):
41+
unbind(context, node)
42+
43+
44+
# https://github.com/apple/coremltools/pull/2563
45+
@register_torch_op(override=False)
46+
def split_copy(context, node):
47+
split(context, node)
48+
49+
50+
@register_torch_op(
51+
torch_alias=[
52+
"dim_order_ops::_to_dim_order_copy",
53+
"dim_order_ops._to_dim_order_copy",
54+
],
55+
override=False,
56+
)
57+
def _to_dim_order_copy(context, node):
58+
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
59+
node.kwinputs.pop("dim_order")
60+
61+
# In CoreML, dim_order.val will be an ndarray, so we convert it to a list
62+
dim_order = [int(d) for d in dim_order.val]
63+
memory_format = get_memory_format(dim_order)
64+
assert (
65+
memory_format == _torch.contiguous_format
66+
), "Only contiguous memory format is supported in CoreML"
67+
to(context, node)
68+
69+
70+
# https://github.com/apple/coremltools/pull/2558
71+
@register_torch_op(
72+
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
73+
override=False,
74+
)
75+
def dequantize_affine(context, node):
76+
inputs = _get_inputs(context, node, expected=[7, 8])
77+
int_data = inputs[0].val
78+
block_size = inputs[1].val
79+
scale = inputs[2].val
80+
zero_point = (
81+
inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None
82+
)
83+
# I do not think we need to worry about input_dtype b/c it gets cast to int4/int8
84+
# For now, we just check that it is int8 or int32
85+
input_dtype = inputs[4].val # noqa: F841
86+
assert NUM_TO_TORCH_DTYPE[input_dtype] in [
87+
_torch.int8,
88+
_torch.int32,
89+
], "input_dtype should be int8 or int32"
90+
91+
quant_min = inputs[5].val
92+
quant_max = inputs[6].val
93+
94+
assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs"
95+
96+
assert len(int_data.shape) == len(
97+
block_size
98+
), "block_size must have the same length as int_data.shape"
99+
assert block_size[0] == 1, "block_size[0] must be 1"
100+
group_size = block_size[1]
101+
k = int_data.shape[1]
102+
assert k % group_size == 0, "k must be divisible by group_size"
103+
scales_per_row = k // group_size
104+
scale = scale.reshape(-1, scales_per_row)
105+
if zero_point is not None:
106+
zero_point = zero_point.reshape(-1, scales_per_row)
107+
108+
# TODO: I don't know if CoreML can make use of this
109+
# We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass
110+
# For now, we just log a warning
111+
out_np_dtype = None
112+
if len(inputs) > 7:
113+
out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val]
114+
_logger.warning(
115+
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
116+
)
117+
118+
if quant_min == -8 and quant_max == 7:
119+
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4"))
120+
elif quant_min == -128 and quant_max == 127:
121+
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8"))
122+
else:
123+
raise ValueError(
124+
f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization."
125+
)
126+
127+
output = _utils._construct_constexpr_dequant_op(
128+
int_data.astype(quantized_np_dtype),
129+
zero_point,
130+
scale,
131+
axis=-1,
132+
name=node.name,
133+
)
134+
context.add(output, node.name)
135+
136+
137+
138+
# codes: torch.Tensor,
139+
# codebook: torch.Tensor,
140+
# code_dtype: torch.dtype,
141+
# block_size: List[int],
142+
# output_dtype: torch.dtype = torch.float32,
143+
144+
# Copyright (c) Meta Platforms, Inc. and affiliates.
145+
# All rights reserved.
146+
#
147+
# This source code is licensed under the BSD-style license found in the
148+
# LICENSE file in the root directory of this source tree.
149+
150+
# This file registers torch ops that are not yet in coremltools, or are in a more recent version of
151+
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
152+
# the op to the coremltools library.
153+
154+
import torch as _torch
155+
from coremltools import _logger
156+
from coremltools.converters.mil.frontend import _utils
157+
from coremltools.converters.mil.frontend.torch.ops import (
158+
_get_inputs,
159+
_get_kwinputs,
160+
NUM_TO_NUMPY_DTYPE,
161+
NUM_TO_TORCH_DTYPE,
162+
split,
163+
to,
164+
transpose,
165+
unbind,
166+
)
167+
import numpy as np
24168

25169
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
26170
register_torch_op,
@@ -132,3 +276,47 @@ def dequantize_affine(context, node):
132276
name=node.name,
133277
)
134278
context.add(output, node.name)
279+
280+
281+
@register_torch_op(
282+
torch_alias=["quant::dequantize_codebook", "quant.dequantize_codebook"],
283+
override=False,
284+
)
285+
def dequantize_codebook(context, node):
286+
print("IN DEQUANTIZE CODEBOOK")
287+
inputs = _get_inputs(context, node, expected=[4, 5])
288+
codes = inputs[0].val
289+
codebook = inputs[1].val
290+
code_dtype = inputs[2].val
291+
block_size = inputs[3].val
292+
293+
294+
assert len(codes.shape) == 2, "Only rank 2 inputs are supported"
295+
296+
# In TorchAO, the codebook shape is (n_lut, nbit, 1). The LUTs are for the columns.
297+
# In CoreML, the expected shape is (lut_block_size, nbit, 1). 1 here is for scalar
298+
# lut_block_size has the same rank as codes/idxs and tells you how many LUTs there are per block, e.g.,
299+
# lut_block_size=(1, 8) means there is 1 LUT per 8 columns
300+
assert len(codebook.shape) == 3, "Only rank 3 inputs are supported for codebook"
301+
assert codebook.shape[-1] == 1, "we only support scalar palletization"
302+
codebook = np.expand_dims(codebook, 0)
303+
304+
305+
if len(inputs) > 4:
306+
output_dtype = inputs[4].val
307+
out_np_dtype = NUM_TO_NUMPY_DTYPE[output_dtype]
308+
_logger.warning(
309+
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
310+
)
311+
312+
print("codes", codes.shape)
313+
print("codebook", codebook.shape)
314+
print("code_dtype", code_dtype)
315+
print("block_size", block_size)
316+
317+
output = _utils._construct_constexpr_lut_op(
318+
codes.astype(np.int8),
319+
codebook,
320+
name=node.name,
321+
)
322+
context.add(output, node.name)

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1717
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
1818

19+
from torchao.prototype.quantization.codebook_coreml import (
20+
CodebookWeightOnlyConfig,
21+
)
22+
1923

2024
def is_fbcode():
2125
return not hasattr(torch.version, "git_version")
@@ -163,12 +167,35 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
163167
], f"Got unexpected node target after delegation: {node.target.__name__}"
164168
et_prog = delegated_program.to_executorch()
165169
self._compare_outputs(et_prog, model, example_inputs)
170+
171+
def test_dequantize_codebook_linear(self):
172+
model, example_inputs = self._get_test_model()
173+
quantize_(
174+
model,
175+
CodebookWeightOnlyConfig(dtype=torch.uint8, block_size=[-1, 16]),
176+
)
177+
ep = torch.export.export(model, example_inputs)
178+
print("ORIGINAL MODEL", ep)
179+
delegated_program = executorch.exir.to_edge_transform_and_lower(
180+
ep,
181+
partitioner=[self._coreml_partitioner()],
182+
)
183+
for node in delegated_program.exported_program().graph.nodes:
184+
if node.op == "call_function":
185+
assert node.target.__name__ in [
186+
"executorch_call_delegate",
187+
"getitem",
188+
], f"Got unexpected node target after delegation: {node.target.__name__}"
189+
et_prog = delegated_program.to_executorch()
190+
print(et_prog.exported_program())
191+
self._compare_outputs(et_prog, model, example_inputs)
166192

167193

168194
if __name__ == "__main__":
169195
test_runner = TestTorchOps()
170-
test_runner.test_dequantize_affine_b4w_embedding()
171-
test_runner.test_dequantize_affine_b4w_linear()
172-
test_runner.test_dequantize_affine_c4w_embedding()
173-
test_runner.test_dequantize_affine_c4w_linear()
174-
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
196+
# test_runner.test_dequantize_affine_b4w_embedding()
197+
# test_runner.test_dequantize_affine_b4w_linear()
198+
# test_runner.test_dequantize_affine_c4w_embedding()
199+
# test_runner.test_dequantize_affine_c4w_linear()
200+
# test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
201+
test_runner.test_dequantize_codebook_linear()

0 commit comments

Comments
 (0)