@@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
554
554
Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
555
555
while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
556
556
All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
557
+ Also removes duplicates within debug_handle2.
557
558
"""
558
559
559
560
# Initialize a list to store unique elements in order
@@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
566
567
# If the element has not been seen before, add it to the list and mark it as seen
567
568
if item not in seen :
568
569
unique_ordered_list .append (item )
569
-
570
+ seen = set ( unique_ordered_list )
570
571
for item in debug_handle2 :
571
- unique_ordered_list .append (item )
572
+ if item not in seen :
573
+ unique_ordered_list .append (item )
574
+ seen .add (item )
572
575
return tuple (unique_ordered_list )
573
576
574
577
575
578
def merge_runtime_overlapping_debug_handles (
576
- intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
579
+ runtime_intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
577
580
) -> Dict [DebugHandle , Tuple [int , Any ]]:
578
581
"""
579
582
Merges runtimes with overlapping debug handles into a single key in the dict.
@@ -585,15 +588,15 @@ def merge_runtime_overlapping_debug_handles(
585
588
586
589
The value associated with the merged key is determined by the debug handle with the highest instruction id.
587
590
"""
588
- if len (intermediate_outputs ) == 0 :
591
+ if len (runtime_intermediate_outputs ) == 0 :
589
592
return {}
590
593
merged : Dict [DebugHandle , Tuple [int , Any ]] = {}
591
- for debug_handle , (instruction_id , debug_data ) in intermediate_outputs .items ():
594
+ for debug_handle , (instruction_id , debug_data ) in runtime_intermediate_outputs .items ():
592
595
curr_debug_handle , last_value = debug_handle , (instruction_id , debug_data )
593
596
# Collect any existing keys that overlap with the current key
594
597
to_remove = []
595
598
for existing_debug_handle , existing_value in merged .items ():
596
- if any ( item in existing_debug_handle for item in debug_handle ):
599
+ if set ( debug_handle ) & set ( existing_debug_handle ):
597
600
# Keep the value with the highest instruction_id
598
601
# Also merge the debug handles higher instruction_id
599
602
if existing_value [0 ] < instruction_id :
@@ -759,7 +762,11 @@ def map_runtime_aot_intermediate_outputs(
759
762
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760
763
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761
764
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762
- assert len (runtime_list ) == 1
765
+ if len (runtime_list ) != 1 :
766
+ raise ValueError (
767
+ f"Expected only one runtime debug handle, but found { len (runtime_list )} : { runtime_list } "
768
+ )
769
+
763
770
runtime_debug_handle , runtime_intermediate_output = runtime_list [0 ]
764
771
765
772
# Combine aot debug handles into a single key
0 commit comments