Skip to content

Commit 1d80ae2

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
Verify intermediate output capturer on exported program (#12517)
Summary: as title. Reviewed By: Juntian777 Differential Revision: D78141930
1 parent 953fa0e commit 1d80ae2

File tree

3 files changed

+103
-50
lines changed

3 files changed

+103
-50
lines changed

devtools/inspector/tests/inspector_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from executorch.devtools.inspector.tests.inspector_test_utils import (
4747
check_if_debug_handle_to_op_names_match,
48-
check_if_final_outputs_match,
48+
check_if_intermediate_outputs_match,
4949
model_registry,
5050
)
5151
from executorch.exir import (
@@ -526,8 +526,9 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
526526
inspector_instance._get_aot_intermediate_outputs_and_op_names()
527527
)
528528
self.assertTrue(
529-
check_if_final_outputs_match(
530-
"ConvLinearModel", aot_intermediate_outputs
529+
check_if_intermediate_outputs_match(
530+
aot_intermediate_outputs,
531+
mod.get_edge_dialect_expected_intermediate_outputs(),
531532
)
532533
)
533534

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212

13+
from executorch.exir.debug_handle_utils import UNSET_DEBUG_HANDLE
14+
1315

1416
class ConvlLinearModel(nn.Module):
1517
"""
@@ -42,6 +44,7 @@ def forward(self, x):
4244
x = self.linear_layer(x)
4345
x = x + self.additional_bias
4446
x = x - 0.1
47+
x = x.to(x.dtype)
4548
x = x * self.scale_factor
4649
x = x / (self.scale_factor + 1.0)
4750
x = F.relu(x)
@@ -57,9 +60,9 @@ def get_input():
5760
return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
5861

5962
@staticmethod
60-
def get_expected_intermediate_outputs():
63+
def get_edge_dialect_expected_intermediate_outputs():
6164
"""
62-
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
65+
Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input.
6366
"""
6467
return {
6568
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
@@ -94,6 +97,26 @@ def get_expected_debug_handle_to_op_names():
9497
(11,): ["aten_split_with_sizes_copy_default"],
9598
}
9699

100+
@staticmethod
101+
def get_exported_program_expected_intermediate_outputs():
102+
"""
103+
Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input.
104+
"""
105+
return {
106+
(UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]),
107+
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
108+
(2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
109+
(3,): torch.tensor([[5.0000, 14.1200]]),
110+
(4,): torch.tensor([[5.5000, 13.6200]]),
111+
(5,): torch.tensor([[5.4000, 13.5200]]),
112+
(6,): torch.tensor([[10.8000, 6.7600]]),
113+
(7,): torch.tensor([3.0000, 1.5000]),
114+
(8,): torch.tensor([[3.6000, 4.5067]]),
115+
(9,): torch.tensor([[3.6000, 4.5067]]),
116+
(10,): torch.tensor([[0.9734, 0.9891]]),
117+
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
118+
}
119+
97120

98121
# Global model registry
99122
model_registry = {
@@ -102,13 +125,14 @@ def get_expected_debug_handle_to_op_names():
102125
}
103126

104127

105-
def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
128+
def check_if_intermediate_outputs_match(
129+
actual_outputs_with_handles, expected_outputs_with_handles
130+
):
106131
"""
107132
Checks if the actual outputs match the expected outputs for the specified model.
108133
Returns True if all outputs match, otherwise returns False.
109134
"""
110-
model_instance = model_registry[model_name]
111-
expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs()
135+
112136
if len(actual_outputs_with_handles) != len(expected_outputs_with_handles):
113137
return False
114138
for debug_handle, expected_output in expected_outputs_with_handles.items():

devtools/inspector/tests/intermediate_output_capturer_test.py

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,67 +7,95 @@
77
# pyre-unsafe
88

99
import unittest
10+
from typing import Dict, Tuple, Union
1011

1112
import torch
13+
14+
from executorch.devtools.inspector._inspector_utils import (
15+
DebugHandle,
16+
propagate_back_debug_handle,
17+
)
1218
from executorch.devtools.inspector._intermediate_output_capturer import (
1319
IntermediateOutputCapturer,
1420
)
1521
from executorch.devtools.inspector.tests.inspector_test_utils import (
16-
check_if_final_outputs_match,
22+
check_if_intermediate_outputs_match,
1723
model_registry,
1824
)
25+
1926
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
2027
from torch.export import export, ExportedProgram
21-
from torch.fx import GraphModule
2228

2329

2430
class TestIntermediateOutputCapturer(unittest.TestCase):
25-
def _set_up_model(self, model_name):
26-
model = model_registry[model_name]()
27-
input_tensor = model.get_input()
28-
aten_model: ExportedProgram = export(model, (input_tensor,), strict=True)
29-
edge_program_manager: EdgeProgramManager = to_edge(
30-
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
31+
def _capture_intermediate_outputs_and_check(
32+
self,
33+
inputs: Tuple[torch.Tensor],
34+
ep: ExportedProgram,
35+
expected_intermediate_outputs: Dict[
36+
DebugHandle, Union[torch.Tensor, Tuple[torch.Tensor]]
37+
],
38+
):
39+
captured_intermediate_outputs = IntermediateOutputCapturer(
40+
ep.module()
41+
).run_and_capture(inputs)
42+
43+
# Test keying with debug handle tuple
44+
for key in captured_intermediate_outputs.keys():
45+
self.assertIsInstance(key, tuple)
46+
47+
# Test tensor cloning and detaching
48+
for output in captured_intermediate_outputs.values():
49+
if isinstance(output, torch.Tensor):
50+
self.assertFalse(output.requires_grad)
51+
self.assertTrue(output.is_leaf)
52+
53+
# Test placeholder nodes are skipped
54+
for node in ep.graph.nodes:
55+
if node.op == "placeholder":
56+
self.assertNotIn(node.meta.get("debug_handle"), node.meta)
57+
58+
# Test multiple outputs capture
59+
for inter_output in captured_intermediate_outputs.values():
60+
if isinstance(inter_output, tuple):
61+
for part in output:
62+
self.assertIsInstance(part, torch.Tensor)
63+
64+
# Test capture correct outputs
65+
self.assertTrue(
66+
check_if_intermediate_outputs_match(
67+
captured_intermediate_outputs, expected_intermediate_outputs
68+
)
3169
)
32-
graph_module: GraphModule = edge_program_manager._edge_programs[
33-
"forward"
34-
].module()
35-
capturer = IntermediateOutputCapturer(graph_module)
36-
intermediate_outputs = capturer.run_and_capture(input_tensor)
37-
return input_tensor, graph_module, capturer, intermediate_outputs
3870

3971
def test_models(self):
4072
available_models = list(model_registry.keys())
4173
for model_name in available_models:
4274
with self.subTest(model=model_name):
43-
input_tensor, graph_module, capturer, intermediate_outputs = (
44-
self._set_up_model(model_name)
75+
model = model_registry[model_name]()
76+
input_tensor = model.get_input()
77+
aten_model: ExportedProgram = export(model, (input_tensor,))
78+
aten_model_graph_id = id(aten_model.graph)
79+
80+
edge_program_manager: EdgeProgramManager = to_edge(
81+
aten_model,
82+
compile_config=EdgeCompileConfig(_check_ir_validity=True),
4583
)
4684

47-
# Test keying with debug handle tuple
48-
for key in intermediate_outputs.keys():
49-
self.assertIsInstance(key, tuple)
50-
51-
# Test tensor cloning and detaching
52-
for output in intermediate_outputs.values():
53-
if isinstance(output, torch.Tensor):
54-
self.assertFalse(output.requires_grad)
55-
self.assertTrue(output.is_leaf)
56-
57-
# Test placeholder nodes are skipped
58-
for node in graph_module.graph.nodes:
59-
if node.op == "placeholder":
60-
self.assertNotIn(node.meta.get("debug_handle"), node.meta)
61-
62-
# Test multiple outputs capture
63-
outputs = capturer.run_and_capture(input_tensor)
64-
for output in outputs.values():
65-
if isinstance(output, tuple):
66-
self.assertEqual(len(output), 2)
67-
for part in output:
68-
self.assertIsInstance(part, torch.Tensor)
69-
70-
# Test capture correct outputs
71-
self.assertTrue(
72-
check_if_final_outputs_match(model_name, intermediate_outputs)
85+
ret = propagate_back_debug_handle(
86+
aten_model,
87+
aten_model_graph_id,
88+
edge_program_manager.exported_program(),
89+
)
90+
assert ret is True
91+
92+
self._capture_intermediate_outputs_and_check(
93+
input_tensor,
94+
aten_model,
95+
model.get_exported_program_expected_intermediate_outputs(),
96+
)
97+
self._capture_intermediate_outputs_and_check(
98+
input_tensor,
99+
edge_program_manager.exported_program(),
100+
model.get_edge_dialect_expected_intermediate_outputs(),
73101
)

0 commit comments

Comments
 (0)