Skip to content

Commit bc0a47e

Browse files
committed
Add packed tensor format support for flex/sdpa/eager through the mask! (#39194)
* Add the necesary logic to mask_utils * add it everywhere * Update masking_utils.py * style * Update masking_utils.py * Update modeling_mimi.py * Update masking_utils.py * add support for more than batch size 1 * Update masking_utils.py * add test * style * Update test_masking_utils.py * Update masking_utils.py * add require_token * fix tests * fix
1 parent 63af3d7 commit bc0a47e

File tree

65 files changed

+303
-9
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+303
-9
lines changed

src/transformers/generation/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ def prepare_inputs_for_generation(
656656
# If it's not defined, it means the model uses the new general mask API
657657
if causal_mask_creation_function is None: # can't be found
658658
token_type_ids = getattr(model_input, "token_type_ids", None)
659+
position_ids = getattr(model_input, position_ids_key, None)
659660
# Some models may overwrite the general one
660661
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
661662
attention_mask = causal_mask_creation_function(
@@ -665,6 +666,7 @@ def prepare_inputs_for_generation(
665666
attention_mask=attention_mask,
666667
cache_position=cache_position,
667668
past_key_values=past_key_values,
669+
position_ids=position_ids,
668670
token_type_ids=token_type_ids,
669671
)
670672
else:

src/transformers/masking_utils.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:
111111

112112

113113
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
114+
"""
115+
This return the mask_function function corresponding to a 2D padding mask.
116+
"""
117+
114118
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
115119
# Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
116120
# we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
@@ -120,6 +124,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
120124
return inner_mask
121125

122126

127+
def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
128+
"""
129+
This return the mask_function function corresponding to a 2D packed sequence mask.
130+
"""
131+
132+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
133+
return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
134+
135+
return inner_mask
136+
137+
123138
def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
124139
"""
125140
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
@@ -584,12 +599,40 @@ class AttentionMaskInterface(GeneralInterface):
584599
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
585600

586601

602+
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
603+
"""
604+
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
605+
tensor format (i.e. several sequences packed in the same batch dimension).
606+
607+
Args:
608+
position_ids (`torch.Tensor`)
609+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
610+
611+
Returns:
612+
A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
613+
pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
614+
"""
615+
# What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
616+
# taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
617+
# gives exactly the sequence indices
618+
# Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
619+
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
620+
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
621+
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
622+
packed_sequence_mask = (position_diff != 1).cumsum(-1)
623+
624+
# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
625+
# but it causes issues with export
626+
return packed_sequence_mask
627+
628+
587629
def _preprocess_mask_arguments(
588630
config: PretrainedConfig,
589631
input_embeds: torch.Tensor,
590632
attention_mask: Optional[Union[torch.Tensor, BlockMask]],
591633
cache_position: torch.Tensor,
592634
past_key_values: Optional[Cache],
635+
position_ids: Optional[torch.Tensor],
593636
layer_idx: Optional[int],
594637
) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]:
595638
"""
@@ -609,6 +652,8 @@ def _preprocess_mask_arguments(
609652
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
610653
past_key_values (`Cache`, optional):
611654
The past key values, if we use a cache.
655+
position_ids (`torch.Tensor`, optional)
656+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
612657
layer_idx (`int`, optional):
613658
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
614659
length and offset. Indeed, for hybrid caches, different layers may return different lengths.
@@ -618,22 +663,25 @@ def _preprocess_mask_arguments(
618663
Whether we should early exit mask creation, and return the mask as-is.
619664
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
620665
The attention mask to either return immediately, or to use in downstream mask creation.
666+
packed_sequence_mask (`torch.Tensor`, optional):
667+
In case we detected packed sequence format, this is a tensor where each similar integer indicates that
668+
the tokens belong to the same sequence.
621669
kv_length (`int`):
622670
The size that the key and value states will have during the attention computation.
623671
kv_offset (`int`):
624672
An offset to indicate at which first position the key and values states will refer to.
625673
"""
626674
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
627675
if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
628-
return True, attention_mask, None, None
676+
return True, attention_mask, None, None, None
629677

630678
# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
631679
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
632680
# full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
633681
# with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
634682
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
635683
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
636-
return True, None, None, None
684+
return True, None, None, None, None
637685

638686
# Move the mask to correct device, and potentially switch dtype for efficiency
639687
if attention_mask is not None and attention_mask.ndim == 2:
@@ -646,7 +694,17 @@ def _preprocess_mask_arguments(
646694
else:
647695
kv_length, kv_offset = input_embeds.shape[1], 0
648696

649-
return False, attention_mask, kv_length, kv_offset
697+
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
698+
# and we don't have past_key_values, i.e. generally a training setup)
699+
packed_sequence_mask = None
700+
if position_ids is not None and attention_mask is None and past_key_values is None:
701+
batch_size = input_embeds.shape[0]
702+
# The position ids are sometimes just unsqueezed, without being expanded
703+
if batch_size != position_ids.shape[0]:
704+
position_ids = position_ids.expand(batch_size, -1)
705+
packed_sequence_mask = find_packed_sequence_indices(position_ids)
706+
707+
return False, attention_mask, packed_sequence_mask, kv_length, kv_offset
650708

651709

652710
def create_causal_mask(
@@ -655,6 +713,7 @@ def create_causal_mask(
655713
attention_mask: Optional[torch.Tensor],
656714
cache_position: torch.Tensor,
657715
past_key_values: Optional[Cache],
716+
position_ids: Optional[torch.Tensor],
658717
or_mask_function: Optional[Callable] = None,
659718
and_mask_function: Optional[Callable] = None,
660719
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -676,6 +735,8 @@ def create_causal_mask(
676735
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
677736
past_key_values (`Cache`, optional):
678737
The past key values, if we use a cache.
738+
position_ids (`torch.Tensor`, optional)
739+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
679740
or_mask_function (`Callable`, optional):
680741
An optional mask function to combine with the causal mask function (by doing the union of both). This is
681742
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -689,8 +750,8 @@ def create_causal_mask(
689750
else:
690751
layer_idx = 0
691752

692-
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
693-
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
753+
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
754+
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
694755
)
695756
if early_exit:
696757
return attention_mask
@@ -703,6 +764,11 @@ def create_causal_mask(
703764
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
704765
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
705766

767+
# If we detected packing format
768+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
769+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
770+
allow_is_causal_skip = False
771+
706772
# Allow slight deviations from causal mask
707773
if or_mask_function is not None:
708774
if not _is_torch_greater_or_equal_than_2_6:
@@ -736,6 +802,7 @@ def create_sliding_window_causal_mask(
736802
attention_mask: Optional[torch.Tensor],
737803
cache_position: torch.Tensor,
738804
past_key_values: Optional[Cache],
805+
position_ids: Optional[torch.Tensor],
739806
or_mask_function: Optional[Callable] = None,
740807
and_mask_function: Optional[Callable] = None,
741808
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -758,6 +825,8 @@ def create_sliding_window_causal_mask(
758825
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
759826
past_key_values (`Cache`, optional):
760827
The past key values, if we use a cache.
828+
position_ids (`torch.Tensor`, optional)
829+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
761830
or_mask_function (`Callable`, optional):
762831
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
763832
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
@@ -771,8 +840,8 @@ def create_sliding_window_causal_mask(
771840
else:
772841
layer_idx = 0
773842

774-
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
775-
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
843+
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
844+
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
776845
)
777846
if early_exit:
778847
return attention_mask
@@ -789,6 +858,11 @@ def create_sliding_window_causal_mask(
789858
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
790859
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
791860

861+
# If we detected packing format
862+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
863+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
864+
allow_is_causal_skip = False
865+
792866
# Allow slight deviations from sliding causal mask
793867
if or_mask_function is not None:
794868
if not _is_torch_greater_or_equal_than_2_6:
@@ -823,6 +897,7 @@ def create_chunked_causal_mask(
823897
attention_mask: Optional[torch.Tensor],
824898
cache_position: torch.Tensor,
825899
past_key_values: Optional[Cache],
900+
position_ids: Optional[torch.Tensor],
826901
or_mask_function: Optional[Callable] = None,
827902
and_mask_function: Optional[Callable] = None,
828903
) -> Optional[Union[torch.Tensor, BlockMask]]:
@@ -845,6 +920,8 @@ def create_chunked_causal_mask(
845920
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
846921
past_key_values (`Cache`, optional):
847922
The past key values, if we use a cache.
923+
position_ids (`torch.Tensor`, optional)
924+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
848925
or_mask_function (`Callable`, optional):
849926
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
850927
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
@@ -858,8 +935,8 @@ def create_chunked_causal_mask(
858935
else:
859936
layer_idx = 0
860937

861-
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
862-
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
938+
early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
939+
config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
863940
)
864941
if early_exit:
865942
return attention_mask
@@ -883,6 +960,11 @@ def create_chunked_causal_mask(
883960
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
884961
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
885962

963+
# If we detected packing format
964+
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
965+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
966+
allow_is_causal_skip = False
967+
886968
# Allow slight deviations from chunked causal mask
887969
if or_mask_function is not None:
888970
if not _is_torch_greater_or_equal_than_2_6:
@@ -924,6 +1006,7 @@ def create_masks_for_generate(
9241006
attention_mask: Optional[torch.Tensor],
9251007
cache_position: torch.Tensor,
9261008
past_key_values: Optional[Cache],
1009+
position_ids: Optional[torch.Tensor],
9271010
or_mask_function: Optional[Callable] = None,
9281011
and_mask_function: Optional[Callable] = None,
9291012
**kwargs,
@@ -945,6 +1028,8 @@ def create_masks_for_generate(
9451028
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
9461029
past_key_values (`Cache`, optional):
9471030
The past key values, if we use a cache.
1031+
position_ids (`torch.Tensor`, optional)
1032+
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
9481033
or_mask_function (`Callable`, optional):
9491034
An optional mask function to combine with the other mask function (by doing the union of both). This is
9501035
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
@@ -961,6 +1046,7 @@ def create_masks_for_generate(
9611046
"attention_mask": attention_mask,
9621047
"cache_position": cache_position,
9631048
"past_key_values": past_key_values,
1049+
"position_ids": position_ids,
9641050
"or_mask_function": or_mask_function,
9651051
"and_mask_function": and_mask_function,
9661052
}

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def forward(
423423
attention_mask=attention_mask,
424424
cache_position=cache_position,
425425
past_key_values=past_key_values,
426+
position_ids=position_ids,
426427
)
427428

428429
hidden_states = inputs_embeds

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ def forward(
806806
attention_mask=attention_mask,
807807
cache_position=cache_position,
808808
past_key_values=past_key_values,
809+
position_ids=position_ids,
809810
)
810811

811812
hidden_states = inputs_embeds

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def forward(
420420
attention_mask=attention_mask,
421421
cache_position=cache_position,
422422
past_key_values=past_key_values,
423+
position_ids=position_ids,
423424
)
424425

425426
hidden_states = inputs_embeds

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def forward(
457457
attention_mask=attention_mask,
458458
cache_position=cache_position,
459459
past_key_values=past_key_values,
460+
position_ids=position_ids,
460461
)
461462

462463
hidden_states = inputs_embeds

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def forward(
434434
"attention_mask": attention_mask,
435435
"cache_position": cache_position,
436436
"past_key_values": past_key_values,
437+
"position_ids": position_ids,
437438
}
438439
# Create the masks
439440
causal_mask_mapping = {

src/transformers/models/cohere2/modular_cohere2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def forward(
455455
"attention_mask": attention_mask,
456456
"cache_position": cache_position,
457457
"past_key_values": past_key_values,
458+
"position_ids": position_ids,
458459
}
459460
# Create the masks
460461
causal_mask_mapping = {

src/transformers/models/csm/modeling_csm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ def forward(
500500
attention_mask=attention_mask,
501501
cache_position=cache_position,
502502
past_key_values=past_key_values,
503+
position_ids=position_ids,
503504
)
504505

505506
hidden_states = inputs_embeds
@@ -811,6 +812,7 @@ def forward(
811812
attention_mask=attention_mask,
812813
cache_position=cache_position,
813814
past_key_values=past_key_values,
815+
position_ids=position_ids,
814816
)
815817

816818
hidden_states = inputs_embeds

src/transformers/models/csm/modular_csm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def forward(
238238
attention_mask=attention_mask,
239239
cache_position=cache_position,
240240
past_key_values=past_key_values,
241+
position_ids=position_ids,
241242
)
242243

243244
hidden_states = inputs_embeds

0 commit comments

Comments
 (0)