Skip to content

Commit 0a038a7

Browse files
authored
Fixed runtime merging logics
Differential Revision: D78182279 Pull Request resolved: #12509
1 parent b14cb22 commit 0a038a7

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

devtools/inspector/_inspector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,9 @@ def _get_runtime_intermediate_outputs_and_op_names(
12111211
# TODO: One debug handle can be associated with multiple op names
12121212
debug_handle_to_op_names[debug_handle] = [event.name]
12131213

1214-
merge_runtime_overlapping_debug_handles(debug_handle_to_output)
1214+
debug_handle_to_output = merge_runtime_overlapping_debug_handles(
1215+
debug_handle_to_output
1216+
)
12151217
return {
12161218
k: v[1] for k, v in debug_handle_to_output.items()
12171219
}, debug_handle_to_op_names

devtools/inspector/_inspector_utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
554554
Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
555555
while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
556556
All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
557+
Also removes duplicates within debug_handle2.
557558
"""
558559

559560
# Initialize a list to store unique elements in order
@@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
566567
# If the element has not been seen before, add it to the list and mark it as seen
567568
if item not in seen:
568569
unique_ordered_list.append(item)
569-
570+
seen = set(unique_ordered_list)
570571
for item in debug_handle2:
571-
unique_ordered_list.append(item)
572+
if item not in seen:
573+
unique_ordered_list.append(item)
574+
seen.add(item)
572575
return tuple(unique_ordered_list)
573576

574577

575578
def merge_runtime_overlapping_debug_handles(
576-
intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
579+
runtime_intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
577580
) -> Dict[DebugHandle, Tuple[int, Any]]:
578581
"""
579582
Merges runtimes with overlapping debug handles into a single key in the dict.
@@ -585,15 +588,18 @@ def merge_runtime_overlapping_debug_handles(
585588
586589
The value associated with the merged key is determined by the debug handle with the highest instruction id.
587590
"""
588-
if len(intermediate_outputs) == 0:
591+
if len(runtime_intermediate_outputs) == 0:
589592
return {}
590593
merged: Dict[DebugHandle, Tuple[int, Any]] = {}
591-
for debug_handle, (instruction_id, debug_data) in intermediate_outputs.items():
594+
for debug_handle, (
595+
instruction_id,
596+
debug_data,
597+
) in runtime_intermediate_outputs.items():
592598
curr_debug_handle, last_value = debug_handle, (instruction_id, debug_data)
593599
# Collect any existing keys that overlap with the current key
594600
to_remove = []
595601
for existing_debug_handle, existing_value in merged.items():
596-
if any(item in existing_debug_handle for item in debug_handle):
602+
if set(debug_handle) & set(existing_debug_handle):
597603
# Keep the value with the highest instruction_id
598604
# Also merge the debug handles higher instruction_id
599605
if existing_value[0] < instruction_id:
@@ -759,7 +765,11 @@ def map_runtime_aot_intermediate_outputs(
759765
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760766
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761767
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762-
assert len(runtime_list) == 1
768+
if len(runtime_list) != 1:
769+
raise ValueError(
770+
f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}"
771+
)
772+
763773
runtime_debug_handle, runtime_intermediate_output = runtime_list[0]
764774

765775
# Combine aot debug handles into a single key

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,27 @@ def test_merge_overlapping_debug_handles_non_continuous(self):
278278
actual_value = intermediate_outputs[key][1]
279279
self.assertTrue(torch.allclose(expected_value, actual_value))
280280

281+
def test_merge_overlapping_debug_handles_edge_cases(self):
282+
intermediate_outputs = {
283+
(9,): (1, "val1"),
284+
(
285+
9,
286+
9,
287+
9,
288+
): (2, "val2"),
289+
(
290+
9,
291+
9,
292+
): (3, "val3"),
293+
}
294+
intermediate_outputs = merge_runtime_overlapping_debug_handles(
295+
intermediate_outputs
296+
)
297+
expected_intermediate_outputs = {
298+
(9,): (3, "val3"),
299+
}
300+
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
301+
281302
def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
282303
# When the inputs are empty, the output should also be empty
283304
aot_intermediate_outputs = {}

0 commit comments

Comments
 (0)