18
18
from executorch .examples .apple .coreml .llama .utils import (
19
19
replace_linear_with_split_linear ,
20
20
)
21
- from executorch .examples .models .llama .source_transformation .quantize import (
22
- EmbeddingQuantHandler ,
23
- )
24
21
25
22
from executorch .exir .backend .utils import format_delegated_graph
26
23
from executorch .exir .capture ._config import EdgeCompileConfig , ExecutorchBackendConfig
27
24
from executorch .exir .passes import MemoryPlanningPass
28
25
from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
29
26
from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30
- from executorch .exir .program ._program import to_edge , to_edge_transform_and_lower
27
+ from executorch .exir .program ._program import to_edge_transform_and_lower
31
28
from executorch .extension .export_util .utils import save_pte_program
32
29
30
+ from torchao .quantization .granularity import PerAxis , PerGroup
31
+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
32
+ from torchao .utils import unwrap_tensor_subclass
33
+
33
34
34
35
def main () -> None :
35
36
parser = argparse .ArgumentParser ()
@@ -115,19 +116,8 @@ def main() -> None:
115
116
export_args .dtype
116
117
] # dtype for model/inputs
117
118
118
- if export_args .embedding_quantize :
119
- bitwidth , group_size = export_args .embedding_quantize .split ("," )
120
- if group_size == "none" or group_size == "None" or group_size == "0" :
121
- group_size = None
122
- else :
123
- group_size = int (group_size )
124
- bitwidth = int (bitwidth )
125
- model = EmbeddingQuantHandler (
126
- model ,
127
- bitwidth = bitwidth ,
128
- group_size = group_size ,
129
- packed = (bitwidth in [2 , 4 ]),
130
- ).quantized_model ()
119
+ model .eval ()
120
+ model .to (float_dtype )
131
121
132
122
if export_args .target_split_size is not None :
133
123
replace_linear_with_split_linear (
@@ -140,24 +130,49 @@ def main() -> None:
140
130
in_max_splits = 1 ,
141
131
)
142
132
143
- model .eval ()
144
- model .to (float_dtype )
133
+ # Quantization
134
+ if export_args .embedding_quantize :
135
+ bitwidth , group_size = export_args .embedding_quantize .split ("," )
136
+ bitwidth = int (bitwidth )
137
+ assert bitwidth in [4 , 8 ], "CoreML only supports 4-bit and 8-bit quantization"
138
+ group_size = int (group_size )
139
+ if group_size == 0 :
140
+ granularity = PerAxis (0 )
141
+ else :
142
+ granularity = PerGroup (group_size )
143
+ weight_dtype = getattr (torch , f"int{ bitwidth } " )
145
144
145
+ quantize_ (
146
+ model ,
147
+ IntxWeightOnlyConfig (weight_dtype = weight_dtype , granularity = granularity ),
148
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
149
+ )
150
+
151
+ # CoreML's op_linear_quantizer_config appears to have a bug where the quantization
152
+ # quality is subpar. We use torchao APIs instead, which are now supported by CoreML
146
153
op_linear_quantizer_config = None
154
+ # op_linear_quantizer_config = {
155
+ # "mode": "linear_symmetric",
156
+ # "dtype": "int4",
157
+ # "granularity": "per_channel",
158
+ # }
159
+
147
160
if export_args .coreml_quantize == "b4w" :
148
- op_linear_quantizer_config = {
149
- "mode" : "linear_symmetric" ,
150
- "dtype" : "int4" ,
151
- "granularity" : "per_block" ,
152
- "block_size" : 32 ,
153
- "weight_threshold" : 512 ,
154
- }
161
+ quantize_ (
162
+ model ,
163
+ IntxWeightOnlyConfig (
164
+ weight_dtype = torch . int4 ,
165
+ granularity = PerGroup ( 32 ) ,
166
+ ) ,
167
+ )
155
168
elif export_args .coreml_quantize == "c4w" :
156
- op_linear_quantizer_config = {
157
- "mode" : "linear_symmetric" ,
158
- "dtype" : "int4" ,
159
- "granularity" : "per_channel" ,
160
- }
169
+ quantize_ (
170
+ model ,
171
+ IntxWeightOnlyConfig (
172
+ weight_dtype = torch .int4 ,
173
+ granularity = PerAxis (0 ),
174
+ ),
175
+ )
161
176
162
177
compile_specs = CoreMLBackend .generate_compile_specs ( # pyre-fixme[16]
163
178
minimum_deployment_target = ct .target .iOS18 ,
@@ -172,10 +187,7 @@ def main() -> None:
172
187
partitioner = CoreMLPartitioner ( # pyre-fixme[16]
173
188
compile_specs = compile_specs ,
174
189
take_over_mutable_buffer = False ,
175
- skip_ops_for_coreml_delegation = [
176
- "quantized_decomposed.embedding_4bit.dtype" ,
177
- "aten.embedding.default" ,
178
- ],
190
+ skip_ops_for_coreml_delegation = [],
179
191
)
180
192
181
193
input_manager = InputManager (
@@ -192,31 +204,12 @@ def main() -> None:
192
204
)
193
205
example_inputs = input_manager .get_inputs (tokens = [0 ])
194
206
207
+ model = unwrap_tensor_subclass (model )
208
+
195
209
ep = torch .export .export (model , example_inputs , strict = True )
196
210
print ("Exported program" )
197
211
print (ep )
198
212
199
- # edge_manager = to_edge(
200
- # ep,
201
- # compile_config=EdgeCompileConfig(
202
- # _check_ir_validity=False,
203
- # _skip_dim_order=True,
204
- # preserve_ops=[
205
- # torch.ops.aten.scaled_dot_product_attention.default,
206
- # # preserve norm op for numerical stability
207
- # torch.ops.aten.linalg_vector_norm.default,
208
- # torch.ops.aten.reciprocal.default,
209
- # ],
210
- # ),
211
- # )
212
- # print("Edge program")
213
- # print(edge_manager.exported_program())
214
-
215
- # for node in edge_manager.exported_program().graph_module.graph.nodes:
216
- # print(node.name, node.target, node.args, node.kwargs)
217
-
218
- # edge_manager = edge_manager.to_backend(partitioner)
219
-
220
213
edge_manager = to_edge_transform_and_lower (
221
214
ep ,
222
215
partitioner = [partitioner ],
@@ -227,7 +220,6 @@ def main() -> None:
227
220
)
228
221
229
222
print ("Delegated program" )
230
-
231
223
print (format_delegated_graph (edge_manager .exported_program ().graph_module ))
232
224
233
225
executorch_program = edge_manager .to_executorch (
0 commit comments