Skip to content

Commit f8b4835

Browse files
Arm backend: Add VGF unit tests to operators (Part 2) (#13033)
- Included aten.ge to aten.rsqrt - Ops not completed: aten.index_select and aten.index_tensor Signed-off-by: Yufeng Shi <[email protected]>
1 parent 9078b49 commit f8b4835

37 files changed

+1527
-4
lines changed

backends/arm/test/ops/test_ge.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OpNotSupportedPipeline,
1414
TosaPipelineFP,
1515
TosaPipelineINT,
16+
VgfPipeline,
1617
)
1718

1819
input_t = Tuple[torch.Tensor]
@@ -181,3 +182,55 @@ def test_ge_scalar_u85_INT(test_module):
181182
run_on_fvp=True,
182183
)
183184
pipeline.run()
185+
186+
187+
@common.parametrize("test_module", test_data_tensor)
188+
@common.SkipIfNoModelConverter
189+
def test_ge_tensor_vgf_FP(test_module):
190+
pipeline = VgfPipeline[input_t](
191+
test_module(),
192+
test_module().get_inputs(),
193+
GreaterEqual.aten_op_tensor,
194+
GreaterEqual.exir_op,
195+
tosa_version="TOSA-1.0+FP",
196+
)
197+
pipeline.run()
198+
199+
200+
@common.parametrize("test_module", test_data_tensor)
201+
@common.SkipIfNoModelConverter
202+
def test_ge_tensor_vgf_INT(test_module):
203+
pipeline = VgfPipeline[input_t](
204+
test_module(),
205+
test_module().get_inputs(),
206+
GreaterEqual.aten_op_tensor,
207+
GreaterEqual.exir_op,
208+
tosa_version="TOSA-1.0+INT",
209+
)
210+
pipeline.run()
211+
212+
213+
@common.parametrize("test_module", test_data_scalar)
214+
@common.SkipIfNoModelConverter
215+
def test_ge_scalar_vgf_FP(test_module):
216+
pipeline = VgfPipeline[input_t](
217+
test_module(),
218+
test_module().get_inputs(),
219+
GreaterEqual.aten_op_scalar,
220+
GreaterEqual.exir_op,
221+
tosa_version="TOSA-1.0+FP",
222+
)
223+
pipeline.run()
224+
225+
226+
@common.parametrize("test_module", test_data_scalar)
227+
@common.SkipIfNoModelConverter
228+
def test_ge_scalar_vgf_INT(test_module):
229+
pipeline = VgfPipeline[input_t](
230+
test_module(),
231+
test_module().get_inputs(),
232+
GreaterEqual.aten_op_tensor,
233+
GreaterEqual.exir_op,
234+
tosa_version="TOSA-1.0+INT",
235+
)
236+
pipeline.run()

backends/arm/test/ops/test_gelu.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
EthosU85PipelineINT,
1313
TosaPipelineFP,
1414
TosaPipelineINT,
15+
VgfPipeline,
1516
)
1617

1718
input_t1 = Tuple[torch.Tensor]
@@ -125,3 +126,31 @@ def test_gelu_u85_INT(test_data: input_t1):
125126
Gelu.aten_op,
126127
Gelu.exir_op,
127128
).run()
129+
130+
131+
@common.parametrize("test_data", Gelu.test_data)
132+
@common.SkipIfNoModelConverter
133+
def test_gelu_vgf_FP(test_data: input_t1):
134+
approximate, data = test_data()
135+
pipeline = VgfPipeline[input_t1](
136+
Gelu(approximate),
137+
(data,),
138+
Gelu.aten_op,
139+
Gelu.exir_op,
140+
tosa_version="TOSA-1.0+FP",
141+
)
142+
pipeline.run()
143+
144+
145+
@common.parametrize("test_data", Gelu.test_data)
146+
@common.SkipIfNoModelConverter
147+
def test_gelu_vgf_INT(test_data: input_t1):
148+
approximate, data = test_data()
149+
pipeline = VgfPipeline[input_t1](
150+
Gelu(approximate),
151+
(data,),
152+
Gelu.aten_op,
153+
Gelu.exir_op,
154+
tosa_version="TOSA-1.0+INT",
155+
)
156+
pipeline.run()

backends/arm/test/ops/test_group_norm.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
EthosU85PipelineINT,
1111
TosaPipelineFP,
1212
TosaPipelineINT,
13+
VgfPipeline,
1314
)
1415

1516

