8
8
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
9
9
# the op to the coremltools library.
10
10
11
- import torch as _torch
12
- from coremltools import _logger
13
- from coremltools .converters .mil .frontend import _utils
14
- from coremltools .converters .mil .frontend .torch .ops import (
15
- _get_inputs ,
16
- _get_kwinputs ,
17
- NUM_TO_NUMPY_DTYPE ,
18
- NUM_TO_TORCH_DTYPE ,
19
- split ,
20
- to ,
21
- transpose ,
22
- unbind ,
23
- )
24
11
import numpy as np
25
- from coremltools .converters .mil .frontend .torch .torch_op_registry import (
26
- register_torch_op ,
27
- )
28
- from coremltools .converters .mil .mil import types
29
- from executorch .exir .dim_order_utils import get_memory_format
30
-
31
-
32
- # https://github.com/apple/coremltools/pull/2556
33
- @register_torch_op (override = False )
34
- def transpose_copy (context , node ):
35
- transpose (context , node )
36
-
37
-
38
- # https://github.com/apple/coremltools/pull/2557
39
- @register_torch_op (override = False )
40
- def unbind_copy (context , node ):
41
- unbind (context , node )
42
-
43
-
44
- # https://github.com/apple/coremltools/pull/2563
45
- @register_torch_op (override = False )
46
- def split_copy (context , node ):
47
- split (context , node )
48
-
49
-
50
- @register_torch_op (
51
- torch_alias = [
52
- "dim_order_ops::_to_dim_order_copy" ,
53
- "dim_order_ops._to_dim_order_copy" ,
54
- ],
55
- override = False ,
56
- )
57
- def _to_dim_order_copy (context , node ):
58
- dim_order = _get_kwinputs (context , node , "dim_order" , default = [None ])[0 ]
59
- node .kwinputs .pop ("dim_order" )
60
-
61
- # In CoreML, dim_order.val will be an ndarray, so we convert it to a list
62
- dim_order = [int (d ) for d in dim_order .val ]
63
- memory_format = get_memory_format (dim_order )
64
- assert (
65
- memory_format == _torch .contiguous_format
66
- ), "Only contiguous memory format is supported in CoreML"
67
- to (context , node )
68
-
69
-
70
- # https://github.com/apple/coremltools/pull/2558
71
- @register_torch_op (
72
- torch_alias = ["torchao::dequantize_affine" , "torchao.dequantize_affine" ],
73
- override = False ,
74
- )
75
- def dequantize_affine (context , node ):
76
- inputs = _get_inputs (context , node , expected = [7 , 8 ])
77
- int_data = inputs [0 ].val
78
- block_size = inputs [1 ].val
79
- scale = inputs [2 ].val
80
- zero_point = (
81
- inputs [3 ].val if inputs [3 ] is not None and inputs [3 ].val is not None else None
82
- )
83
- # I do not think we need to worry about input_dtype b/c it gets cast to int4/int8
84
- # For now, we just check that it is int8 or int32
85
- input_dtype = inputs [4 ].val # noqa: F841
86
- assert NUM_TO_TORCH_DTYPE [input_dtype ] in [
87
- _torch .int8 ,
88
- _torch .int32 ,
89
- ], "input_dtype should be int8 or int32"
90
-
91
- quant_min = inputs [5 ].val
92
- quant_max = inputs [6 ].val
93
-
94
- assert len (int_data .shape ) == 2 , "dequantize_affine only supports rank 2 inputs"
95
-
96
- assert len (int_data .shape ) == len (
97
- block_size
98
- ), "block_size must have the same length as int_data.shape"
99
- assert block_size [0 ] == 1 , "block_size[0] must be 1"
100
- group_size = block_size [1 ]
101
- k = int_data .shape [1 ]
102
- assert k % group_size == 0 , "k must be divisible by group_size"
103
- scales_per_row = k // group_size
104
- scale = scale .reshape (- 1 , scales_per_row )
105
- if zero_point is not None :
106
- zero_point = zero_point .reshape (- 1 , scales_per_row )
107
-
108
- # TODO: I don't know if CoreML can make use of this
109
- # We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass
110
- # For now, we just log a warning
111
- out_np_dtype = None
112
- if len (inputs ) > 7 :
113
- out_np_dtype = NUM_TO_NUMPY_DTYPE [inputs [7 ].val ]
114
- _logger .warning (
115
- f"Core ML ignores output_dtype { out_np_dtype } on torchao.dequantize_affine and instead uses the native precision."
116
- )
117
-
118
- if quant_min == - 8 and quant_max == 7 :
119
- quantized_np_dtype = types .nptype_from_builtin (types .string_to_builtin ("int4" ))
120
- elif quant_min == - 128 and quant_max == 127 :
121
- quantized_np_dtype = types .nptype_from_builtin (types .string_to_builtin ("int8" ))
122
- else :
123
- raise ValueError (
124
- f"Unsupported quantization range: { quant_min } to { quant_max } . CoreML only supports 4-bit and 8-bit quantization."
125
- )
126
-
127
- output = _utils ._construct_constexpr_dequant_op (
128
- int_data .astype (quantized_np_dtype ),
129
- zero_point ,
130
- scale ,
131
- axis = - 1 ,
132
- name = node .name ,
133
- )
134
- context .add (output , node .name )
135
-
136
-
137
-
138
- # codes: torch.Tensor,
139
- # codebook: torch.Tensor,
140
- # code_dtype: torch.dtype,
141
- # block_size: List[int],
142
- # output_dtype: torch.dtype = torch.float32,
143
-
144
- # Copyright (c) Meta Platforms, Inc. and affiliates.
145
- # All rights reserved.
146
- #
147
- # This source code is licensed under the BSD-style license found in the
148
- # LICENSE file in the root directory of this source tree.
149
-
150
- # This file registers torch ops that are not yet in coremltools, or are in a more recent version of
151
- # coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
152
- # the op to the coremltools library.
153
-
154
12
import torch as _torch
155
13
from coremltools import _logger
156
14
from coremltools .converters .mil .frontend import _utils
@@ -164,8 +22,6 @@ def dequantize_affine(context, node):
164
22
transpose ,
165
23
unbind ,
166
24
)
167
- import numpy as np
168
-
169
25
from coremltools .converters .mil .frontend .torch .torch_op_registry import (
170
26
register_torch_op ,
171
27
)
@@ -283,24 +139,25 @@ def dequantize_affine(context, node):
283
139
override = False ,
284
140
)
285
141
def dequantize_codebook (context , node ):
286
- print ("IN DEQUANTIZE CODEBOOK" )
287
142
inputs = _get_inputs (context , node , expected = [4 , 5 ])
288
143
codes = inputs [0 ].val
289
144
codebook = inputs [1 ].val
290
- code_dtype = inputs [2 ].val
291
- block_size = inputs [3 ].val
145
+ nbits = inputs [2 ].val
292
146
293
-
294
- assert len ( codes . shape ) == 2 , "Only rank 2 inputs are supported"
147
+ # information in block_size is redundant with codebook.shape
148
+ block_size = inputs [ 3 ]. val # noqa: F841
295
149
296
- # In TorchAO, the codebook shape is (n_lut, nbit, 1). The LUTs are for the columns.
297
- # In CoreML, the expected shape is (lut_block_size, nbit, 1). 1 here is for scalar
298
- # lut_block_size has the same rank as codes/idxs and tells you how many LUTs there are per block, e.g.,
299
- # lut_block_size=(1, 8) means there is 1 LUT per 8 columns
300
- assert len (codebook .shape ) == 3 , "Only rank 3 inputs are supported for codebook"
301
- assert codebook .shape [- 1 ] == 1 , "we only support scalar palletization"
302
- codebook = np .expand_dims (codebook , 0 )
150
+ assert len (codes .shape ) == 2 , "Only rank 2 inputs are supported"
303
151
152
+ # Assert codebook is as expected. codebook.dim() = codes.dim() + 2
153
+ assert len (codebook .shape ) == 4 , "Only rank 4 inputs are supported for codebook"
154
+ assert codebook .shape [0 ] == 1 , "Only grouped_channel granularity is supported"
155
+ n_luts = codebook .shape [1 ]
156
+ assert (
157
+ codes .shape [1 ] % n_luts == 0
158
+ ), "codes.shape[1] must be divisible by codebook.shape[1]"
159
+ assert codebook .shape [2 ] == 2 ** nbits
160
+ assert codebook .shape [3 ] == 1 , "Only scalar look up values are supported"
304
161
305
162
if len (inputs ) > 4 :
306
163
output_dtype = inputs [4 ].val
@@ -309,11 +166,6 @@ def dequantize_codebook(context, node):
309
166
f"Core ML ignores output_dtype { out_np_dtype } on torchao.dequantize_affine and instead uses the native precision."
310
167
)
311
168
312
- print ("codes" , codes .shape )
313
- print ("codebook" , codebook .shape )
314
- print ("code_dtype" , code_dtype )
315
- print ("block_size" , block_size )
316
-
317
169
output = _utils ._construct_constexpr_lut_op (
318
170
codes .astype (np .int8 ),
319
171
codebook ,
0 commit comments