Skip to content

Commit c2d6f3d

Browse files
authored
Use same constant in runtime for unset debug handle (#12516)
In the runtime we have a contant number for unset debug handle. This diff bring that to python env for usage. Differential Revision: [D78132322](https://our.internmc.facebook.com/intern/diff/D78132322/)
1 parent 4551a56 commit c2d6f3d

File tree

3 files changed

+11
-13
lines changed

3 files changed

+11
-13
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from executorch.exir.debug_handle_utils import (
3939
DEBUG_HANDLE_KEY,
4040
get_greatest_ancestor_node_identifier,
41+
UNSET_DEBUG_HANDLE,
4142
)
4243

4344
from executorch.exir.graph_module import bfs_trace_with_node_process
@@ -950,7 +951,7 @@ def propagate_back_debug_handle(
950951
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).
951952
952953
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
953-
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.
954+
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
954955
955956
Return: True if:
956957
a. every debug handle in the edge dialect program has a corresponding node in the exported program
@@ -971,11 +972,6 @@ def propagate_back_debug_handle(
971972
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
972973
n_matched_node = 0
973974

974-
# debug handle for the node in the exported program but not in the edge dialect program
975-
debug_handle_for_removed_node = (
976-
max(export_graph_node_id_to_debug_handle.values()) + 1
977-
)
978-
979975
def _find_n_match_node(node: torch.fx.Node) -> None:
980976
nonlocal n_matched_node
981977
if node.name in ("output", "placeholder"):
@@ -991,7 +987,7 @@ def _equip_debug_handle(node: torch.fx.Node) -> None:
991987
if node_id in export_graph_node_id_to_debug_handle:
992988
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
993989
else:
994-
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node
990+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
995991

996992
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
997993

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.devtools.inspector.numerical_comparator import L1Comparator
4949
from executorch.exir import to_edge
50-
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
50+
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY, UNSET_DEBUG_HANDLE
5151
from torch.export import export
5252

5353

@@ -704,18 +704,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
704704
)
705705
)
706706

707-
# only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three
708-
debug_handle_for_removed_node = 3
707+
n_removed_nodes = 0
709708

710709
for node in exported_program.graph.nodes:
711710
if node.name == "add":
712711
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1)
713712
elif node.name == "add_1":
714713
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2)
715714
elif node.op not in ("placeholder", "output"):
716-
self.assertEqual(
717-
node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node
718-
)
715+
n_removed_nodes += 1
716+
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], UNSET_DEBUG_HANDLE)
717+
718+
self.assertEqual(n_removed_nodes, 2)
719719

720720

721721
def gen_mock_operator_graph_with_expected_map() -> (

exir/debug_handle_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
FROM_NODE_KEY = "from_node"
1010
DEBUG_HANDLE_KEY = "debug_handle"
1111

12+
UNSET_DEBUG_HANDLE = 0
13+
1214

1315
def get_greatest_ancestor_node_identifier(node: Node) -> str:
1416
"""Get the identifier of the greatest ancestor node of the given node.

0 commit comments

Comments
 (0)