Skip to content

Fixed runtime merging logics #12509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,9 @@ def _get_runtime_intermediate_outputs_and_op_names(
# TODO: One debug handle can be associated with multiple op names
debug_handle_to_op_names[debug_handle] = [event.name]

merge_runtime_overlapping_debug_handles(debug_handle_to_output)
debug_handle_to_output = merge_runtime_overlapping_debug_handles(
debug_handle_to_output
)
return {
k: v[1] for k, v in debug_handle_to_output.items()
}, debug_handle_to_op_names
Expand Down
24 changes: 17 additions & 7 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
Also removes duplicates within debug_handle2.
"""

# Initialize a list to store unique elements in order
Expand All @@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
# If the element has not been seen before, add it to the list and mark it as seen
if item not in seen:
unique_ordered_list.append(item)

seen = set(unique_ordered_list)
for item in debug_handle2:
unique_ordered_list.append(item)
if item not in seen:
unique_ordered_list.append(item)
seen.add(item)
return tuple(unique_ordered_list)


def merge_runtime_overlapping_debug_handles(
intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
runtime_intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]]
) -> Dict[DebugHandle, Tuple[int, Any]]:
"""
Merges runtimes with overlapping debug handles into a single key in the dict.
Expand All @@ -585,15 +588,18 @@ def merge_runtime_overlapping_debug_handles(

The value associated with the merged key is determined by the debug handle with the highest instruction id.
"""
if len(intermediate_outputs) == 0:
if len(runtime_intermediate_outputs) == 0:
return {}
merged: Dict[DebugHandle, Tuple[int, Any]] = {}
for debug_handle, (instruction_id, debug_data) in intermediate_outputs.items():
for debug_handle, (
instruction_id,
debug_data,
) in runtime_intermediate_outputs.items():
curr_debug_handle, last_value = debug_handle, (instruction_id, debug_data)
# Collect any existing keys that overlap with the current key
to_remove = []
for existing_debug_handle, existing_value in merged.items():
if any(item in existing_debug_handle for item in debug_handle):
if set(debug_handle) & set(existing_debug_handle):
# Keep the value with the highest instruction_id
# Also merge the debug handles higher instruction_id
if existing_value[0] < instruction_id:
Expand Down Expand Up @@ -759,7 +765,11 @@ def map_runtime_aot_intermediate_outputs(
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
assert len(runtime_list) == 1
if len(runtime_list) != 1:
raise ValueError(
f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}"
)

runtime_debug_handle, runtime_intermediate_output = runtime_list[0]

# Combine aot debug handles into a single key
Expand Down
21 changes: 21 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,27 @@ def test_merge_overlapping_debug_handles_non_continuous(self):
actual_value = intermediate_outputs[key][1]
self.assertTrue(torch.allclose(expected_value, actual_value))

def test_merge_overlapping_debug_handles_edge_cases(self):
intermediate_outputs = {
(9,): (1, "val1"),
(
9,
9,
9,
): (2, "val2"),
(
9,
9,
): (3, "val3"),
}
intermediate_outputs = merge_runtime_overlapping_debug_handles(
intermediate_outputs
)
expected_intermediate_outputs = {
(9,): (3, "val3"),
}
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)

def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
# When the inputs are empty, the output should also be empty
aot_intermediate_outputs = {}
Expand Down
Loading