|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import tempfile |
7 | 8 | import unittest
|
8 | 9 |
|
| 10 | +from typing import Tuple |
| 11 | + |
9 | 12 | import executorch.exir as exir
|
10 | 13 |
|
11 | 14 | import torch
|
|
20 | 23 | # import the xnnpack backend implementation
|
21 | 24 | from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
|
22 | 25 |
|
23 |
| -from executorch.exir import CaptureConfig |
| 26 | +from executorch.exir import ( |
| 27 | + CaptureConfig, |
| 28 | + EdgeCompileConfig, |
| 29 | + EdgeProgramManager, |
| 30 | + to_edge_transform_and_lower, |
| 31 | +) |
| 32 | + |
24 | 33 | from executorch.exir.backend.backend_api import to_backend, validation_disabled
|
25 | 34 | from executorch.exir.passes.spec_prop_pass import SpecPropPass
|
26 | 35 |
|
|
41 | 50 | prepare_fx,
|
42 | 51 | )
|
43 | 52 |
|
44 |
| - |
45 | 53 | class TestXnnQnnBackends(unittest.TestCase):
|
46 | 54 | def test_add_xnnpack_and_dqlinear_qnn(self):
|
47 | 55 | qconfig_mapping = QConfigMapping().set_object_type(
|
@@ -132,3 +140,49 @@ def forward(self, x, y):
|
132 | 140 | self.assertTrue(
|
133 | 141 | torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
|
134 | 142 | )
|
| 143 | + |
| 144 | + def test_serde(self): |
| 145 | + # The module with blank_logprobs() function |
| 146 | + class BlankLogProbsModule(torch.nn.Module): |
| 147 | + def __init__(self) -> None: |
| 148 | + super().__init__() |
| 149 | + self.linear = torch.nn.Linear(768, 1) |
| 150 | + self.log_sigmoid = torch.nn.LogSigmoid() |
| 151 | + |
| 152 | + def forward(self, joint_encodings: torch.Tensor) -> torch.Tensor: |
| 153 | + tanh_out = torch.tanh(joint_encodings) |
| 154 | + linear_out = self.linear(tanh_out) |
| 155 | + blank_output = self.log_sigmoid(linear_out) |
| 156 | + return blank_output |
| 157 | + |
| 158 | + def get_blank_logprobs_inputs_fn() -> Tuple[torch.Tensor, ...]: |
| 159 | + """ |
| 160 | + Get the input to the blank_logprobs() and nonblank_logprobs() functions. |
| 161 | + """ |
| 162 | + return (torch.randn(1, 1, 1, 768),) |
| 163 | + |
| 164 | + model = BlankLogProbsModule() |
| 165 | + # Get the inputs for the logprobs function |
| 166 | + logprobs_fake_inputs = get_blank_logprobs_inputs_fn() |
| 167 | + |
| 168 | + # Export and partition |
| 169 | + aten_prog = torch.export.export(model, logprobs_fake_inputs, strict=True) |
| 170 | + partitioned_prog: EdgeProgramManager = to_edge_transform_and_lower( |
| 171 | + aten_prog, |
| 172 | + partitioner=[XnnpackFloatingPointPartitioner()], |
| 173 | + compile_config=EdgeCompileConfig( |
| 174 | + _check_ir_validity=False, _use_edge_ops=True, |
| 175 | + ), |
| 176 | + ) |
| 177 | + |
| 178 | + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: |
| 179 | + exir.save(partitioned_prog.exported_program(), f.name) |
| 180 | + f.seek(0) |
| 181 | + loaded_model = exir.load(f.name) |
| 182 | + |
| 183 | + self.assertTrue( |
| 184 | + torch.allclose( |
| 185 | + model(*logprobs_fake_inputs), |
| 186 | + loaded_model.module()(*logprobs_fake_inputs), |
| 187 | + ) |
| 188 | + ) |
0 commit comments