@@ -143,3 +144,56 @@ def test_native_group_norm_u85_INT(test_data):
143144
)
144145
pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1)
145146
pipeline.run()
147+
148+
149+
@common.parametrize(
150+
"test_data",
151+
test_data_suite,
152+
xfails={
153+
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue",
154+
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue",
155+
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue",
156+
"rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue",
157+
},
158+
strict=False,
159+
)
160+
@common.SkipIfNoModelConverter
161+
def test_native_group_norm_vgf_FP(test_data):
162+
aten_op = "torch.ops.aten.group_norm.default"
163+
exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default"
164+
model, inp = test_data
165+
pipeline = VgfPipeline[input_t](
166+
inp,
167+
model,
168+
aten_op=aten_op,
169+
exir_op=exir_op,
170+
tosa_version="TOSA-1.0+FP",
171+
)
172+
pipeline.run()
173+
174+
175+
@common.parametrize(
176+
"test_data",
177+
test_data_suite,
178+
xfails={
179+
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue",
180+
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue",
181+
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue",
182+
"rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue",
183+
},
184+
strict=False,
185+
)
186+
@common.SkipIfNoModelConverter
187+
def test_native_group_norm_vgf_INT(test_data):
188+
aten_op = "torch.ops.aten.sub.Tensor"
189+
exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default"
190+
model, inp = test_data
191+
pipeline = VgfPipeline[input_t](
192+
inp,
193+
model,
194+
aten_op=aten_op,
195+
exir_op=exir_op,
196+
tosa_version="TOSA-1.0+INT",
197+
atol=0.1, # TODO: "MLETORCH-925: Fix numerical issue for aten.native_group_norm"
198+
)
199+
pipeline.run()

backends/arm/test/ops/test_gt.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
OpNotSupportedPipeline,
1414
TosaPipelineFP,
1515
TosaPipelineINT,
16+
VgfPipeline,
1617
)
1718

1819

@@ -186,3 +187,55 @@ def test_gt_scalar_u85_INT(test_module):
186187
run_on_fvp=True,
187188
)
188189
pipeline.run()
190+
191+
192+
@common.parametrize("test_module", test_data_tensor)
193+
@common.SkipIfNoModelConverter
194+
def test_gt_tensor_vgf_FP(test_module):
195+
pipeline = VgfPipeline[input_t](
196+
test_module(),
197+
test_module().get_inputs(),
198+
Greater.aten_op_tensor,
199+
Greater.exir_op,
200+
tosa_version="TOSA-1.0+FP",
201+
)
202+
pipeline.run()
203+
204+
205+
@common.parametrize("test_module", test_data_scalar)
206+
@common.SkipIfNoModelConverter
207+
def test_gt_scalar_vgf_FP(test_module):
208+
pipeline = VgfPipeline[input_t](
209+
test_module(),
210+
test_module().get_inputs(),
211+
Greater.aten_op_scalar,
212+
Greater.exir_op,
213+
tosa_version="TOSA-1.0+FP",
214+
)
215+
pipeline.run()
216+
217+
218+
@common.parametrize("test_module", test_data_tensor)
219+
@common.SkipIfNoModelConverter
220+
def test_gt_tensor_vgf_INT(test_module):
221+
pipeline = VgfPipeline[input_t](
222+
test_module(),
223+
test_module().get_inputs(),
224+
Greater.aten_op_tensor,
225+
Greater.exir_op,
226+
tosa_version="TOSA-1.0+INT",
227+
)
228+
pipeline.run()
229+
230+
231+
@common.parametrize("test_module", test_data_scalar)
232+
@common.SkipIfNoModelConverter
233+
def test_gt_scalar_vgf_INT(test_module):
234+
pipeline = VgfPipeline[input_t](
235+
test_module(),
236+
test_module().get_inputs(),
237+
Greater.aten_op_tensor,
238+
Greater.exir_op,
239+
tosa_version="TOSA-1.0+INT",
240+
)
241+
pipeline.run()

backends/arm/test/ops/test_hardsigmoid.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
EthosU85PipelineINT,
1515
TosaPipelineFP,
1616
TosaPipelineINT,
17+
VgfPipeline,
1718
)
1819

1920
aten_op = "torch.ops.aten.hardsigmoid.default"
@@ -87,3 +88,25 @@ def test_hardsigmoid_u85_INT(test_data: torch.Tensor):
8788
use_to_edge_transform_and_lower=True,
8889
)
8990
pipeline.run()
91+
92+
93+
@common.parametrize("test_data", test_data_suite)
94+
@common.SkipIfNoModelConverter
95+
def test_hardsigmoid_vgf_FP(test_data: torch.Tensor):
96+
pipeline = VgfPipeline[input_t1](
97+
Hardsigmoid(), (test_data(),), aten_op, exir_op=[], tosa_version="TOSA-1.0+FP"
98+
)
99+
pipeline.run()
100+
101+
102+
@common.parametrize("test_data", test_data_suite)
103+
@common.SkipIfNoModelConverter
104+
def test_hardsigmoid_vgf_INT(test_data: torch.Tensor):
105+
pipeline = VgfPipeline[input_t1](
106+
Hardsigmoid(),
107+
(test_data(),),
108+
aten_op,
109+
exir_op=[],
110+
tosa_version="TOSA-1.0+INT",
111+
)
112+
pipeline.run()

