diff --git a/src/fairseq2/models/conformer/_block.py b/src/fairseq2/models/conformer/_block.py index e075bcbce..efff3873f 100644 --- a/src/fairseq2/models/conformer/_block.py +++ b/src/fairseq2/models/conformer/_block.py @@ -25,6 +25,7 @@ # isort: split from fairseq2.models.conformer._convolution import ConformerConvolution +from fairseq2.models.transformer._block_mask import BlockMaskCache @final @@ -131,10 +132,13 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: seqs = self._forward_ffn1(seqs) - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache + ) seqs = self._forward_conv(seqs, seqs_layout) @@ -161,6 +165,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: residual = seqs @@ -173,6 +178,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, ) if self.self_attn_dropout is not None: diff --git a/src/fairseq2/models/jepa/classifier/_model.py b/src/fairseq2/models/jepa/classifier/_model.py index f58deb56d..21a752264 100644 --- a/src/fairseq2/models/jepa/classifier/_model.py +++ b/src/fairseq2/models/jepa/classifier/_model.py @@ -18,6 +18,7 @@ from fairseq2.device import Device from fairseq2.models.transformer import ( AttentionBiasCache, + BlockMaskCache, FeedForwardNetwork, MultiheadAttention, TransformerEncoder, @@ -221,6 +222,7 @@ def _forward_cross_attn( encoder_output = self.cross_attn_layer_norm(encoder_output) attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() seqs = self.cross_attn( seqs, @@ -229,6 +231,7 @@ def _forward_cross_attn( keys_layout=encoder_output_layout, values=encoder_output, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, ) seqs = seqs + residual diff --git a/src/fairseq2/models/transformer/__init__.py b/src/fairseq2/models/transformer/__init__.py index 0d45df32c..fabba70f9 100644 --- a/src/fairseq2/models/transformer/__init__.py +++ b/src/fairseq2/models/transformer/__init__.py @@ -23,6 +23,7 @@ from fairseq2.models.transformer._attention_bias import ( maybe_get_attention_bias_tensor as maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache as BlockMaskCache from fairseq2.models.transformer._checkpoint import ( convert_transformer_checkpoint as convert_transformer_checkpoint, ) @@ -144,6 +145,7 @@ ) from fairseq2.models.transformer._sdpa._flash2 import Flash2SDPA as Flash2SDPA from fairseq2.models.transformer._sdpa._flash3 import Flash3SDPA as Flash3SDPA +from fairseq2.models.transformer._sdpa._flex import FlexSDPA as FlexSDPA from fairseq2.models.transformer._sdpa._naive import NaiveSDPA as NaiveSDPA from fairseq2.models.transformer._sdpa._naive import ( naive_scaled_dot_product_attention as naive_scaled_dot_product_attention, diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py new file mode 100644 index 000000000..ccda5ee4f --- /dev/null +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Callable, OrderedDict, TypeAlias + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import ( + BlockMask, + and_masks, + create_block_mask, +) + +from fairseq2.device import Device +from fairseq2.error import NotSupportedError +from fairseq2.nn import BatchLayout + +# isort: split + +from fairseq2.models.transformer._attention_bias import ( + AttentionBias, + CausalAttentionBias, + IdentityBias, +) + +MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] + +BLOCK_MASK_CACHE_MAX_SIZE = 100 + + +def _causal_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: + """Creates a causal mask function.""" + + def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Get sequence lengths for this batch + q_len = q_lens[b] + kv_len = kv_lens[b] + + # Calculate diagonal offset + d = kv_len - q_len + + return q_idx >= kv_idx - d + + return mask_fn + + +def _sliding_window_causal_mask_fn( + window_size: int, + q_lens: Tensor, + kv_lens: Tensor, +) -> MaskFunction: + """Creates a sliding window causal mask functions.""" + + def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Get sequence lengths for this batch + q_len = q_lens[b] + kv_len = kv_lens[b] + + # Calculate diagonal offset + d = kv_len - q_len + + # Apply both causal and window constraint + # NOTE: There is a incompatibility here with our _create_causal_bias_tensor + # function, since that requires data-dependent control flow via the q_len, + # which is not currently supported by torch.compile. This is a simplified + # version of that logic, which won't match exactly in all cases. + causal_mask = q_idx >= kv_idx - d + window_mask = kv_idx - d >= q_idx - window_size + 1 + return causal_mask & window_mask + + return mask_fn + + +def _offsets_to_doc_ids_tensor(offsets: Tensor) -> Tensor: + """Convert offsets to document IDs for packed sequences.""" + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave( + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) + + +def _create_packed_mask_fn( + seq_begin_indices: Tensor, + keys_begin_indices: Tensor, + base_mask_fn: MaskFunction | None = None, +) -> MaskFunction: + """Creates a mask function for packed sequences using document-based masking.""" + # Create document IDs for queries and keys + query_doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) + key_doc_ids = _offsets_to_doc_ids_tensor(keys_begin_indices) + + def packed_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Check if query and key belong to the same document + same_doc = query_doc_ids[q_idx] == key_doc_ids[kv_idx] + + # Convert global indices to logical positions within documents + q_doc_id = query_doc_ids[q_idx] + kv_doc_id = key_doc_ids[kv_idx] + q_logical = q_idx - seq_begin_indices[q_doc_id] + kv_logical = kv_idx - keys_begin_indices[kv_doc_id] + + # Apply base mask (e.g., causal) to logical positions + if base_mask_fn is not None: + inner_mask = base_mask_fn(b, h, q_logical, kv_logical) + return same_doc & inner_mask + else: + return same_doc + + return packed_mask_fn + + +def _create_padding_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: + """Creates a padding mask function that masks out padding tokens.""" + + def padding_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + q_valid = q_idx < q_lens[b] + kv_valid = kv_idx < kv_lens[b] + return q_valid & kv_valid + + return padding_mask_fn + + +def _create_composed_mask( + bias: AttentionBias, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + device: Device, +) -> BlockMask | None: + """Creates a composed mask using and_masks for combining multiple mask functions.""" + masks = [] + + if seqs_layout.packed: + # For packed sequences, create the base mask function first + base_mask_fn = None + + # Add attention bias mask as base mask + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + base_mask_fn = _sliding_window_causal_mask_fn( + attn_window_len, + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + else: + base_mask_fn = _causal_mask_fn( + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + # Create the packed sequence mask that incorporates the base mask + packed_mask = _create_packed_mask_fn( + seqs_layout.seq_begin_indices_pt, + keys_layout.seq_begin_indices_pt, + base_mask_fn, + ) + masks.append(packed_mask) + else: + # Standard batch format - handle bias and padding separately + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + masks.append( + _sliding_window_causal_mask_fn( + attn_window_len, + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + ) + else: + masks.append( + _causal_mask_fn( + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + ) + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + # Add padding mask + if seqs_layout.padded or keys_layout.padded: + masks.append( + _create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt) + ) + + # Compose masks + mask_fn = None + if len(masks) == 0: + return None + elif len(masks) == 1: + mask_fn = masks[0] + else: + mask_fn = and_masks(*masks) + + if seqs_layout.packed: + total_seq_len = int(seqs_layout.seq_begin_indices_pt[-1].item()) + total_keys_len = int(keys_layout.seq_begin_indices_pt[-1].item()) + batch_size = 1 + else: + total_seq_len = seqs_layout.max_seq_len + total_keys_len = keys_layout.max_seq_len + batch_size = len(seqs_layout.seq_lens) + + # Create the block mask + block_mask = create_block_mask( + mask_fn, + B=batch_size, + H=None, + Q_LEN=total_seq_len, + KV_LEN=total_keys_len, + device=str(device), + ) + return block_mask + + +@dataclass +class BlockMaskCacheKey: + """Key for caching block masks.""" + + batch_size: int + seqs_len: int + keys_len: int + + def __hash__(self) -> int: + return hash( + ( + self.batch_size, + self.seqs_len, + self.keys_len, + ) + ) + + +class BlockMaskCache: + """ + Cache for block masks to avoid recomputation across layers and (possibly) training + steps. We assume that the cache is not shared across different models or training + runs and therefore only needs to hash on sequence lengths and batch sizes. + """ + + def __init__(self) -> None: + self._cache: OrderedDict[BlockMaskCacheKey, BlockMask | None] = OrderedDict() + + def _create_cache_key( + self, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + ) -> BlockMaskCacheKey: + """Create a cache key based on sequence / key-value lengths and batch sizes.""" + if seqs_layout.packed: + batch_size = 1 + seqs_len = int(seqs_layout.seq_begin_indices[-1]) + keys_len = int(keys_layout.seq_begin_indices[-1]) + else: + batch_size = len(seqs_layout.seq_lens) + seqs_len = seqs_layout.max_seq_len + keys_len = keys_layout.max_seq_len + + cache_key = BlockMaskCacheKey( + batch_size=batch_size, + seqs_len=seqs_len, + keys_len=keys_len, + ) + return cache_key + + def get_or_create_mask( + self, + bias: AttentionBias, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + device: Device, + ) -> BlockMask | None: + """Get cached mask or create new one.""" + cache_key = self._create_cache_key(seqs_layout, keys_layout) + + if cache_key in self._cache: + # Move to end (most recently used) + self._cache.move_to_end(cache_key) + return self._cache[cache_key] + + # Create new mask + mask = _create_composed_mask(bias, seqs_layout, keys_layout, device) + + # Add to cache and evict if needed + self._cache[cache_key] = mask + if len(self._cache) > BLOCK_MASK_CACHE_MAX_SIZE: + self._cache.popitem(last=False) + + return mask + + def clear(self) -> None: + """Clear the cache.""" + self._cache.clear() diff --git a/src/fairseq2/models/transformer/_decoder.py b/src/fairseq2/models/transformer/_decoder.py index 071c7e3fb..242e2805a 100644 --- a/src/fairseq2/models/transformer/_decoder.py +++ b/src/fairseq2/models/transformer/_decoder.py @@ -24,6 +24,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._decoder_layer import TransformerDecoderLayer from fairseq2.models.transformer._encoder import _record_drop_for_backward @@ -167,6 +168,7 @@ def forward( ) attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() num_layers = len(self.layers) @@ -177,6 +179,7 @@ def forward( encoder_output, encoder_output_layout, attn_bias_cache, + block_mask_cache, state_bag=state_bag, ) diff --git a/src/fairseq2/models/transformer/_decoder_layer.py b/src/fairseq2/models/transformer/_decoder_layer.py index 2c0279454..b5c3f78f5 100644 --- a/src/fairseq2/models/transformer/_decoder_layer.py +++ b/src/fairseq2/models/transformer/_decoder_layer.py @@ -26,6 +26,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._ffn import FeedForwardNetwork from fairseq2.models.transformer._multihead_attention import MultiheadAttention from fairseq2.models.transformer._norm_order import TransformerNormOrder @@ -42,6 +43,7 @@ def forward( encoder_output: Tensor, encoder_output_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -195,10 +197,13 @@ def forward( encoder_output: Tensor, encoder_output_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache, state_bag) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache, state_bag + ) seqs = self._forward_encoder_decoder_attn( seqs, @@ -206,6 +211,7 @@ def forward( encoder_output, encoder_output_layout, attn_bias_cache, + block_mask_cache, state_bag, ) @@ -218,6 +224,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -232,6 +239,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, state_bag=state_bag, ) @@ -252,6 +260,7 @@ def _forward_encoder_decoder_attn( encoder_output: Tensor, encoder_output_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -266,6 +275,7 @@ def _forward_encoder_decoder_attn( keys_layout=encoder_output_layout, values=encoder_output, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, state_bag=state_bag, ) diff --git a/src/fairseq2/models/transformer/_encoder.py b/src/fairseq2/models/transformer/_encoder.py index 191c59f7c..d0e001b1c 100644 --- a/src/fairseq2/models/transformer/_encoder.py +++ b/src/fairseq2/models/transformer/_encoder.py @@ -25,6 +25,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._encoder_layer import TransformerEncoderLayer @@ -149,11 +150,12 @@ def forward(self, seqs: Tensor, seqs_layout: BatchLayout) -> Tensor: ) attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() num_layers = len(self.layers) for layer_idx, (layer, drop) in enumerate(self._drop_iter()): - layer_output = layer(seqs, seqs_layout, attn_bias_cache) + layer_output = layer(seqs, seqs_layout, attn_bias_cache, block_mask_cache) if drop: seqs = _record_drop_for_backward(seqs, layer_output) diff --git a/src/fairseq2/models/transformer/_encoder_layer.py b/src/fairseq2/models/transformer/_encoder_layer.py index 458736c13..bcb3cb84f 100644 --- a/src/fairseq2/models/transformer/_encoder_layer.py +++ b/src/fairseq2/models/transformer/_encoder_layer.py @@ -25,6 +25,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._ffn import FeedForwardNetwork from fairseq2.models.transformer._multihead_attention import MultiheadAttention from fairseq2.models.transformer._norm_order import TransformerNormOrder @@ -39,6 +40,7 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: """ :param seqs: The sequences to process. *Shape:* :math:`(N,S,M)`, where @@ -148,8 +150,11 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache + ) seqs = self._forward_ffn(seqs) @@ -160,6 +165,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: residual = seqs @@ -173,6 +179,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, ) if self.self_attn_dropout is not None: diff --git a/src/fairseq2/models/transformer/_multihead_attention.py b/src/fairseq2/models/transformer/_multihead_attention.py index c5ca38e51..eceaa2311 100644 --- a/src/fairseq2/models/transformer/_multihead_attention.py +++ b/src/fairseq2/models/transformer/_multihead_attention.py @@ -36,6 +36,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -58,6 +59,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -393,6 +395,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -463,7 +466,14 @@ def forward( # attns: (N, S, H, V_h) # attn_weights: (N, H, S, S_kv) attns, attn_weights = self.sdpa( - q, seqs_layout, k, keys_layout, v, bias_cache, needs_weights=needs_weights + q, + seqs_layout, + k, + keys_layout, + v, + bias_cache, + block_mask_cache, + needs_weights=needs_weights, ) del q diff --git a/src/fairseq2/models/transformer/_sdpa/_base.py b/src/fairseq2/models/transformer/_sdpa/_base.py index 42d34795f..24c354835 100644 --- a/src/fairseq2/models/transformer/_sdpa/_base.py +++ b/src/fairseq2/models/transformer/_sdpa/_base.py @@ -13,6 +13,7 @@ from torch.nn import Module from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.nn import BatchLayout @@ -28,6 +29,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_flash2.py b/src/fairseq2/models/transformer/_sdpa/_flash2.py index 74ac0a034..9c783e62f 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flash2.py +++ b/src/fairseq2/models/transformer/_sdpa/_flash2.py @@ -32,6 +32,7 @@ CausalAttentionBias, IdentityBias, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -58,6 +59,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_flash3.py b/src/fairseq2/models/transformer/_sdpa/_flash3.py index 7f1f792bb..5bbb23a68 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flash3.py +++ b/src/fairseq2/models/transformer/_sdpa/_flash3.py @@ -30,6 +30,7 @@ CausalAttentionBias, IdentityBias, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -56,6 +57,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py new file mode 100644 index 000000000..5eed1a445 --- /dev/null +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Callable, TypeAlias, final + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import flex_attention +from typing_extensions import override + +from fairseq2.models.transformer._block_mask import BlockMaskCache +from fairseq2.logging import log +from fairseq2.nn import BatchLayout + +# isort: split + +from fairseq2.models.transformer._attention_bias import ( + AttentionBias, + AttentionBiasCache, +) +from fairseq2.models.transformer._sdpa._base import SDPA + +MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] + +# NOTE: Flex attention only has performance benefits when torch.compiled, but this is +# not possible on certain platforms (e.g., CPU). +if torch.cuda.is_available(): + flex_attention = torch.compile(flex_attention, dynamic=False) + + +@final +class FlexSDPA(SDPA): + """Computes scaled dot-product attention using PyTorch's Flex Attention.""" + + bias: AttentionBias + dropout_p: float + + def __init__(self, bias: AttentionBias, *, dropout_p: float = 0.0) -> None: + super().__init__() + + self.bias = bias + self.dropout_p = dropout_p + + @override + def forward( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + keys: Tensor, + keys_layout: BatchLayout, + values: Tensor, + bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, + *, + needs_weights: bool = False, + ) -> tuple[Tensor, Tensor | None]: + if seqs_layout.packed ^ keys_layout.packed: + raise ValueError("`seqs_layout` and `keys_layout` must be both packed.") + + unsqueezed = False + if seqs.ndim == 3: + unsqueezed = True + seqs = seqs.unsqueeze(0) + keys = keys.unsqueeze(0) + values = values.unsqueeze(0) + + # Create the composed block mask using and_masks + block_mask = block_mask_cache.get_or_create_mask( + self.bias, + seqs_layout, + keys_layout, + seqs.device, + ) + + seqs = seqs.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + + attns = flex_attention( + seqs, + keys, + values, + block_mask=block_mask, + enable_gqa=False, + ) + + if isinstance(attns, tuple): + attns, _ = attns + + attns = attns.transpose(1, 2) + if unsqueezed: + attns = attns.squeeze(0) + + return attns, None + + @override + def extra_repr(self) -> str: + """:meta private:""" + return f"bias={self.bias}, dropout_p={self.dropout_p:G}" diff --git a/src/fairseq2/models/transformer/_sdpa/_naive.py b/src/fairseq2/models/transformer/_sdpa/_naive.py index 1d2373ae5..16c387bb2 100644 --- a/src/fairseq2/models/transformer/_sdpa/_naive.py +++ b/src/fairseq2/models/transformer/_sdpa/_naive.py @@ -22,6 +22,7 @@ AttentionBiasCache, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -48,6 +49,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_relative.py b/src/fairseq2/models/transformer/_sdpa/_relative.py index de9fa0b6b..5cf432b29 100644 --- a/src/fairseq2/models/transformer/_sdpa/_relative.py +++ b/src/fairseq2/models/transformer/_sdpa/_relative.py @@ -28,6 +28,7 @@ AttentionBiasCache, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA from fairseq2.models.transformer._sdpa._naive import ( naive_scaled_dot_product_attention, @@ -100,6 +101,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_shaw.py b/src/fairseq2/models/transformer/_sdpa/_shaw.py index 359df7834..1a503c8da 100644 --- a/src/fairseq2/models/transformer/_sdpa/_shaw.py +++ b/src/fairseq2/models/transformer/_sdpa/_shaw.py @@ -25,6 +25,7 @@ AttentionBiasCache, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA from fairseq2.models.transformer._sdpa._naive import ( naive_scaled_dot_product_attention, @@ -113,6 +114,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_torch.py b/src/fairseq2/models/transformer/_sdpa/_torch.py index 8c661d0e7..c85dd632b 100644 --- a/src/fairseq2/models/transformer/_sdpa/_torch.py +++ b/src/fairseq2/models/transformer/_sdpa/_torch.py @@ -23,6 +23,7 @@ CausalAttentionBias, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -49,6 +50,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer_lm/_decoder.py b/src/fairseq2/models/transformer_lm/_decoder.py index fe7effc86..d20cea244 100644 --- a/src/fairseq2/models/transformer_lm/_decoder.py +++ b/src/fairseq2/models/transformer_lm/_decoder.py @@ -21,6 +21,7 @@ # isort: split +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer_lm._decoder_layer import TransformerLMDecoderLayer @@ -132,11 +133,18 @@ def forward( state_bag: IncrementalStateBag | None = None, ) -> Tensor: attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() num_layers = len(self.layers) for layer_idx, layer in enumerate(self.layers): - seqs = layer(seqs, seqs_layout, attn_bias_cache, state_bag=state_bag) + seqs = layer( + seqs, + seqs_layout, + attn_bias_cache, + block_mask_cache, + state_bag=state_bag, + ) for hook in self._layer_hooks.values(): if not hook(layer_idx, seqs, seqs_layout, num_layers): diff --git a/src/fairseq2/models/transformer_lm/_decoder_layer.py b/src/fairseq2/models/transformer_lm/_decoder_layer.py index 0139cd870..704c86fd6 100644 --- a/src/fairseq2/models/transformer_lm/_decoder_layer.py +++ b/src/fairseq2/models/transformer_lm/_decoder_layer.py @@ -17,6 +17,7 @@ from fairseq2.device import Device from fairseq2.models.transformer import ( AttentionBiasCache, + BlockMaskCache, FeedForwardNetwork, MultiheadAttention, TransformerNormOrder, @@ -39,6 +40,7 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -146,10 +148,13 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache, state_bag) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache, state_bag + ) seqs = self._forward_ffn(seqs) @@ -160,6 +165,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -174,6 +180,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, state_bag=state_bag, ) diff --git a/src/fairseq2/recipes/common/_torch.py b/src/fairseq2/recipes/common/_torch.py index 9794d0553..254cf85f1 100644 --- a/src/fairseq2/recipes/common/_torch.py +++ b/src/fairseq2/recipes/common/_torch.py @@ -16,6 +16,7 @@ from fairseq2.models.transformer import ( Flash2SDPA, Flash3SDPA, + FlexSDPA, NaiveSDPA, TorchSDPA, set_default_sdpa_factory, @@ -137,6 +138,8 @@ def _set_default_sdpa_variant(name: str) -> None: _set_torch_sdpa_backend(backend) except (ImportError, AttributeError): log.warning("PyTorch SDPA kernel cannot be set to '{}'. Falling back to auto mode.", backend) # fmt: skip + case "flex": + set_default_sdpa_factory(FlexSDPA) case "flash2": set_default_sdpa_factory(Flash2SDPA) case "flash3": diff --git a/src/fairseq2/recipes/config.py b/src/fairseq2/recipes/config.py index 731b2f21f..63fc9b17a 100644 --- a/src/fairseq2/recipes/config.py +++ b/src/fairseq2/recipes/config.py @@ -465,6 +465,7 @@ class CommonSection: "torch_flash", "flash2", "flash3", + "flex", "naive", ] diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index f0d71dcb3..38b41bbd5 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -15,7 +15,9 @@ from fairseq2.models.transformer import ( AttentionBias, AttentionBiasCache, + BlockMaskCache, CausalAttentionBias, + FlexSDPA, IdentityBias, NaiveSDPA, StandardMultiheadAttention, @@ -27,19 +29,24 @@ class TestScaledDotProductAttention: # fmt: off - @pytest.mark.parametrize("use_padding,use_bias,training", + @pytest.mark.parametrize("use_padding,use_packing,use_bias,training", [ - (False, False, True), - (True, True, True), - (False, True, True), - (True, False, True), - (False, False, False), - (False, True, False), + (False, False, False, True), + (True, False, True, True), + (False, False, True, True), + (True, False, False, True), + (False, True, False, True), + (False, False, False, False), + (False, False, True, False), ], ) # fmt: on def test_torch_sdpa( - self, use_padding: bool, use_bias: bool, training: bool + self, + use_padding: bool, + use_packing: bool, + use_bias: bool, + training: bool, ) -> None: attn_bias: AttentionBias @@ -52,10 +59,10 @@ def test_torch_sdpa( naive_sdpa = NaiveSDPA(attn_bias) if training: - torch_sdpa.eval() - naive_sdpa.eval() + torch_sdpa.train() + naive_sdpa.train() - kwargs = self._get_sdpa_args(use_padding) + kwargs = self._get_sdpa_args(use_padding, use_packing) attn1, _ = torch_sdpa(**kwargs) attn2, _ = naive_sdpa(**kwargs) @@ -63,7 +70,7 @@ def test_torch_sdpa( assert_close(attn1, attn2) @staticmethod - def _get_sdpa_args(use_padding: bool) -> dict[str, Any]: + def _get_sdpa_args(use_padding: bool, use_packing: bool) -> dict[str, Any]: batch_size = 2 num_heads = 4 @@ -77,21 +84,42 @@ def _get_sdpa_args(use_padding: bool) -> dict[str, Any]: def random_tensor(*args: int) -> Tensor: return torch.randn(*args, device=device) - q = random_tensor(batch_size, target_seq_len, num_heads, k_size) - k = random_tensor(batch_size, source_seq_len, num_heads, k_size) - v = random_tensor(batch_size, source_seq_len, num_heads, v_size) + if use_packing: + # For packing, we need 1D tensors with total sequence length + total_target_len = sum([2, 3]) # seq_lens for target + total_source_len = sum([2, 3]) # seq_lens for source - target_shape = (batch_size, target_seq_len) - source_shape = (batch_size, source_seq_len) + q = random_tensor(total_target_len, num_heads, k_size) + k = random_tensor(total_source_len, num_heads, k_size) + v = random_tensor(total_source_len, num_heads, v_size) - if use_padding: - q_layout = BatchLayout(target_shape, seq_lens=None, device=device) - k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) + target_shape: tuple[int, ...] = (total_target_len,) + source_shape: tuple[int, ...] = (total_source_len,) + + q_layout = BatchLayout( + target_shape, seq_lens=[2, 3], packed=True, device=device + ) + k_layout = BatchLayout( + source_shape, seq_lens=[2, 3], packed=True, device=device + ) else: - q_layout = BatchLayout(target_shape, seq_lens=None, device=device) - k_layout = BatchLayout(source_shape, seq_lens=None, device=device) + # For non-packing cases (regular 2D tensors) + q = random_tensor(batch_size, target_seq_len, num_heads, k_size) + k = random_tensor(batch_size, source_seq_len, num_heads, k_size) + v = random_tensor(batch_size, source_seq_len, num_heads, v_size) + + target_shape = (batch_size, target_seq_len) + source_shape = (batch_size, source_seq_len) + + if use_padding: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) + else: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=None, device=device) bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() return { "seqs": q, @@ -100,6 +128,7 @@ def random_tensor(*args: int) -> Tensor: "keys_layout": k_layout, "values": v, "bias_cache": bias_cache, + "block_mask_cache": block_mask_cache, } @@ -145,6 +174,7 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: keys_layout = BatchLayout.of(keys) bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() result = mha( seqs, @@ -153,6 +183,118 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: keys_layout, values=keys, bias_cache=bias_cache, + block_mask_cache=block_mask_cache, ) assert result.shape == seqs.shape + + +class TestFlexScaledDotProductAttention: + # fmt: off + @pytest.mark.parametrize("use_padding,use_packing,use_bias,training,attn_window_len", + [ + # Original test cases + (False, False, False, True, None), + (True, False, True, True, None), + (False, False, True, True, None), + (True, False, False, True, None), + (False, True, False, True, None), + (False, True, True, True, None), + (False, False, True, True, 1), + (True, False, True, True, 1), + (False, True, True, True, 1), + (False, False, False, False, None), + (False, False, True, False, None), + ], + ) + # fmt: on + def test_flex_sdpa( + self, + use_padding: bool, + use_packing: bool, + use_bias: bool, + training: bool, + attn_window_len: int | None, + ) -> None: + attn_bias: AttentionBias + + if use_bias: + attn_bias = CausalAttentionBias(attn_window_len=attn_window_len) + else: + attn_bias = IdentityBias() + + flex_sdpa = FlexSDPA(attn_bias) + naive_sdpa = NaiveSDPA(attn_bias) + + if training: + flex_sdpa.train() + naive_sdpa.train() + + kwargs = self._get_sdpa_args(use_padding, use_packing) + + attn1, _ = flex_sdpa(**kwargs) + attn2, _ = naive_sdpa(**kwargs) + + assert_close(attn1, attn2) + + @staticmethod + def _get_sdpa_args(use_padding: bool, use_packing: bool) -> dict[str, Any]: + batch_size = 2 + + num_heads = 4 + + source_seq_len = 3 + target_seq_len = 2 + + k_size = 2 + v_size = 4 + + def random_tensor(*args: int) -> Tensor: + return torch.randn(*args, device=device) + + if use_packing: + # For packing, we need 1D tensors with total sequence length + total_target_len = sum([2, 3]) # seq_lens for target + total_source_len = sum([2, 3]) # seq_lens for source + + q = random_tensor(total_target_len, num_heads, k_size) + k = random_tensor(total_source_len, num_heads, k_size) + v = random_tensor(total_source_len, num_heads, v_size) + + target_shape: tuple[int, ...] = (total_target_len,) + source_shape: tuple[int, ...] = (total_source_len,) + + q_layout = BatchLayout( + target_shape, seq_lens=[2, 3], packed=True, device=device + ) + k_layout = BatchLayout( + source_shape, seq_lens=[2, 3], packed=True, device=device + ) + else: + # For non-packing cases (regular 2D tensors) + q = random_tensor(batch_size, target_seq_len, num_heads, k_size) + k = random_tensor(batch_size, source_seq_len, num_heads, k_size) + v = random_tensor(batch_size, source_seq_len, num_heads, v_size) + + target_shape = (batch_size, target_seq_len) + source_shape = (batch_size, source_seq_len) + + if use_padding: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) + else: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=None, device=device) + + bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() + + return { + "seqs": q, + "seqs_layout": q_layout, + "keys": k, + "keys_layout": k_layout, + "values": v, + "bias_cache": bias_cache, + "block_mask_cache": block_mask_cache, + } diff --git a/tests/unit/models/transformer/test_block_mask.py b/tests/unit/models/transformer/test_block_mask.py new file mode 100644 index 000000000..2bad2a8b9 --- /dev/null +++ b/tests/unit/models/transformer/test_block_mask.py @@ -0,0 +1,299 @@ +from unittest.mock import Mock, patch + +import pytest +import torch + +from fairseq2.device import Device +from fairseq2.models.transformer._attention_bias import IdentityBias +from fairseq2.models.transformer._block_mask import ( + BlockMaskCache, + BlockMaskCacheKey, + _causal_mask_fn, + _create_composed_mask, + _create_packed_mask_fn, + _create_padding_mask_fn, + _offsets_to_doc_ids_tensor, + _sliding_window_causal_mask_fn, +) + + +class TestMaskFunctions: + """Test individual mask functions.""" + + def test_causal_mask_fn(self) -> None: + """Test causal mask function behavior.""" + q_lens = torch.tensor([3, 2]) + kv_lens = torch.tensor([3, 2]) + mask_fn = _causal_mask_fn(q_lens, kv_lens) + + # Test for batch 0 + b = torch.tensor(0) + h = torch.tensor(0) + + # Test diagonal and upper triangular positions + assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False + assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True + + def test_sliding_window_causal_mask_fn(self) -> None: + """Test sliding window causal mask function.""" + q_lens = torch.tensor([4]) + kv_lens = torch.tensor([4]) + window_size = 2 + mask_fn = _sliding_window_causal_mask_fn(window_size, q_lens, kv_lens) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Test window behavior + assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True # Within window + assert mask_fn(b, h, torch.tensor(2), torch.tensor(2)) == True # Diagonal + assert ( + mask_fn(b, h, torch.tensor(3), torch.tensor(1)) == False + ) # Outside window + assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == False # Future token + + def test_sliding_window_size_one(self) -> None: + """Test sliding window with size 1 (diagonal only).""" + q_lens = torch.tensor([3]) + kv_lens = torch.tensor([3]) + mask_fn = _sliding_window_causal_mask_fn(1, q_lens, kv_lens) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Only diagonal should be True + assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == False + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False + + def test_offsets_to_doc_ids_tensor(self) -> None: + """Test conversion of offsets to document IDs.""" + offsets = torch.tensor([0, 3, 5, 8]) + doc_ids = _offsets_to_doc_ids_tensor(offsets) + expected = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32) + assert torch.equal(doc_ids, expected) + + def test_padding_mask_fn(self) -> None: + """Test padding mask function.""" + q_lens = torch.tensor([2, 3]) + kv_lens = torch.tensor([3, 2]) + mask_fn = _create_padding_mask_fn(q_lens, kv_lens) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Valid positions + assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == True + # Invalid positions (beyond sequence length) + assert mask_fn(b, h, torch.tensor(2), torch.tensor(0)) == False + assert mask_fn(b, h, torch.tensor(0), torch.tensor(3)) == False + + +class TestPackedMaskFunction: + """Test packed sequence mask function.""" + + def test_create_packed_mask_fn_basic(self) -> None: + """Test basic packed mask functionality.""" + seq_begin_indices = torch.tensor([0, 3, 5]) + keys_begin_indices = torch.tensor([0, 3, 5]) + + mask_fn = _create_packed_mask_fn(seq_begin_indices, keys_begin_indices) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Same document + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == True + assert mask_fn(b, h, torch.tensor(3), torch.tensor(4)) == True + # Different documents + assert mask_fn(b, h, torch.tensor(0), torch.tensor(3)) == False + assert mask_fn(b, h, torch.tensor(1), torch.tensor(4)) == False + + def test_create_packed_mask_fn_with_base_mask(self) -> None: + """Test packed mask with base causal mask.""" + seq_begin_indices = torch.tensor([0, 2, 4]) + keys_begin_indices = torch.tensor([0, 2, 4]) + q_lens = torch.tensor([2, 2]) + kv_lens = torch.tensor([2, 2]) + + base_mask_fn = _causal_mask_fn(q_lens, kv_lens) + mask_fn = _create_packed_mask_fn( + seq_begin_indices, keys_begin_indices, base_mask_fn + ) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Same document, causal valid + assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True + # Same document, causal invalid + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False + # Different documents + assert mask_fn(b, h, torch.tensor(0), torch.tensor(2)) == False + + +class TestBlockMaskCache: + """Test block mask caching functionality.""" + + def test_cache_key_creation(self) -> None: + """Test cache key creation for different layouts.""" + cache = BlockMaskCache() + + # Mock BatchLayout for non-packed sequences + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.seq_lens = [3, 4, 2] + seqs_layout.max_seq_len = 4 + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.seq_lens = [3, 4, 2] + keys_layout.max_seq_len = 4 + + key = cache._create_cache_key(seqs_layout, keys_layout) + assert key.batch_size == 3 + assert key.seqs_len == 4 + assert key.keys_len == 4 + + def test_cache_key_creation_packed(self) -> None: + """Test cache key creation for packed sequences.""" + cache = BlockMaskCache() + + # Mock BatchLayout for packed sequences + seqs_layout = Mock() + seqs_layout.packed = True + seqs_layout.seq_begin_indices = [0, 3, 7] + + keys_layout = Mock() + keys_layout.packed = True + keys_layout.seq_begin_indices = [0, 3, 7] + + key = cache._create_cache_key(seqs_layout, keys_layout) + assert key.batch_size == 1 + assert key.seqs_len == 7 + assert key.keys_len == 7 + + def test_cache_key_hash(self) -> None: + """Test that cache keys are hashable.""" + key1 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10) + key2 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10) + key3 = BlockMaskCacheKey(batch_size=3, seqs_len=10, keys_len=10) + + assert hash(key1) == hash(key2) + assert hash(key1) != hash(key3) + assert key1 == key2 + assert key1 != key3 + + @patch("fairseq2.models.transformer._block_mask._create_composed_mask") + def test_cache_hit_and_miss(self, mock_create_mask: Mock) -> None: + """Test cache hit and miss behavior.""" + cache = BlockMaskCache() + mock_mask = Mock() + mock_create_mask.return_value = mock_mask + + # Mock inputs + bias = Mock(spec=IdentityBias) + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.seq_lens = [3, 4] + seqs_layout.max_seq_len = 4 + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.seq_lens = [3, 4] + keys_layout.max_seq_len = 4 + + device = Mock(spec=Device) + + # First call - cache miss + result1 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device) + assert result1 == mock_mask + assert mock_create_mask.call_count == 1 + + # Second call - cache hit + result2 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device) + assert result2 == mock_mask + assert mock_create_mask.call_count == 1 # Should not increase + + +class TestCreateComposedMask: + """Test the main composed mask creation function.""" + + @patch("fairseq2.models.transformer._block_mask.create_block_mask") + def test_create_composed_mask_identity_bias( + self, mock_create_block_mask: Mock + ) -> None: + """Test composed mask creation with identity bias.""" + mock_block_mask = Mock() + mock_create_block_mask.return_value = mock_block_mask + + bias = Mock(spec=IdentityBias) + + # Mock BatchLayout + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.padded = True + seqs_layout.seq_lens = [3, 4] + seqs_layout.max_seq_len = 4 + seqs_layout.seq_lens_pt = torch.tensor([3, 4]) + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.padded = True + keys_layout.seq_lens = [3, 4] + keys_layout.max_seq_len = 4 + keys_layout.seq_lens_pt = torch.tensor([3, 4]) + + device = Mock(spec=Device) + + result = _create_composed_mask(bias, seqs_layout, keys_layout, device) + + # Should create block mask with padding mask only + mock_create_block_mask.assert_called_once() + assert result == mock_block_mask + + @patch("fairseq2.models.transformer._block_mask.create_block_mask") + def test_create_composed_mask_no_masks_needed( + self, mock_create_block_mask: Mock + ) -> None: + """Test when no masks are needed.""" + bias = Mock(spec=IdentityBias) + + # Mock BatchLayout with no padding + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.padded = False + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.padded = False + + device = Mock(spec=Device) + + result = _create_composed_mask(bias, seqs_layout, keys_layout, device) + + # Should return None when no masks are needed + assert result is None + mock_create_block_mask.assert_not_called() + + def test_unsupported_bias_type(self) -> None: + """Test that unsupported bias types raise an error.""" + bias = Mock() # Unknown bias type + + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.padded = False + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.padded = False + + device = Mock(spec=Device) + + with pytest.raises(Exception): # Should raise NotSupportedError + _create_composed_mask(bias, seqs_layout, keys_layout, device)