Skip to content

Verify intermediate output capturer on exported program #12517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from executorch.devtools.inspector.tests.inspector_test_utils import (
check_if_debug_handle_to_op_names_match,
check_if_final_outputs_match,
check_if_intermediate_outputs_match,
model_registry,
)
from executorch.exir import (
Expand Down Expand Up @@ -526,8 +526,9 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
inspector_instance._get_aot_intermediate_outputs_and_op_names()
)
self.assertTrue(
check_if_final_outputs_match(
"ConvLinearModel", aot_intermediate_outputs
check_if_intermediate_outputs_match(
aot_intermediate_outputs,
mod.get_edge_dialect_expected_intermediate_outputs(),
)
)

Expand Down
34 changes: 29 additions & 5 deletions devtools/inspector/tests/inspector_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch.nn as nn
import torch.nn.functional as F

from executorch.exir.debug_handle_utils import UNSET_DEBUG_HANDLE


class ConvlLinearModel(nn.Module):
"""
Expand Down Expand Up @@ -42,6 +44,7 @@ def forward(self, x):
x = self.linear_layer(x)
x = x + self.additional_bias
x = x - 0.1
x = x.to(x.dtype)
x = x * self.scale_factor
x = x / (self.scale_factor + 1.0)
x = F.relu(x)
Expand All @@ -57,9 +60,9 @@ def get_input():
return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)

@staticmethod
def get_expected_intermediate_outputs():
def get_edge_dialect_expected_intermediate_outputs():
"""
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input.
"""
return {
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
Expand Down Expand Up @@ -94,6 +97,26 @@ def get_expected_debug_handle_to_op_names():
(11,): ["aten_split_with_sizes_copy_default"],
}

@staticmethod
def get_exported_program_expected_intermediate_outputs():
"""
Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input.
"""
return {
(UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]),
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
(2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
(3,): torch.tensor([[5.0000, 14.1200]]),
(4,): torch.tensor([[5.5000, 13.6200]]),
(5,): torch.tensor([[5.4000, 13.5200]]),
(6,): torch.tensor([[10.8000, 6.7600]]),
(7,): torch.tensor([3.0000, 1.5000]),
(8,): torch.tensor([[3.6000, 4.5067]]),
(9,): torch.tensor([[3.6000, 4.5067]]),
(10,): torch.tensor([[0.9734, 0.9891]]),
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
}


# Global model registry
model_registry = {
Expand All @@ -102,13 +125,14 @@ def get_expected_debug_handle_to_op_names():
}


def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
def check_if_intermediate_outputs_match(
actual_outputs_with_handles, expected_outputs_with_handles
):
"""
Checks if the actual outputs match the expected outputs for the specified model.
Returns True if all outputs match, otherwise returns False.
"""
model_instance = model_registry[model_name]
expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs()

if len(actual_outputs_with_handles) != len(expected_outputs_with_handles):
return False
for debug_handle, expected_output in expected_outputs_with_handles.items():
Expand Down
112 changes: 70 additions & 42 deletions devtools/inspector/tests/intermediate_output_capturer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,67 +7,95 @@
# pyre-unsafe

import unittest
from typing import Dict, Tuple, Union

import torch

from executorch.devtools.inspector._inspector_utils import (
DebugHandle,
propagate_back_debug_handle,
)
from executorch.devtools.inspector._intermediate_output_capturer import (
IntermediateOutputCapturer,
)
from executorch.devtools.inspector.tests.inspector_test_utils import (
check_if_final_outputs_match,
check_if_intermediate_outputs_match,
model_registry,
)

from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from torch.export import export, ExportedProgram
from torch.fx import GraphModule


class TestIntermediateOutputCapturer(unittest.TestCase):
def _set_up_model(self, model_name):
model = model_registry[model_name]()
input_tensor = model.get_input()
aten_model: ExportedProgram = export(model, (input_tensor,), strict=True)
edge_program_manager: EdgeProgramManager = to_edge(
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
def _capture_intermediate_outputs_and_check(
self,
inputs: Tuple[torch.Tensor],
ep: ExportedProgram,
expected_intermediate_outputs: Dict[
DebugHandle, Union[torch.Tensor, Tuple[torch.Tensor]]
],
):
captured_intermediate_outputs = IntermediateOutputCapturer(
ep.module()
).run_and_capture(inputs)

# Test keying with debug handle tuple
for key in captured_intermediate_outputs.keys():
self.assertIsInstance(key, tuple)

# Test tensor cloning and detaching
for output in captured_intermediate_outputs.values():
if isinstance(output, torch.Tensor):
self.assertFalse(output.requires_grad)
self.assertTrue(output.is_leaf)

# Test placeholder nodes are skipped
for node in ep.graph.nodes:
if node.op == "placeholder":
self.assertNotIn(node.meta.get("debug_handle"), node.meta)

# Test multiple outputs capture
for inter_output in captured_intermediate_outputs.values():
if isinstance(inter_output, tuple):
for part in output:
self.assertIsInstance(part, torch.Tensor)

# Test capture correct outputs
self.assertTrue(
check_if_intermediate_outputs_match(
captured_intermediate_outputs, expected_intermediate_outputs
)
)
graph_module: GraphModule = edge_program_manager._edge_programs[
"forward"
].module()
capturer = IntermediateOutputCapturer(graph_module)
intermediate_outputs = capturer.run_and_capture(input_tensor)
return input_tensor, graph_module, capturer, intermediate_outputs

def test_models(self):
available_models = list(model_registry.keys())
for model_name in available_models:
with self.subTest(model=model_name):
input_tensor, graph_module, capturer, intermediate_outputs = (
self._set_up_model(model_name)
model = model_registry[model_name]()
input_tensor = model.get_input()
aten_model: ExportedProgram = export(model, (input_tensor,))
aten_model_graph_id = id(aten_model.graph)

edge_program_manager: EdgeProgramManager = to_edge(
aten_model,
compile_config=EdgeCompileConfig(_check_ir_validity=True),
)

# Test keying with debug handle tuple
for key in intermediate_outputs.keys():
self.assertIsInstance(key, tuple)

# Test tensor cloning and detaching
for output in intermediate_outputs.values():
if isinstance(output, torch.Tensor):
self.assertFalse(output.requires_grad)
self.assertTrue(output.is_leaf)

# Test placeholder nodes are skipped
for node in graph_module.graph.nodes:
if node.op == "placeholder":
self.assertNotIn(node.meta.get("debug_handle"), node.meta)

# Test multiple outputs capture
outputs = capturer.run_and_capture(input_tensor)
for output in outputs.values():
if isinstance(output, tuple):
self.assertEqual(len(output), 2)
for part in output:
self.assertIsInstance(part, torch.Tensor)

# Test capture correct outputs
self.assertTrue(
check_if_final_outputs_match(model_name, intermediate_outputs)
ret = propagate_back_debug_handle(
aten_model,
aten_model_graph_id,
edge_program_manager.exported_program(),
)
assert ret is True

self._capture_intermediate_outputs_and_check(
input_tensor,
aten_model,
model.get_exported_program_expected_intermediate_outputs(),
)
self._capture_intermediate_outputs_and_check(
input_tensor,
edge_program_manager.exported_program(),
model.get_edge_dialect_expected_intermediate_outputs(),
)
Loading