From 1d80ae296e7d8516adeed73b9044125936a9ddfc Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Tue, 15 Jul 2025 14:42:13 -0700 Subject: [PATCH] Verify intermediate output capturer on exported program (#12517) Summary: as title. Reviewed By: Juntian777 Differential Revision: D78141930 --- devtools/inspector/tests/inspector_test.py | 7 +- .../inspector/tests/inspector_test_utils.py | 34 +++++- .../intermediate_output_capturer_test.py | 112 +++++++++++------- 3 files changed, 103 insertions(+), 50 deletions(-) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 826c93bbed9..c36311afeab 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -45,7 +45,7 @@ ) from executorch.devtools.inspector.tests.inspector_test_utils import ( check_if_debug_handle_to_op_names_match, - check_if_final_outputs_match, + check_if_intermediate_outputs_match, model_registry, ) from executorch.exir import ( @@ -526,8 +526,9 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self): inspector_instance._get_aot_intermediate_outputs_and_op_names() ) self.assertTrue( - check_if_final_outputs_match( - "ConvLinearModel", aot_intermediate_outputs + check_if_intermediate_outputs_match( + aot_intermediate_outputs, + mod.get_edge_dialect_expected_intermediate_outputs(), ) ) diff --git a/devtools/inspector/tests/inspector_test_utils.py b/devtools/inspector/tests/inspector_test_utils.py index 0369b7b26d7..da426377564 100644 --- a/devtools/inspector/tests/inspector_test_utils.py +++ b/devtools/inspector/tests/inspector_test_utils.py @@ -10,6 +10,8 @@ import torch.nn as nn import torch.nn.functional as F +from executorch.exir.debug_handle_utils import UNSET_DEBUG_HANDLE + class ConvlLinearModel(nn.Module): """ @@ -42,6 +44,7 @@ def forward(self, x): x = self.linear_layer(x) x = x + self.additional_bias x = x - 0.1 + x = x.to(x.dtype) x = x * self.scale_factor x = x / (self.scale_factor + 1.0) x = F.relu(x) @@ -57,9 +60,9 @@ def get_input(): return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True) @staticmethod - def get_expected_intermediate_outputs(): + def get_edge_dialect_expected_intermediate_outputs(): """ - Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input. + Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input. """ return { (1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), @@ -94,6 +97,26 @@ def get_expected_debug_handle_to_op_names(): (11,): ["aten_split_with_sizes_copy_default"], } + @staticmethod + def get_exported_program_expected_intermediate_outputs(): + """ + Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input. + """ + return { + (UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]), + (1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]), + (2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]), + (3,): torch.tensor([[5.0000, 14.1200]]), + (4,): torch.tensor([[5.5000, 13.6200]]), + (5,): torch.tensor([[5.4000, 13.5200]]), + (6,): torch.tensor([[10.8000, 6.7600]]), + (7,): torch.tensor([3.0000, 1.5000]), + (8,): torch.tensor([[3.6000, 4.5067]]), + (9,): torch.tensor([[3.6000, 4.5067]]), + (10,): torch.tensor([[0.9734, 0.9891]]), + (11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])], + } + # Global model registry model_registry = { @@ -102,13 +125,14 @@ def get_expected_debug_handle_to_op_names(): } -def check_if_final_outputs_match(model_name, actual_outputs_with_handles): +def check_if_intermediate_outputs_match( + actual_outputs_with_handles, expected_outputs_with_handles +): """ Checks if the actual outputs match the expected outputs for the specified model. Returns True if all outputs match, otherwise returns False. """ - model_instance = model_registry[model_name] - expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs() + if len(actual_outputs_with_handles) != len(expected_outputs_with_handles): return False for debug_handle, expected_output in expected_outputs_with_handles.items(): diff --git a/devtools/inspector/tests/intermediate_output_capturer_test.py b/devtools/inspector/tests/intermediate_output_capturer_test.py index 3c8d2487e70..40834146c74 100644 --- a/devtools/inspector/tests/intermediate_output_capturer_test.py +++ b/devtools/inspector/tests/intermediate_output_capturer_test.py @@ -7,67 +7,95 @@ # pyre-unsafe import unittest +from typing import Dict, Tuple, Union import torch + +from executorch.devtools.inspector._inspector_utils import ( + DebugHandle, + propagate_back_debug_handle, +) from executorch.devtools.inspector._intermediate_output_capturer import ( IntermediateOutputCapturer, ) from executorch.devtools.inspector.tests.inspector_test_utils import ( - check_if_final_outputs_match, + check_if_intermediate_outputs_match, model_registry, ) + from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from torch.export import export, ExportedProgram -from torch.fx import GraphModule class TestIntermediateOutputCapturer(unittest.TestCase): - def _set_up_model(self, model_name): - model = model_registry[model_name]() - input_tensor = model.get_input() - aten_model: ExportedProgram = export(model, (input_tensor,), strict=True) - edge_program_manager: EdgeProgramManager = to_edge( - aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True) + def _capture_intermediate_outputs_and_check( + self, + inputs: Tuple[torch.Tensor], + ep: ExportedProgram, + expected_intermediate_outputs: Dict[ + DebugHandle, Union[torch.Tensor, Tuple[torch.Tensor]] + ], + ): + captured_intermediate_outputs = IntermediateOutputCapturer( + ep.module() + ).run_and_capture(inputs) + + # Test keying with debug handle tuple + for key in captured_intermediate_outputs.keys(): + self.assertIsInstance(key, tuple) + + # Test tensor cloning and detaching + for output in captured_intermediate_outputs.values(): + if isinstance(output, torch.Tensor): + self.assertFalse(output.requires_grad) + self.assertTrue(output.is_leaf) + + # Test placeholder nodes are skipped + for node in ep.graph.nodes: + if node.op == "placeholder": + self.assertNotIn(node.meta.get("debug_handle"), node.meta) + + # Test multiple outputs capture + for inter_output in captured_intermediate_outputs.values(): + if isinstance(inter_output, tuple): + for part in output: + self.assertIsInstance(part, torch.Tensor) + + # Test capture correct outputs + self.assertTrue( + check_if_intermediate_outputs_match( + captured_intermediate_outputs, expected_intermediate_outputs + ) ) - graph_module: GraphModule = edge_program_manager._edge_programs[ - "forward" - ].module() - capturer = IntermediateOutputCapturer(graph_module) - intermediate_outputs = capturer.run_and_capture(input_tensor) - return input_tensor, graph_module, capturer, intermediate_outputs def test_models(self): available_models = list(model_registry.keys()) for model_name in available_models: with self.subTest(model=model_name): - input_tensor, graph_module, capturer, intermediate_outputs = ( - self._set_up_model(model_name) + model = model_registry[model_name]() + input_tensor = model.get_input() + aten_model: ExportedProgram = export(model, (input_tensor,)) + aten_model_graph_id = id(aten_model.graph) + + edge_program_manager: EdgeProgramManager = to_edge( + aten_model, + compile_config=EdgeCompileConfig(_check_ir_validity=True), ) - # Test keying with debug handle tuple - for key in intermediate_outputs.keys(): - self.assertIsInstance(key, tuple) - - # Test tensor cloning and detaching - for output in intermediate_outputs.values(): - if isinstance(output, torch.Tensor): - self.assertFalse(output.requires_grad) - self.assertTrue(output.is_leaf) - - # Test placeholder nodes are skipped - for node in graph_module.graph.nodes: - if node.op == "placeholder": - self.assertNotIn(node.meta.get("debug_handle"), node.meta) - - # Test multiple outputs capture - outputs = capturer.run_and_capture(input_tensor) - for output in outputs.values(): - if isinstance(output, tuple): - self.assertEqual(len(output), 2) - for part in output: - self.assertIsInstance(part, torch.Tensor) - - # Test capture correct outputs - self.assertTrue( - check_if_final_outputs_match(model_name, intermediate_outputs) + ret = propagate_back_debug_handle( + aten_model, + aten_model_graph_id, + edge_program_manager.exported_program(), + ) + assert ret is True + + self._capture_intermediate_outputs_and_check( + input_tensor, + aten_model, + model.get_exported_program_expected_intermediate_outputs(), + ) + self._capture_intermediate_outputs_and_check( + input_tensor, + edge_program_manager.exported_program(), + model.get_edge_dialect_expected_intermediate_outputs(), )