|
12 | 12 |
|
13 | 13 | import executorch.backends.cadence.aot.ops_registrations # noqa
|
14 | 14 | import torch
|
15 |
| -from executorch.backends.cadence.aot import compiler |
16 | 15 | from executorch.backends.cadence.aot.fuse_ops import (
|
17 | 16 | FuseCascadedTransposeOrPermuteOps,
|
18 | 17 | FuseCascadedViewOps,
|
|
30 | 29 | from executorch.exir.dialects._ops import ops as exir_ops
|
31 | 30 | from executorch.exir.dialects.edge._ops import EdgeOpOverload
|
32 | 31 | from executorch.exir.pass_base import PassResult, ProxyValue
|
33 |
| -from torch import nn |
34 | 32 |
|
35 | 33 |
|
36 | 34 | class TestFusionPassesBase(unittest.TestCase):
|
@@ -178,43 +176,6 @@ def test_keep_mm_add_with_multiple_users(self) -> None:
|
178 | 176 | self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.mm.default), 1)
|
179 | 177 | self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.add.Tensor), 3)
|
180 | 178 |
|
181 |
| - # TODO(matthiascremon) -> None: enable that pass with new flow |
182 |
| - @torch.no_grad() |
183 |
| - @unittest.expectedFailure |
184 |
| - def test_legacy_conv_bn_fusion(self) -> None: |
185 |
| - class ModelConvBN(torch.nn.Module): |
186 |
| - def __init__( |
187 |
| - self, in_features: int, out_features: int, kernel_size: int |
188 |
| - ) -> None: |
189 |
| - super().__init__() |
190 |
| - self.conv1d = nn.Conv1d(in_features, out_features, kernel_size) |
191 |
| - self.bn = nn.BatchNorm1d(out_features) |
192 |
| - |
193 |
| - def forward(self, x: torch.Tensor) -> torch.Tensor: |
194 |
| - y = self.conv1d(x) |
195 |
| - return self.bn(y) |
196 |
| - |
197 |
| - model = ModelConvBN(64, 1, 2) |
198 |
| - x = torch.randn(1, 64, 4) |
199 |
| - |
200 |
| - graph_module = ( |
201 |
| - compiler.export_to_executorch_gen_etrecord(model.eval(), (x,)) |
202 |
| - .exported_program() |
203 |
| - .graph_module |
204 |
| - ) |
205 |
| - # Assert that after running the fusion passes, batchnorm was fused with conv1d |
206 |
| - self.assertEqual( |
207 |
| - count_node(graph_module, torch.ops.aten.linear.out) |
208 |
| - + count_node(graph_module, torch.ops.cadence.convolution.out), |
209 |
| - 1, |
210 |
| - ) |
211 |
| - self.assertEqual( |
212 |
| - count_node( |
213 |
| - graph_module, torch.ops.aten._native_batch_norm_legit_no_training.out |
214 |
| - ), |
215 |
| - 0, |
216 |
| - ) |
217 |
| - |
218 | 179 | def test_permute_transpose_fusion(self) -> None:
|
219 | 180 | builder = GraphBuilder()
|
220 | 181 | x = builder.placeholder("x", torch.randn(3, 1, 3, 1, 4, dtype=torch.float32))
|
|
0 commit comments