Skip to content

Commit 1004a61

Browse files
Juntian777facebook-github-bot
authored andcommitted
Fixed runtime merging logics
Summary: This PR first added a missing return statement so that merge_runtime_overlapping_debug_handles now returns the updated debug_handle_to_output instead of doing nothing. Also updated the condition to use set intersection for more clearer checking of any common elements between debug_handle and existing_debug_handle. Differential Revision: D78182279
1 parent 967e3b9 commit 1004a61

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

devtools/inspector/_inspector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ def _get_runtime_intermediate_outputs_and_op_names(
12111211
# TODO: One debug handle can be associated with multiple op names
12121212
debug_handle_to_op_names[debug_handle] = [event.name]
12131213

1214-
merge_runtime_overlapping_debug_handles(debug_handle_to_output)
1214+
debug_handle_to_output = merge_runtime_overlapping_debug_handles(debug_handle_to_output)
12151215
return {
12161216
k: v[1] for k, v in debug_handle_to_output.items()
12171217
}, debug_handle_to_op_names

devtools/inspector/_inspector_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
554554
Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
555555
while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
556556
All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
557+
Also removes duplicates within debug_handle2.
557558
"""
558559

559560
# Initialize a list to store unique elements in order
@@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
566567
# If the element has not been seen before, add it to the list and mark it as seen
567568
if item not in seen:
568569
unique_ordered_list.append(item)
569-
570+
seen = set(unique_ordered_list)
570571
for item in debug_handle2:
571-
unique_ordered_list.append(item)
572+
if item not in seen:
573+
unique_ordered_list.append(item)
574+
seen.add(item)
572575
return tuple(unique_ordered_list)
573576

574577

575578
def merge_runtime_overlapping_debug_handles(
576-
intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
579+
runtime_intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
577580
) -> Dict[DebugHandle, Tuple[int, Any]]:
578581
"""
579582
Merges runtimes with overlapping debug handles into a single key in the dict.
@@ -585,15 +588,15 @@ def merge_runtime_overlapping_debug_handles(
585588
586589
The value associated with the merged key is determined by the debug handle with the highest instruction id.
587590
"""
588-
if len(intermediate_outputs) == 0:
591+
if len(runtime_intermediate_outputs) == 0:
589592
return {}
590593
merged: Dict[DebugHandle, Tuple[int, Any]] = {}
591-
for debug_handle, (instruction_id, debug_data) in intermediate_outputs.items():
594+
for debug_handle, (instruction_id, debug_data) in runtime_intermediate_outputs.items():
592595
curr_debug_handle, last_value = debug_handle, (instruction_id, debug_data)
593596
# Collect any existing keys that overlap with the current key
594597
to_remove = []
595598
for existing_debug_handle, existing_value in merged.items():
596-
if any(item in existing_debug_handle for item in debug_handle):
599+
if set(debug_handle) & set(existing_debug_handle):
597600
# Keep the value with the highest instruction_id
598601
# Also merge the debug handles higher instruction_id
599602
if existing_value[0] < instruction_id:
@@ -759,7 +762,11 @@ def map_runtime_aot_intermediate_outputs(
759762
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760763
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761764
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762-
assert len(runtime_list) == 1
765+
if len(runtime_list) != 1:
766+
raise ValueError(
767+
f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}"
768+
)
769+
763770
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]
764771

765772
# Combine aot debug handles into a single key

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,27 @@ def test_merge_overlapping_debug_handles_non_continuous(self):
278278
actual_value = intermediate_outputs[key][1]
279279
self.assertTrue(torch.allclose(expected_value, actual_value))
280280

281+
def test_merge_overlapping_debug_handles_edge_cases(self):
282+
intermediate_outputs = {
283+
(9,): (1, "val1"),
284+
(
285+
9,
286+
9,
287+
9,
288+
): (2, "val2"),
289+
(
290+
9,
291+
9,
292+
): (3, "val3"),
293+
}
294+
intermediate_outputs = merge_runtime_overlapping_debug_handles(
295+
intermediate_outputs
296+
)
297+
expected_intermediate_outputs = {
298+
(9,): (3, "val3"),
299+
}
300+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
301+
281302
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
282303
# When the inputs are empty, the output should also be empty
283304
aot_intermediate_outputs = {}

0 commit comments

Comments
 (0)