diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 03636fe1823..5a083655f2a 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -1211,7 +1211,9 @@ def _get_runtime_intermediate_outputs_and_op_names( # TODO: One debug handle can be associated with multiple op names debug_handle_to_op_names[debug_handle] = [event.name] - merge_runtime_overlapping_debug_handles(debug_handle_to_output) + debug_handle_to_output = merge_runtime_overlapping_debug_handles( + debug_handle_to_output + ) return { k: v[1] for k, v in debug_handle_to_output.items() }, debug_handle_to_op_names diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 040d664a808..56bf075dad0 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -554,6 +554,7 @@ def _merge_runtime_debug_handles( Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2, while preserving the relative order of elements in both modified debug_handle1 and debug_handle2. All elements from the modified debug_handle1 will appear before any elements from debug_handle2. + Also removes duplicates within debug_handle2. """ # Initialize a list to store unique elements in order @@ -566,14 +567,16 @@ def _merge_runtime_debug_handles( # If the element has not been seen before, add it to the list and mark it as seen if item not in seen: unique_ordered_list.append(item) - + seen = set(unique_ordered_list) for item in debug_handle2: - unique_ordered_list.append(item) + if item not in seen: + unique_ordered_list.append(item) + seen.add(item) return tuple(unique_ordered_list) def merge_runtime_overlapping_debug_handles( - intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]] + runtime_intermediate_outputs: Dict[DebugHandle, Tuple[int, Any]] ) -> Dict[DebugHandle, Tuple[int, Any]]: """ Merges runtimes with overlapping debug handles into a single key in the dict. @@ -585,15 +588,18 @@ def merge_runtime_overlapping_debug_handles( The value associated with the merged key is determined by the debug handle with the highest instruction id. """ - if len(intermediate_outputs) == 0: + if len(runtime_intermediate_outputs) == 0: return {} merged: Dict[DebugHandle, Tuple[int, Any]] = {} - for debug_handle, (instruction_id, debug_data) in intermediate_outputs.items(): + for debug_handle, ( + instruction_id, + debug_data, + ) in runtime_intermediate_outputs.items(): curr_debug_handle, last_value = debug_handle, (instruction_id, debug_data) # Collect any existing keys that overlap with the current key to_remove = [] for existing_debug_handle, existing_value in merged.items(): - if any(item in existing_debug_handle for item in debug_handle): + if set(debug_handle) & set(existing_debug_handle): # Keep the value with the highest instruction_id # Also merge the debug handles higher instruction_id if existing_value[0] < instruction_id: @@ -759,7 +765,11 @@ def map_runtime_aot_intermediate_outputs( # The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element. # Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes. # As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings. - assert len(runtime_list) == 1 + if len(runtime_list) != 1: + raise ValueError( + f"Expected only one runtime debug handle, but found {len(runtime_list)}: {runtime_list}" + ) + runtime_debug_handle, runtime_intermediate_output = runtime_list[0] # Combine aot debug handles into a single key diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 2d5ff242e22..d24d4b993dc 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -278,6 +278,27 @@ def test_merge_overlapping_debug_handles_non_continuous(self): actual_value = intermediate_outputs[key][1] self.assertTrue(torch.allclose(expected_value, actual_value)) + def test_merge_overlapping_debug_handles_edge_cases(self): + intermediate_outputs = { + (9,): (1, "val1"), + ( + 9, + 9, + 9, + ): (2, "val2"), + ( + 9, + 9, + ): (3, "val3"), + } + intermediate_outputs = merge_runtime_overlapping_debug_handles( + intermediate_outputs + ) + expected_intermediate_outputs = { + (9,): (3, "val3"), + } + self.assertEqual(intermediate_outputs, expected_intermediate_outputs) + def test_map_runtime_aot_intermediate_outputs_empty_inputs(self): # When the inputs are empty, the output should also be empty aot_intermediate_outputs = {}