Skip to content

Commit af07feb

Browse files
authored
Add quantization and partitioner flow in the qualcomm doc (#12387)
Summary: Add a session to describe how to lower a model to HTP, including quantization step. Differential Revision: D78117959
1 parent b7ae183 commit af07feb

File tree

3 files changed

+155
-13
lines changed

3 files changed

+155
-13
lines changed

docs/source/backends-qualcomm.md

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,115 @@ The model, inputs, and output location are passed to `qnn_executorch_runner` by
365365

366366
Please refer to `$EXECUTORCH_ROOT/examples/qualcomm/scripts/` and `EXECUTORCH_ROOT/examples/qualcomm/oss_scripts/` to the list of supported models.
367367

368+
## How to Support a Custom Model in HTP Backend
369+
370+
### Step-by-Step Implementation Guide
371+
372+
Please reference [the simple example](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/scripts/export_example.py) and [more compilated examples](https://github.com/pytorch/executorch/tree/main/examples/qualcomm/scripts) for reference
373+
#### Step 1: Prepare Your Model
374+
```python
375+
import torch
376+
377+
# Initialize your custom model
378+
model = YourModelClass().eval() # Your custom PyTorch model
379+
380+
# Create example inputs (adjust shape as needed)
381+
example_inputs = (torch.randn(1, 3, 224, 224),) # Example input tensor
382+
```
383+
384+
#### Step 2: [Optional] Quantize Your Model
385+
Choose between quantization approaches, post training quantization (PTQ) or quantization aware training (QAT):
386+
```python
387+
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
388+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e, convert_pt2e
389+
390+
quantizer = QnnQuantizer()
391+
m = torch.export.export(model, example_inputs, strict=True).module()
392+
393+
# PTQ (Post-Training Quantization)
394+
if quantization_type == "ptq":
395+
prepared_model = prepare_pt2e(m, quantizer)
396+
# Calibration loop would go here
397+
prepared_model(*example_inputs)
398+
399+
# QAT (Quantization-Aware Training)
400+
elif quantization_type == "qat":
401+
prepared_model = prepare_qat_pt2e(m, quantizer)
402+
# Training loop would go here
403+
for _ in range(training_steps):
404+
prepared_model(*example_inputs)
405+
406+
# Convert to quantized model
407+
quantized_model = convert_pt2e(prepared_model)
408+
```
409+
410+
The `QNNQuantizer` is configurable, with the default setting being **8a8w**. For advanced users, refer to the [`QnnQuantizer`](https://github.com/pytorch/executorch/blob/main/backends/qualcomm/quantizer/quantizer.py) documentation for details.
411+
412+
##### Supported Quantization Schemes
413+
- **8a8w** (default)
414+
- **16a16w**
415+
- **16a8w**
416+
- **16a4w**
417+
- **16a4w_block**
418+
419+
##### Customization Options
420+
- **Per-node annotation**: Use `custom_quant_annotations`.
421+
- **Per-module (`nn.Module`) annotation**: Use `submodule_qconfig_list`.
422+
423+
##### Additional Features
424+
- **Node exclusion**: Discard specific nodes via `discard_nodes`.
425+
- **Blockwise quantization**: Configure block sizes with `block_size_map`.
426+
427+
428+
For practical examples, see [`test_qnn_delegate.py`](https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_qnn_delegate.py).
429+
430+
431+
#### Step 3: Configure Compile Specs
432+
During this step, you will need to specify the target SoC, data type, and other QNN compiler spec.
433+
```python
434+
from executorch.backends.qualcomm.compiler import (
435+
generate_qnn_executorch_compiler_spec,
436+
generate_htp_compiler_spec,
437+
)
438+
from executorch.backends.qualcomm.utils.utils import QcomChipset
439+
440+
# HTP Compiler Configuration
441+
backend_options = generate_htp_compiler_spec(
442+
use_fp16=not quantized, # False for quantized models
443+
)
444+
445+
# QNN Compiler Spec
446+
compile_spec = generate_qnn_executorch_compiler_spec(
447+
soc_model=QcomChipset.SM8650, # Your target SoC
448+
backend_options=backend_options,
449+
)
450+
```
451+
#### Step 4: Lower and Export the Model
452+
```python
453+
from executorch.backends.qualcomm.partition.qnn_partitioner import (
454+
to_edge_transform_and_lower_to_qnn,
455+
)
456+
from executorch.exir import ExecutorchBackendConfig
457+
458+
# Lower to QNN backend
459+
delegated_program = to_edge_transform_and_lower_to_qnn(
460+
quantized_model if quantized else model,
461+
example_inputs,
462+
compile_spec
463+
)
464+
465+
# Export to ExecuTorch format
466+
executorch_program = delegated_program.to_executorch(
467+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
468+
)
469+
470+
# Save the compiled model
471+
model_name = "custom_model_qnn.pte"
472+
with open(model_name, "wb") as f:
473+
f.write(executorch_program.buffer)
474+
print(f"Model successfully exported to {model_name}")
475+
```
476+
368477
## What is coming?
369478

370479
- Improve the performance for llama3-8B-Instruct and support batch prefill.

docs/source/quantization-overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Not all quantization options are supported by all backends. Consult backend-spec
3030

3131
* [XNNPACK quantization](backends-xnnpack.md#quantization)
3232
* [CoreML quantization](backends-coreml.md#quantization)
33+
* [QNN quantization](backends-qualcomm.md#step-2-optional-quantize-your-model)
3334

3435

3536

examples/qualcomm/scripts/export_example.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import torch
66
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
7-
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
87
from executorch.backends.qualcomm.utils.utils import (
98
generate_htp_compiler_spec,
109
generate_qnn_executorch_compiler_spec,
10+
get_soc_to_chipset_map,
1111
to_edge_transform_and_lower_to_qnn,
1212
)
1313
from executorch.devtools import generate_etrecord
@@ -16,7 +16,11 @@
1616
from executorch.exir.capture._config import ExecutorchBackendConfig
1717
from executorch.extension.export_util.utils import save_pte_program
1818

19-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
19+
from torchao.quantization.pt2e.quantize_pt2e import (
20+
convert_pt2e,
21+
prepare_pt2e,
22+
prepare_qat_pt2e,
23+
)
2024

2125

2226
def main() -> None:
@@ -43,6 +47,20 @@ def main() -> None:
4347
help="The folder to store the exported program",
4448
)
4549

50+
parser.add_argument(
51+
"--soc",
52+
type=str,
53+
default="SM8650",
54+
help="Specify the SoC model.",
55+
)
56+
57+
parser.add_argument(
58+
"-q",
59+
"--quantization",
60+
choices=["ptq", "qat"],
61+
help="Run post-traininig quantization.",
62+
)
63+
4664
args = parser.parse_args()
4765

4866
if args.model_name not in MODEL_NAME_TO_MODEL:
@@ -51,27 +69,41 @@ def main() -> None:
5169
f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
5270
)
5371

72+
# Get model and example inputs
5473
model, example_inputs, _, _ = EagerModelFactory.create_model(
5574
*MODEL_NAME_TO_MODEL[args.model_name]
5675
)
5776

5877
# Get quantizer
59-
quantizer = QnnQuantizer()
60-
61-
# Typical pytorch 2.0 quantization flow
62-
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
63-
m = prepare_pt2e(m, quantizer)
64-
# Calibration
65-
m(*example_inputs)
66-
# Get the quantized model
67-
m = convert_pt2e(m)
78+
if args.quantization:
79+
print("Quantizing model...")
80+
# It is the model quantization path
81+
quantizer = QnnQuantizer()
82+
# Typical pytorch 2.0 quantization flow
83+
m = torch.export.export(model.eval(), example_inputs, strict=True).module()
84+
if args.quantization == "qat":
85+
m = prepare_qat_pt2e(m, quantizer)
86+
# Training loop
87+
m(*example_inputs)
88+
elif args.quantization == "ptq":
89+
m = prepare_pt2e(m, quantizer)
90+
# Calibration
91+
m(*example_inputs)
92+
else:
93+
raise RuntimeError(f"Unknown quantization type {args.quantization}")
94+
# Get the quantized model
95+
m = convert_pt2e(m)
96+
else:
97+
# It is the fp model path
98+
m = model
6899

69100
# Capture program for edge IR and delegate to QNN backend
101+
use_fp16 = True if args.quantization is None else False
70102
backend_options = generate_htp_compiler_spec(
71-
use_fp16=False,
103+
use_fp16=use_fp16,
72104
)
73105
compile_spec = generate_qnn_executorch_compiler_spec(
74-
soc_model=QcomChipset.SM8550,
106+
soc_model=get_soc_to_chipset_map()[args.soc],
75107
backend_options=backend_options,
76108
)
77109
delegated_program = to_edge_transform_and_lower_to_qnn(

0 commit comments

Comments
 (0)