Skip to content

Commit be95023

Browse files
committed
up
1 parent 723d990 commit be95023

File tree

2 files changed

+50
-58
lines changed

2 files changed

+50
-58
lines changed

.ci/scripts/test_ane_static_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
2828
# Download stories llama110m artifacts
2929
download_stories_model_artifacts
3030

31-
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w
31+
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32
3232

3333
popd

examples/apple/coreml/llama/export.py

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,19 @@
1818
from executorch.examples.apple.coreml.llama.utils import (
1919
replace_linear_with_split_linear,
2020
)
21-
from executorch.examples.models.llama.source_transformation.quantize import (
22-
EmbeddingQuantHandler,
23-
)
2421

2522
from executorch.exir.backend.utils import format_delegated_graph
2623
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
2724
from executorch.exir.passes import MemoryPlanningPass
2825
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
2926
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30-
from executorch.exir.program._program import to_edge, to_edge_transform_and_lower
27+
from executorch.exir.program._program import to_edge_transform_and_lower
3128
from executorch.extension.export_util.utils import save_pte_program
3229

30+
from torchao.quantization.granularity import PerAxis, PerGroup
31+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
32+
from torchao.utils import unwrap_tensor_subclass
33+
3334

3435
def main() -> None:
3536
parser = argparse.ArgumentParser()
@@ -115,19 +116,8 @@ def main() -> None:
115116
export_args.dtype
116117
] # dtype for model/inputs
117118

118-
if export_args.embedding_quantize:
119-
bitwidth, group_size = export_args.embedding_quantize.split(",")
120-
if group_size == "none" or group_size == "None" or group_size == "0":
121-
group_size = None
122-
else:
123-
group_size = int(group_size)
124-
bitwidth = int(bitwidth)
125-
model = EmbeddingQuantHandler(
126-
model,
127-
bitwidth=bitwidth,
128-
group_size=group_size,
129-
packed=(bitwidth in [2, 4]),
130-
).quantized_model()
119+
model.eval()
120+
model.to(float_dtype)
131121

132122
if export_args.target_split_size is not None:
133123
replace_linear_with_split_linear(
@@ -140,24 +130,49 @@ def main() -> None:
140130
in_max_splits=1,
141131
)
142132

143-
model.eval()
144-
model.to(float_dtype)
133+
# Quantization
134+
if export_args.embedding_quantize:
135+
bitwidth, group_size = export_args.embedding_quantize.split(",")
136+
bitwidth = int(bitwidth)
137+
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
138+
group_size = int(group_size)
139+
if group_size == 0:
140+
granularity = PerAxis(0)
141+
else:
142+
granularity = PerGroup(group_size)
143+
weight_dtype = getattr(torch, f"int{bitwidth}")
145144

145+
quantize_(
146+
model,
147+
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
148+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
149+
)
150+
151+
# CoreML's op_linear_quantizer_config appears to have a bug where the quantization
152+
# quality is subpar. We use torchao APIs instead, which are now supported by CoreML
146153
op_linear_quantizer_config = None
154+
# op_linear_quantizer_config = {
155+
# "mode": "linear_symmetric",
156+
# "dtype": "int4",
157+
# "granularity": "per_channel",
158+
# }
159+
147160
if export_args.coreml_quantize == "b4w":
148-
op_linear_quantizer_config = {
149-
"mode": "linear_symmetric",
150-
"dtype": "int4",
151-
"granularity": "per_block",
152-
"block_size": 32,
153-
"weight_threshold": 512,
154-
}
161+
quantize_(
162+
model,
163+
IntxWeightOnlyConfig(
164+
weight_dtype=torch.int4,
165+
granularity=PerGroup(32),
166+
),
167+
)
155168
elif export_args.coreml_quantize == "c4w":
156-
op_linear_quantizer_config = {
157-
"mode": "linear_symmetric",
158-
"dtype": "int4",
159-
"granularity": "per_channel",
160-
}
169+
quantize_(
170+
model,
171+
IntxWeightOnlyConfig(
172+
weight_dtype=torch.int4,
173+
granularity=PerAxis(0),
174+
),
175+
)
161176

162177
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
163178
minimum_deployment_target=ct.target.iOS18,
@@ -172,10 +187,7 @@ def main() -> None:
172187
partitioner = CoreMLPartitioner( # pyre-fixme[16]
173188
compile_specs=compile_specs,
174189
take_over_mutable_buffer=False,
175-
skip_ops_for_coreml_delegation=[
176-
"quantized_decomposed.embedding_4bit.dtype",
177-
"aten.embedding.default",
178-
],
190+
skip_ops_for_coreml_delegation=[],
179191
)
180192

181193
input_manager = InputManager(
@@ -192,31 +204,12 @@ def main() -> None:
192204
)
193205
example_inputs = input_manager.get_inputs(tokens=[0])
194206

207+
model = unwrap_tensor_subclass(model)
208+
195209
ep = torch.export.export(model, example_inputs, strict=True)
196210
print("Exported program")
197211
print(ep)
198212

199-
# edge_manager = to_edge(
200-
# ep,
201-
# compile_config=EdgeCompileConfig(
202-
# _check_ir_validity=False,
203-
# _skip_dim_order=True,
204-
# preserve_ops=[
205-
# torch.ops.aten.scaled_dot_product_attention.default,
206-
# # preserve norm op for numerical stability
207-
# torch.ops.aten.linalg_vector_norm.default,
208-
# torch.ops.aten.reciprocal.default,
209-
# ],
210-
# ),
211-
# )
212-
# print("Edge program")
213-
# print(edge_manager.exported_program())
214-
215-
# for node in edge_manager.exported_program().graph_module.graph.nodes:
216-
# print(node.name, node.target, node.args, node.kwargs)
217-
218-
# edge_manager = edge_manager.to_backend(partitioner)
219-
220213
edge_manager = to_edge_transform_and_lower(
221214
ep,
222215
partitioner=[partitioner],
@@ -227,7 +220,6 @@ def main() -> None:
227220
)
228221

229222
print("Delegated program")
230-
231223
print(format_delegated_graph(edge_manager.exported_program().graph_module))
232224

233225
executorch_program = edge_manager.to_executorch(

0 commit comments

Comments
 (0)