From 703caf80b00894fc19924c32b4749e34591d13b6 Mon Sep 17 00:00:00 2001 From: take-cheeze Date: Sun, 13 Jul 2025 14:04:01 +0900 Subject: [PATCH] Add tests for ONNX function Signed-off-by: take-cheeze --- onnxoptimizer/test/function_test.py | 34 +++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 onnxoptimizer/test/function_test.py diff --git a/onnxoptimizer/test/function_test.py b/onnxoptimizer/test/function_test.py new file mode 100644 index 000000000..77b1dc9b5 --- /dev/null +++ b/onnxoptimizer/test/function_test.py @@ -0,0 +1,34 @@ +import io +import onnx +import onnxoptimizer +import pytest +import unittest + +try: + import torch + import torchvision as tv + + has_tv = True +except: + has_tv = False + + +@pytest.mark.skipif(not has_tv, reason="This test needs torchvision") +def test_function_preserved(): + with io.BytesIO() as f: + module = tv.models.resnet18() + torch.onnx.export( + module, + (torch.ones([1, 3, 224, 224], dtype=torch.float32), ), + f, + opset_version=15, + export_modules_as_functions={ + torch.nn.BatchNorm2d, + torch.nn.Conv2d, + } + ) + + model = onnx.load_model_from_string(f.getvalue()) + opt_model = onnxoptimizer.optimize(model) + assert len(model.functions) > 0 + assert len(model.functions) == len(opt_model.functions)