Skip to content
8 changes: 7 additions & 1 deletion src/fairseq2/models/conformer/_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# isort: split

from fairseq2.models.conformer._convolution import ConformerConvolution
from fairseq2.models.transformer._block_mask import BlockMaskCache


@final
Expand Down Expand Up @@ -131,10 +132,13 @@ def forward(
seqs: Tensor,
seqs_layout: BatchLayout,
attn_bias_cache: AttentionBiasCache,
block_mask_cache: BlockMaskCache,
) -> Tensor:
seqs = self._forward_ffn1(seqs)

seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache)
seqs = self._forward_self_attn(
seqs, seqs_layout, attn_bias_cache, block_mask_cache
)

seqs = self._forward_conv(seqs, seqs_layout)

Expand All @@ -161,6 +165,7 @@ def _forward_self_attn(
seqs: Tensor,
seqs_layout: BatchLayout,
attn_bias_cache: AttentionBiasCache,
block_mask_cache: BlockMaskCache,
) -> Tensor:
residual = seqs

Expand All @@ -173,6 +178,7 @@ def _forward_self_attn(
keys_layout=seqs_layout,
values=seqs,
bias_cache=attn_bias_cache,
block_mask_cache=block_mask_cache,
)

if self.self_attn_dropout is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/models/jepa/classifier/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fairseq2.device import Device
from fairseq2.models.transformer import (
AttentionBiasCache,
BlockMaskCache,
FeedForwardNetwork,
MultiheadAttention,
TransformerEncoder,
Expand Down Expand Up @@ -221,6 +222,7 @@ def _forward_cross_attn(
encoder_output = self.cross_attn_layer_norm(encoder_output)

attn_bias_cache = AttentionBiasCache()
block_mask_cache = BlockMaskCache()

seqs = self.cross_attn(
seqs,
Expand All @@ -229,6 +231,7 @@ def _forward_cross_attn(
keys_layout=encoder_output_layout,
values=encoder_output,
bias_cache=attn_bias_cache,
block_mask_cache=block_mask_cache,
)

seqs = seqs + residual
Expand Down
2 changes: 2 additions & 0 deletions src/fairseq2/models/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fairseq2.models.transformer._attention_bias import (
maybe_get_attention_bias_tensor as maybe_get_attention_bias_tensor,
)
from fairseq2.models.transformer._block_mask import BlockMaskCache as BlockMaskCache
from fairseq2.models.transformer._checkpoint import (
convert_transformer_checkpoint as convert_transformer_checkpoint,
)
Expand Down Expand Up @@ -144,6 +145,7 @@
)
from fairseq2.models.transformer._sdpa._flash2 import Flash2SDPA as Flash2SDPA
from fairseq2.models.transformer._sdpa._flash3 import Flash3SDPA as Flash3SDPA
from fairseq2.models.transformer._sdpa._flex import FlexSDPA as FlexSDPA
from fairseq2.models.transformer._sdpa._naive import NaiveSDPA as NaiveSDPA
from fairseq2.models.transformer._sdpa._naive import (
naive_scaled_dot_product_attention as naive_scaled_dot_product_attention,
Expand Down
300 changes: 300 additions & 0 deletions src/fairseq2/models/transformer/_block_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Callable, OrderedDict, TypeAlias

import torch
from torch import Tensor
from torch.nn.attention.flex_attention import (
BlockMask,
and_masks,
create_block_mask,
)

from fairseq2.device import Device
from fairseq2.error import NotSupportedError
from fairseq2.nn import BatchLayout

# isort: split

from fairseq2.models.transformer._attention_bias import (
AttentionBias,
CausalAttentionBias,
IdentityBias,
)

MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]

BLOCK_MASK_CACHE_MAX_SIZE = 100


def _causal_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction:
"""Creates a causal mask function."""

def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
# Get sequence lengths for this batch
q_len = q_lens[b]
kv_len = kv_lens[b]

# Calculate diagonal offset
d = kv_len - q_len

return q_idx >= kv_idx - d

return mask_fn


def _sliding_window_causal_mask_fn(
window_size: int,
q_lens: Tensor,
kv_lens: Tensor,
) -> MaskFunction:
"""Creates a sliding window causal mask functions."""

def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
# Get sequence lengths for this batch
q_len = q_lens[b]
kv_len = kv_lens[b]

# Calculate diagonal offset
d = kv_len - q_len

# Apply both causal and window constraint
# NOTE: There is a incompatibility here with our _create_causal_bias_tensor
# function, since that requires data-dependent control flow via the q_len,
# which is not currently supported by torch.compile. This is a simplified
# version of that logic, which won't match exactly in all cases.
causal_mask = q_idx >= kv_idx - d
window_mask = kv_idx - d >= q_idx - window_size + 1
return causal_mask & window_mask

return mask_fn


def _offsets_to_doc_ids_tensor(offsets: Tensor) -> Tensor:
"""Convert offsets to document IDs for packed sequences."""
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts
)


def _create_packed_mask_fn(
seq_begin_indices: Tensor,
keys_begin_indices: Tensor,
base_mask_fn: MaskFunction | None = None,
) -> MaskFunction:
"""Creates a mask function for packed sequences using document-based masking."""
# Create document IDs for queries and keys
query_doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices)
key_doc_ids = _offsets_to_doc_ids_tensor(keys_begin_indices)

def packed_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
# Check if query and key belong to the same document
same_doc = query_doc_ids[q_idx] == key_doc_ids[kv_idx]

