Skip to content

Commit 5842d73

Browse files
authored
[OMNIML-2244] enable fp8 and int8 ONNX export (#594)
## What does this PR do? **Type of change:** Example update **Overview:** - Support ONNX export for fp8 and int8 precisions - Added utility functions to check for fp8 and int8 quantization (will be used in ONNXExporter) - Fixed a bug in evaluation API for high batch sizes - Added function to replace zeros from scales to smallest positive value in fp16 ## Usage <!-- You can potentially add a usage example below. --> ```python python torch_quant_to_onnx.py \ --quantize_mode fp8/int8 \ --onnx_save_path <onnx_path> ``` ## Testing Validated the accuracy and latency of int8 and fp8 models: | Metric | INT8 | FP8 | |--------|------|-----| | Top1 Accuracy | 84.584% | 85.062% | | Top5 Accuracy | 97.3% | 97.534% | | Inference Latency | 8.4825 ms | 8.15096 ms | ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No --------- Signed-off-by: ajrasane <[email protected]>
1 parent a5025a2 commit 5842d73

File tree

5 files changed

+53
-7
lines changed

5 files changed

+53
-7
lines changed

examples/onnx_ptq/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Model Optimizer enables highly performant quantization formats including NVFP4,
1313
| Pre-Requisites | Required & optional packages to use this technique | [Link](#pre-requisites) | |
1414
| Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | [Link](#getting-started) | [docs](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/_onnx_quantization.html) |
1515
| Support Matrix | View the ONNX export supported LLM models | [Link](#onnx-export-supported-llm-models) | |
16-
| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-example-for-mxfp8-int4-or-nvfp4-precision) | |
16+
| PyTorch to ONNX | Example scripts demonstrating how to quantize with PyTorch and then convert to ONNX | [Link](#torch-quantization-to-onnx-export-example) | |
1717
| Advanced Features | Examples demonstrating use advanced ONNX quantization features | [Link](#advanced-features) | |
1818
| Pre-Quantized Checkpoints | Ready to deploy Hugging Face pre-quantized checkpoints | [Link](#pre-quantized-checkpoints) | |
1919
| Resources | Extra links to relevant resources | [Link](#resources) | |
@@ -80,7 +80,7 @@ python image_prep.py \
8080

8181
The model can be quantized as an FP8, INT8 or INT4 model using either the CLI or Python API. For FP8 and INT8 quantization, you have a choice between `max` and `entropy` calibration algorithms. For INT4 quantization, [awq_clip](https://arxiv.org/abs/2306.00978) or [rtn_dq](https://ar5iv.labs.arxiv.org/html/2301.12017) algorithms can be chosen.
8282

83-
> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-example-for-mxfp8-int4-or-nvfp4-precision).*
83+
> *For NVFP4 and MXFP8 ONNX, see the [PyTorch to ONNX section](#torch-quantization-to-onnx-export-example).*
8484
8585
> *Minimum opset requirements: int8 (13+), fp8 (21+), int4 (21+). ModelOpt will automatically upgrade lower opset versions to meet these requirements.*
8686
@@ -129,9 +129,9 @@ The top5 accuracy of the model is <accuracy score between 0-100%>
129129
Inference latency of the model is <X> ms
130130
```
131131

132-
## Torch quantization to ONNX example for MXFP8, INT4 or NVFP4 precision
132+
## Torch quantization to ONNX export example
133133

134-
This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model using MXFP8, INT4 or NVFP4 precision formats, and then export it to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export.
134+
This example demonstrates how to quantize a [timm](https://github.com/huggingface/pytorch-image-models) vision model for various precision formats followed by export to ONNX. The script leverages the ModelOpt toolkit for both quantization and ONNX export.
135135

136136
> *Opset 20 is used to export the torch models to ONNX.*
137137
@@ -148,7 +148,7 @@ This example demonstrates how to quantize a [timm](https://github.com/huggingfac
148148
```bash
149149
python torch_quant_to_onnx.py \
150150
--timm_model_name=vit_base_patch16_224 \
151-
--quantize_mode=<mxfp8|nvfp4|int4_awq> \
151+
--quantize_mode=<fp8|mxfp8|int8|nvfp4|int4_awq> \
152152
--onnx_save_path=<path to save the exported ONNX model>
153153
```
154154

examples/onnx_ptq/evaluation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,9 @@ def evaluate_accuracy(
152152

153153
# Calculate accuracy
154154
outputs = outputs[0] if isinstance(outputs, list) else outputs.data
155-
156155
labels_size = labels.size(0)
156+
outputs = outputs[:labels_size]
157+
157158
total += labels_size
158159

159160
labels = labels.to(outputs.device)

examples/onnx_ptq/torch_quant_to_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def main():
323323
)
324324
print(f"Quantized Model - Top-1 Accuracy: {top1:.2f}%, Top-5 Accuracy: {top5:.2f}%")
325325

326-
if args.quantize_mode in ["fp8", "int8", "auto"]:
326+
if args.quantize_mode in ["auto"]:
327327
print(
328328
f"The selected quantization mode {args.quantize_mode} is not supported for ONNX export yet."
329329
)

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,21 @@ def remove_graph_input_q(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
10371037
return onnx_model
10381038

10391039

1040+
def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
1041+
"""Replace zero scale values with smallest nonzero fp16 value in the ONNX model."""
1042+
graph = onnx_model.graph
1043+
fp16_smallest_nonzero = np.float16(6e-08)
1044+
scale_nodes = [node.input[1] for node in graph.node if node.op_type == "QuantizeLinear"]
1045+
for node in graph.node:
1046+
if node.op_type == "Constant" and node.output[0] in scale_nodes:
1047+
for attr in node.attribute:
1048+
if attr.name == "value":
1049+
tensor = numpy_helper.to_array(attr.t)
1050+
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor)
1051+
attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name))
1052+
return onnx_model
1053+
1054+
10401055
def _cast_initializer_to_dtype(
10411056
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
10421057
):

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
qdq_to_dq,
3838
quantize_weights_to_int4,
3939
quantize_weights_to_mxfp8,
40+
replace_zero_scale_with_smallest_nonzero,
4041
)
4142
from modelopt.onnx.utils import (
4243
get_input_names,
@@ -336,6 +337,32 @@ def is_mxfp8_quantized(model: nn.Module) -> bool:
336337
return False
337338

338339

340+
def is_int8_quantized(model: nn.Module) -> bool:
341+
"""Check if the model is quantized in INT8 mode."""
342+
for _, module in model.named_modules():
343+
if (
344+
hasattr(module, "weight_quantizer")
345+
and hasattr(module, "input_quantizer")
346+
and module.weight_quantizer._num_bits == 8
347+
and module.input_quantizer._num_bits == 8
348+
):
349+
return True
350+
return False
351+
352+
353+
def is_fp8_quantized(model: nn.Module) -> bool:
354+
"""Check if the model is quantized in FP8 mode."""
355+
for _, module in model.named_modules():
356+
if (
357+
hasattr(module, "weight_quantizer")
358+
and hasattr(module, "input_quantizer")
359+
and module.weight_quantizer._num_bits == (4, 3)
360+
and module.input_quantizer._num_bits == (4, 3)
361+
):
362+
return True
363+
return False
364+
365+
339366
def get_onnx_bytes_and_metadata(
340367
model: nn.Module,
341368
dummy_input: Any | tuple,
@@ -510,6 +537,9 @@ def get_onnx_bytes_and_metadata(
510537
onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False
511538
)
512539

540+
# TensorRT expects all scales to be postive
541+
onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph)
542+
513543
# If the onnx model contains external data store the external tensors in one file and save the onnx model
514544
if has_external_data(onnx_save_path):
515545
tensor_paths = get_external_tensor_paths(onnx_path)

0 commit comments

Comments
 (0)