@@ -111,6 +111,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:
111
111
112
112
113
113
def padding_mask_function (padding_mask : torch .Tensor ) -> Callable :
114
+ """
115
+ This return the mask_function function corresponding to a 2D padding mask.
116
+ """
117
+
114
118
def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
115
119
# Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
116
120
# we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
@@ -120,6 +124,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
120
124
return inner_mask
121
125
122
126
127
+ def packed_sequence_mask_function (packed_sequence_mask : torch .Tensor ) -> Callable :
128
+ """
129
+ This return the mask_function function corresponding to a 2D packed sequence mask.
130
+ """
131
+
132
+ def inner_mask (batch_idx : int , head_idx : int , q_idx : int , kv_idx : int ) -> bool :
133
+ return packed_sequence_mask [batch_idx , q_idx ] == packed_sequence_mask [batch_idx , kv_idx ]
134
+
135
+ return inner_mask
136
+
137
+
123
138
def add_offsets_to_mask_function (mask_function : Callable , q_offset : int , kv_offset : int ) -> Callable :
124
139
"""
125
140
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
@@ -584,12 +599,40 @@ class AttentionMaskInterface(GeneralInterface):
584
599
ALL_MASK_ATTENTION_FUNCTIONS : AttentionMaskInterface = AttentionMaskInterface ()
585
600
586
601
602
+ def find_packed_sequence_indices (position_ids : torch .Tensor ) -> Optional [torch .Tensor ]:
603
+ """
604
+ Find the indices of the sequence to which each new query token in the sequence belongs when using packed
605
+ tensor format (i.e. several sequences packed in the same batch dimension).
606
+
607
+ Args:
608
+ position_ids (`torch.Tensor`)
609
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
610
+
611
+ Returns:
612
+ A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
613
+ pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
614
+ """
615
+ # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
616
+ # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
617
+ # gives exactly the sequence indices
618
+ # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
619
+ # cannot be part of the end of the first batch dim and the start of the 2nd one for example
620
+ first_dummy_value = position_ids [:, :1 ] - 1 # We just need the diff on this first value to be 1
621
+ position_diff = torch .diff (position_ids , prepend = first_dummy_value , dim = - 1 )
622
+ packed_sequence_mask = (position_diff != 1 ).cumsum (- 1 )
623
+
624
+ # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
625
+ # but it causes issues with export
626
+ return packed_sequence_mask
627
+
628
+
587
629
def _preprocess_mask_arguments (
588
630
config : PretrainedConfig ,
589
631
input_embeds : torch .Tensor ,
590
632
attention_mask : Optional [Union [torch .Tensor , BlockMask ]],
591
633
cache_position : torch .Tensor ,
592
634
past_key_values : Optional [Cache ],
635
+ position_ids : Optional [torch .Tensor ],
593
636
layer_idx : Optional [int ],
594
637
) -> tuple [bool , Optional [Union [torch .Tensor , BlockMask ]], int , int ]:
595
638
"""
@@ -609,6 +652,8 @@ def _preprocess_mask_arguments(
609
652
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
610
653
past_key_values (`Cache`, optional):
611
654
The past key values, if we use a cache.
655
+ position_ids (`torch.Tensor`, optional)
656
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
612
657
layer_idx (`int`, optional):
613
658
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
614
659
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
@@ -618,22 +663,25 @@ def _preprocess_mask_arguments(
618
663
Whether we should early exit mask creation, and return the mask as-is.
619
664
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
620
665
The attention mask to either return immediately, or to use in downstream mask creation.
666
+ packed_sequence_mask (`torch.Tensor`, optional):
667
+ In case we detected packed sequence format, this is a tensor where each similar integer indicates that
668
+ the tokens belong to the same sequence.
621
669
kv_length (`int`):
622
670
The size that the key and value states will have during the attention computation.
623
671
kv_offset (`int`):
624
672
An offset to indicate at which first position the key and values states will refer to.
625
673
"""
626
674
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
627
675
if isinstance (attention_mask , (torch .Tensor , BlockMask )) and len (attention_mask .shape ) == 4 :
628
- return True , attention_mask , None , None
676
+ return True , attention_mask , None , None , None
629
677
630
678
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
631
679
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
632
680
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
633
681
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
634
682
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
635
683
if config ._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS ._global_mapping :
636
- return True , None , None , None
684
+ return True , None , None , None , None
637
685
638
686
# Move the mask to correct device, and potentially switch dtype for efficiency
639
687
if attention_mask is not None and attention_mask .ndim == 2 :
@@ -646,7 +694,17 @@ def _preprocess_mask_arguments(
646
694
else :
647
695
kv_length , kv_offset = input_embeds .shape [1 ], 0
648
696
649
- return False , attention_mask , kv_length , kv_offset
697
+ # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
698
+ # and we don't have past_key_values, i.e. generally a training setup)
699
+ packed_sequence_mask = None
700
+ if position_ids is not None and attention_mask is None and past_key_values is None :
701
+ batch_size = input_embeds .shape [0 ]
702
+ # The position ids are sometimes just unsqueezed, without being expanded
703
+ if batch_size != position_ids .shape [0 ]:
704
+ position_ids = position_ids .expand (batch_size , - 1 )
705
+ packed_sequence_mask = find_packed_sequence_indices (position_ids )
706
+
707
+ return False , attention_mask , packed_sequence_mask , kv_length , kv_offset
650
708
651
709
652
710
def create_causal_mask (
@@ -655,6 +713,7 @@ def create_causal_mask(
655
713
attention_mask : Optional [torch .Tensor ],
656
714
cache_position : torch .Tensor ,
657
715
past_key_values : Optional [Cache ],
716
+ position_ids : Optional [torch .Tensor ],
658
717
or_mask_function : Optional [Callable ] = None ,
659
718
and_mask_function : Optional [Callable ] = None ,
660
719
) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -676,6 +735,8 @@ def create_causal_mask(
676
735
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
677
736
past_key_values (`Cache`, optional):
678
737
The past key values, if we use a cache.
738
+ position_ids (`torch.Tensor`, optional)
739
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
679
740
or_mask_function (`Callable`, optional):
680
741
An optional mask function to combine with the causal mask function (by doing the union of both). This is
681
742
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -689,8 +750,8 @@ def create_causal_mask(
689
750
else :
690
751
layer_idx = 0
691
752
692
- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
693
- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
753
+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
754
+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
694
755
)
695
756
if early_exit :
696
757
return attention_mask
@@ -703,6 +764,11 @@ def create_causal_mask(
703
764
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
704
765
allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
705
766
767
+ # If we detected packing format
768
+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
769
+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
770
+ allow_is_causal_skip = False
771
+
706
772
# Allow slight deviations from causal mask
707
773
if or_mask_function is not None :
708
774
if not _is_torch_greater_or_equal_than_2_6 :
@@ -736,6 +802,7 @@ def create_sliding_window_causal_mask(
736
802
attention_mask : Optional [torch .Tensor ],
737
803
cache_position : torch .Tensor ,
738
804
past_key_values : Optional [Cache ],
805
+ position_ids : Optional [torch .Tensor ],
739
806
or_mask_function : Optional [Callable ] = None ,
740
807
and_mask_function : Optional [Callable ] = None ,
741
808
) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -758,6 +825,8 @@ def create_sliding_window_causal_mask(
758
825
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
759
826
past_key_values (`Cache`, optional):
760
827
The past key values, if we use a cache.
828
+ position_ids (`torch.Tensor`, optional)
829
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
761
830
or_mask_function (`Callable`, optional):
762
831
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
763
832
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
@@ -771,8 +840,8 @@ def create_sliding_window_causal_mask(
771
840
else :
772
841
layer_idx = 0
773
842
774
- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
775
- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
843
+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
844
+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
776
845
)
777
846
if early_exit :
778
847
return attention_mask
@@ -789,6 +858,11 @@ def create_sliding_window_causal_mask(
789
858
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
790
859
allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
791
860
861
+ # If we detected packing format
862
+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
863
+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
864
+ allow_is_causal_skip = False
865
+
792
866
# Allow slight deviations from sliding causal mask
793
867
if or_mask_function is not None :
794
868
if not _is_torch_greater_or_equal_than_2_6 :
@@ -823,6 +897,7 @@ def create_chunked_causal_mask(
823
897
attention_mask : Optional [torch .Tensor ],
824
898
cache_position : torch .Tensor ,
825
899
past_key_values : Optional [Cache ],
900
+ position_ids : Optional [torch .Tensor ],
826
901
or_mask_function : Optional [Callable ] = None ,
827
902
and_mask_function : Optional [Callable ] = None ,
828
903
) -> Optional [Union [torch .Tensor , BlockMask ]]:
@@ -845,6 +920,8 @@ def create_chunked_causal_mask(
845
920
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
846
921
past_key_values (`Cache`, optional):
847
922
The past key values, if we use a cache.
923
+ position_ids (`torch.Tensor`, optional)
924
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
848
925
or_mask_function (`Callable`, optional):
849
926
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
850
927
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
@@ -858,8 +935,8 @@ def create_chunked_causal_mask(
858
935
else :
859
936
layer_idx = 0
860
937
861
- early_exit , attention_mask , kv_length , kv_offset = _preprocess_mask_arguments (
862
- config , input_embeds , attention_mask , cache_position , past_key_values , layer_idx
938
+ early_exit , attention_mask , packed_sequence_mask , kv_length , kv_offset = _preprocess_mask_arguments (
939
+ config , input_embeds , attention_mask , cache_position , past_key_values , position_ids , layer_idx
863
940
)
864
941
if early_exit :
865
942
return attention_mask
@@ -883,6 +960,11 @@ def create_chunked_causal_mask(
883
960
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
884
961
allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
885
962
963
+ # If we detected packing format
964
+ if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6 :
965
+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
966
+ allow_is_causal_skip = False
967
+
886
968
# Allow slight deviations from chunked causal mask
887
969
if or_mask_function is not None :
888
970
if not _is_torch_greater_or_equal_than_2_6 :
@@ -924,6 +1006,7 @@ def create_masks_for_generate(
924
1006
attention_mask : Optional [torch .Tensor ],
925
1007
cache_position : torch .Tensor ,
926
1008
past_key_values : Optional [Cache ],
1009
+ position_ids : Optional [torch .Tensor ],
927
1010
or_mask_function : Optional [Callable ] = None ,
928
1011
and_mask_function : Optional [Callable ] = None ,
929
1012
** kwargs ,
@@ -945,6 +1028,8 @@ def create_masks_for_generate(
945
1028
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
946
1029
past_key_values (`Cache`, optional):
947
1030
The past key values, if we use a cache.
1031
+ position_ids (`torch.Tensor`, optional)
1032
+ A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
948
1033
or_mask_function (`Callable`, optional):
949
1034
An optional mask function to combine with the other mask function (by doing the union of both). This is
950
1035
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -961,6 +1046,7 @@ def create_masks_for_generate(
961
1046
"attention_mask" : attention_mask ,
962
1047
"cache_position" : cache_position ,
963
1048
"past_key_values" : past_key_values ,
1049
+ "position_ids" : position_ids ,
964
1050
"or_mask_function" : or_mask_function ,
965
1051
"and_mask_function" : and_mask_function ,
966
1052
}
0 commit comments