# Convert global indices to logical positions within documents
q_doc_id = query_doc_ids[q_idx]
kv_doc_id = key_doc_ids[kv_idx]
q_logical = q_idx - seq_begin_indices[q_doc_id]
kv_logical = kv_idx - keys_begin_indices[kv_doc_id]

# Apply base mask (e.g., causal) to logical positions
if base_mask_fn is not None:
inner_mask = base_mask_fn(b, h, q_logical, kv_logical)
return same_doc & inner_mask
else:
return same_doc

return packed_mask_fn


def _create_padding_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction:
"""Creates a padding mask function that masks out padding tokens."""

def padding_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
q_valid = q_idx < q_lens[b]
kv_valid = kv_idx < kv_lens[b]
return q_valid & kv_valid

return padding_mask_fn


def _create_composed_mask(
bias: AttentionBias,
seqs_layout: BatchLayout,
keys_layout: BatchLayout,
device: Device,
) -> BlockMask | None:
"""Creates a composed mask using and_masks for combining multiple mask functions."""
masks = []

if seqs_layout.packed:
# For packed sequences, create the base mask function first
base_mask_fn = None

# Add attention bias mask as base mask
if isinstance(bias, CausalAttentionBias):
attn_window_len = bias.attn_window_len
if attn_window_len is not None:
base_mask_fn = _sliding_window_causal_mask_fn(
attn_window_len,
seqs_layout.seq_lens_pt,
keys_layout.seq_lens_pt,
)
else:
base_mask_fn = _causal_mask_fn(
seqs_layout.seq_lens_pt,
keys_layout.seq_lens_pt,
)
elif not isinstance(bias, IdentityBias):
raise NotSupportedError(f"Unsupported bias type: {bias}")

# Create the packed sequence mask that incorporates the base mask
packed_mask = _create_packed_mask_fn(
seqs_layout.seq_begin_indices_pt,
keys_layout.seq_begin_indices_pt,
base_mask_fn,
)
masks.append(packed_mask)
else:
# Standard batch format - handle bias and padding separately
if isinstance(bias, CausalAttentionBias):
attn_window_len = bias.attn_window_len
if attn_window_len is not None:
masks.append(
_sliding_window_causal_mask_fn(
attn_window_len,
seqs_layout.seq_lens_pt,
keys_layout.seq_lens_pt,
)
)
else:
masks.append(
_causal_mask_fn(
seqs_layout.seq_lens_pt,
keys_layout.seq_lens_pt,
)
)
elif not isinstance(bias, IdentityBias):
raise NotSupportedError(f"Unsupported bias type: {bias}")

# Add padding mask
if seqs_layout.padded or keys_layout.padded:
masks.append(
_create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt)
)

# Compose masks
mask_fn = None
if len(masks) == 0:
return None
elif len(masks) == 1:
mask_fn = masks[0]
else:
mask_fn = and_masks(*masks)

if seqs_layout.packed:
total_seq_len = int(seqs_layout.seq_begin_indices_pt[-1].item())
total_keys_len = int(keys_layout.seq_begin_indices_pt[-1].item())
batch_size = 1
else:
total_seq_len = seqs_layout.max_seq_len
total_keys_len = keys_layout.max_seq_len
batch_size = len(seqs_layout.seq_lens)

# Create the block mask
block_mask = create_block_mask(
mask_fn,
B=batch_size,
H=None,
Q_LEN=total_seq_len,
KV_LEN=total_keys_len,
device=str(device),
)
return block_mask


@dataclass
class BlockMaskCacheKey:
"""Key for caching block masks."""

batch_size: int
seqs_len: int
keys_len: int

def __hash__(self) -> int:
return hash(
(
self.batch_size,
self.seqs_len,
self.keys_len,
)
)


class BlockMaskCache:
"""
Cache for block masks to avoid recomputation across layers and (possibly) training
steps. We assume that the cache is not shared across different models or training
runs and therefore only needs to hash on sequence lengths and batch sizes.
"""

def __init__(self) -> None:
self._cache: OrderedDict[BlockMaskCacheKey, BlockMask | None] = OrderedDict()

def _create_cache_key(
self,
seqs_layout: BatchLayout,
keys_layout: BatchLayout,
) -> BlockMaskCacheKey:
"""Create a cache key based on sequence / key-value lengths and batch sizes."""
if seqs_layout.packed:
batch_size = 1
seqs_len = int(seqs_layout.seq_begin_indices[-1])
keys_len = int(keys_layout.seq_begin_indices[-1])
else:
batch_size = len(seqs_layout.seq_lens)
seqs_len = seqs_layout.max_seq_len
keys_len = keys_layout.max_seq_len

cache_key = BlockMaskCacheKey(
batch_size=batch_size,
seqs_len=seqs_len,
keys_len=keys_len,
)
return cache_key

def get_or_create_mask(
self,
bias: AttentionBias,
seqs_layout: BatchLayout,
keys_layout: BatchLayout,
device: Device,
) -> BlockMask | None:
"""Get cached mask or create new one."""
cache_key = self._create_cache_key(seqs_layout, keys_layout)

if cache_key in self._cache:
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
return self._cache[cache_key]

# Create new mask
mask = _create_composed_mask(bias, seqs_layout, keys_layout, device)

# Add to cache and evict if needed
self._cache[cache_key] = mask
if len(self._cache) > BLOCK_MASK_CACHE_MAX_SIZE:
self._cache.popitem(last=False)

return mask

def clear(self) -> None:
"""Clear the cache."""
self._cache.clear()
Loading
Loading