backends/arm/test/ops/test_hardswish.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
EthosU85PipelineINT,
1515
TosaPipelineFP,
1616
TosaPipelineINT,
17+
VgfPipeline,
1718
)
1819

1920
aten_op = "torch.ops.aten.hardswish.default"
@@ -77,3 +78,25 @@ def test_hardswish_u85_INT(test_data):
7778
run_on_fvp=True,
7879
use_to_edge_transform_and_lower=True,
7980
).run()
81+
82+
83+
@common.parametrize("test_data", test_data_suite)
84+
@common.SkipIfNoModelConverter
85+
def test_hardswish_vgf_FP(test_data):
86+
pipeline = VgfPipeline[input_t1](
87+
Hardswish(), (test_data(),), aten_op, exir_op, tosa_version="TOSA-1.0+FP"
88+
)
89+
pipeline.run()
90+
91+
92+
@common.parametrize("test_data", test_data_suite)
93+
@common.SkipIfNoModelConverter
94+
def test_hardswish_vgf_INT(test_data):
95+
pipeline = VgfPipeline[input_t1](
96+
Hardswish(),
97+
(test_data(),),
98+
aten_op,
99+
exir_op,
100+
tosa_version="TOSA-1.0+INT",
101+
)
102+
pipeline.run()

backends/arm/test/ops/test_hardtanh.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
EthosU85PipelineINT,
1717
TosaPipelineFP,
1818
TosaPipelineINT,
19+
VgfPipeline,
1920
)
2021

2122
test_data_suite = {
@@ -86,3 +87,25 @@ def test_hardtanh_u85_INT(test_data: torch.Tensor):
8687
run_on_fvp=True,
8788
)
8889
pipeline.run()
90+
91+
92+
@common.parametrize("test_data", test_data_suite)
93+
@common.SkipIfNoModelConverter
94+
def test_hardtanh_vgf_FP(test_data: torch.Tensor):
95+
pipeline = VgfPipeline[input_t](
96+
HardTanh(), (test_data(),), aten_op, exir_op, tosa_version="TOSA-1.0+FP"
97+
)
98+
pipeline.run()
99+
100+
101+
@common.parametrize("test_data", test_data_suite)
102+
@common.SkipIfNoModelConverter
103+
def test_hardtanh_vgf_INT(test_data: torch.Tensor):
104+
pipeline = VgfPipeline[input_t](
105+
HardTanh(),
106+
(test_data(),),
107+
aten_op,
108+
exir_op,
109+
tosa_version="TOSA-1.0+INT",
110+
)
111+
pipeline.run()

backends/arm/test/ops/test_index_select.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
import pytest
1010

1111
import torch
12+
13+
from executorch.backends.arm.test import common
1214
from executorch.backends.arm.test.tester.test_pipeline import (
1315
TosaPipelineFP,
1416
TosaPipelineINT,
17+
VgfPipeline,
1518
)
1619

1720

@@ -115,3 +118,49 @@ def test_index_select_tosa_INT_rand(test_data: input_params):
115118
"run_method_and_compare_outputs", inputs=test_input, atol=0.9, rtol=0.2, qtol=1
116119
)
117120
pipeline.run()
121+
122+
123+
@pytest.mark.parametrize("test_data", list(test_data.values()))
124+
@common.SkipIfNoModelConverter
125+
def test_index_select_vgf_FP(test_data: input_params):
126+
op, inp = test_data
127+
pipeline = VgfPipeline[input_params](
128+
op,
129+
inp,
130+
op.aten_op,
131+
op.exir_op,
132+
tosa_version="TOSA-1.0+FP",
133+
)
134+
pipeline.run()
135+
136+
137+
@pytest.mark.parametrize("test_data", list(test_data.values())[:-1])
138+
@common.SkipIfNoModelConverter
139+
def test_index_select_vgf_INT(test_data: input_params):
140+
op, inp = test_data
141+
pipeline = VgfPipeline[input_params](
142+
op,
143+
inp,
144+
op.aten_op,
145+
op.exir_op,
146+
tosa_version="TOSA-1.0+INT",
147+
)
148+
pipeline.run()
149+
150+
151+
@pytest.mark.parametrize("test_data", list(test_data.values())[-1:])
152+
@common.SkipIfNoModelConverter
153+
def test_index_select_vgf_INT_rand(test_data: input_params):
154+
op, inp = test_data
155+
pipeline = VgfPipeline[input_params](
156+
op,
157+
inp,
158+
op.aten_op,
159+
op.exir_op,
160+
tosa_version="TOSA-1.0+INT",
161+
)
162+
# TODO: MLETORCH-1136 Change args of run_method_and_compare_outputs of the vgf tests
163+
# pipeline.change_args(
164+
# "run_method_and_compare_outputs", inputs=test_input, atol=0.9, rtol=0.2, qtol=1
165+
# )
166+
pipeline.run()

0 commit comments

Comments
 (0)