@@ -554,6 +554,7 @@ def _merge_runtime_debug_handles(
554
554
Merge two DebugHandles by removing elements from debug_handle1 that are also present in debug_handle2,
555
555
while preserving the relative order of elements in both modified debug_handle1 and debug_handle2.
556
556
All elements from the modified debug_handle1 will appear before any elements from debug_handle2.
557
+ Also removes duplicates within debug_handle2.
557
558
"""
558
559
559
560
# Initialize a list to store unique elements in order
@@ -566,14 +567,16 @@ def _merge_runtime_debug_handles(
566
567
# If the element has not been seen before, add it to the list and mark it as seen
567
568
if item not in seen :
568
569
unique_ordered_list .append (item )
569
-
570
+ seen = set ( unique_ordered_list )
570
571
for item in debug_handle2 :
571
- unique_ordered_list .append (item )
572
+ if item not in seen :
573
+ unique_ordered_list .append (item )
574
+ seen .add (item )
572
575
return tuple (unique_ordered_list )
573
576
574
577
575
578
def merge_runtime_overlapping_debug_handles (
576
- intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
579
+ runtime_intermediate_outputs : Dict [DebugHandle , Tuple [int , Any ]]
577
580
) -> Dict [DebugHandle , Tuple [int , Any ]]:
578
581
"""
579
582
Merges runtimes with overlapping debug handles into a single key in the dict.
@@ -585,15 +588,18 @@ def merge_runtime_overlapping_debug_handles(
585
588
586
589
The value associated with the merged key is determined by the debug handle with the highest instruction id.
587
590
"""
588
- if len (intermediate_outputs ) == 0 :
591
+ if len (runtime_intermediate_outputs ) == 0 :
589
592
return {}
590
593
merged : Dict [DebugHandle , Tuple [int , Any ]] = {}
591
- for debug_handle , (instruction_id , debug_data ) in intermediate_outputs .items ():
594
+ for debug_handle , (
595
+ instruction_id ,
596
+ debug_data ,
597
+ ) in runtime_intermediate_outputs .items ():
592
598
curr_debug_handle , last_value = debug_handle , (instruction_id , debug_data )
593
599
# Collect any existing keys that overlap with the current key
594
600
to_remove = []
595
601
for existing_debug_handle , existing_value in merged .items ():
596
- if any ( item in existing_debug_handle for item in debug_handle ):
602
+ if set ( debug_handle ) & set ( existing_debug_handle ):
597
603
# Keep the value with the highest instruction_id
598
604
# Also merge the debug handles higher instruction_id
599
605
if existing_value [0 ] < instruction_id :
@@ -759,7 +765,11 @@ def map_runtime_aot_intermediate_outputs(
759
765
# The size of runtime_list should be 1 because all AOT debug_handles are tuples with one element.
760
766
# Additionally, runtime debug handles have already undergone pre-processing to merge overlapping debug_hanldes.
761
767
# As a result, there shouldn't be any 1-to-n or n-to-n (AOT to runtime) mappings.
762
- assert len (runtime_list ) == 1
768
+ if len (runtime_list ) != 1 :
769
+ raise ValueError (
770
+ f"Expected only one runtime debug handle, but found { len (runtime_list )} : { runtime_list } "
771
+ )
772
+
763
773
runtime_debug_handle , runtime_intermediate_output = runtime_list [0 ]
764
774
765
775
# Combine aot debug handles into a single key
0 commit comments