From 7a2cf0045e2c9bd4e351259134d9f25f782010ab Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 27 May 2025 23:51:25 +0000 Subject: [PATCH 01/15] flex attn integration WIP --- src/fairseq2/models/transformer/__init__.py | 1 + .../models/transformer/_sdpa/_flex.py | 188 ++++++++++++++++++ src/fairseq2/recipes/common/_torch.py | 3 + .../unit/models/transformer/test_attention.py | 79 ++++++++ 4 files changed, 271 insertions(+) create mode 100644 src/fairseq2/models/transformer/_sdpa/_flex.py diff --git a/src/fairseq2/models/transformer/__init__.py b/src/fairseq2/models/transformer/__init__.py index 0d45df32c..cfe6139a2 100644 --- a/src/fairseq2/models/transformer/__init__.py +++ b/src/fairseq2/models/transformer/__init__.py @@ -144,6 +144,7 @@ ) from fairseq2.models.transformer._sdpa._flash2 import Flash2SDPA as Flash2SDPA from fairseq2.models.transformer._sdpa._flash3 import Flash3SDPA as Flash3SDPA +from fairseq2.models.transformer._sdpa._flex import FlexSDPA as FlexSDPA from fairseq2.models.transformer._sdpa._naive import NaiveSDPA as NaiveSDPA from fairseq2.models.transformer._sdpa._naive import ( naive_scaled_dot_product_attention as naive_scaled_dot_product_attention, diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py new file mode 100644 index 000000000..15883898b --- /dev/null +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -0,0 +1,188 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Callable, cast, final + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import flex_attention, or_masks +from typing_extensions import override + +from fairseq2.error import NotSupportedError +from fairseq2.nn import BatchLayout + +# isort: split + +from fairseq2.models.transformer._attention_bias import ( + AttentionBias, + AttentionBiasCache, + CausalAttentionBias, + IdentityBias, +) +from fairseq2.models.transformer._sdpa._base import SDPA + + +def _causal_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + """Standard causal attention mask.""" + return q_idx >= kv_idx + + +def _sliding_window_causal_mask_fn(window_size: int) -> Callable: + """Creates a sliding window causal mask function.""" + def mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + return (q_idx >= kv_idx) and (q_idx - kv_idx <= window_size) + return mask_fn + + +def _dropout_mask_fn(dropout_p: float, training: bool = True) -> Callable | None: + """Creates a dropout mask function.""" + if not training or dropout_p == 0.0: + return None + + def dropout_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + # Generate deterministic random number based on position + generator = torch.Generator() + generator.manual_seed(hash((b, h, q_idx, kv_idx)) % (2**32)) + rand_val = torch.rand(1, generator=generator).item() + + # Return True to keep, False to mask (opposite of dropout probability) + return rand_val >= dropout_p + + return dropout_fn + + +def _create_composed_mask(bias: AttentionBias, dropout_p: float = 0.0, training: bool = True) -> Callable | None: + """Creates a composed mask using or_mask for combining multiple mask functions.""" + + masks = [] + + # Add attention bias mask + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + masks.append(_sliding_window_causal_mask_fn(attn_window_len)) + else: + masks.append(_causal_mask_fn) + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + # Add dropout mask if needed + dropout_mask = _dropout_mask_fn(dropout_p, training) + if dropout_mask is not None: + masks.append(dropout_mask) + + # Compose masks using or_mask + if len(masks) == 0: + return None + elif len(masks) == 1: + return masks[0] + else: + # Use or_mask to combine multiple masks + return or_masks(*masks) + + +@final +class FlexSDPA(SDPA): + """Computes scaled dot-product attention using PyTorch's Flex Attention.""" + + bias: AttentionBias + dropout_p: float + + def __init__(self, bias: AttentionBias, *, dropout_p: float = 0.0) -> None: + super().__init__(bias) + + self.dropout_p = dropout_p + + @override + def forward( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + keys: Tensor, + keys_layout: BatchLayout, + values: Tensor, + bias_cache: AttentionBiasCache, + *, + needs_weights: bool = False, + ) -> tuple[Tensor, Tensor | None]: + if seqs_layout.padded or keys_layout.padded: + raise NotSupportedError(f"`{FlexSDPA}` does not support padded batches.") + + if seqs_layout.packed ^ keys_layout.packed: + raise ValueError("`seqs_layout` and `keys_layout` must be both packed.") + + # Handle dropout + if not self.training: + dropout_p = 0.0 + else: + dropout_p = self.dropout_p + + # Create the composed mask using or_mask for clean composition + mask_fn = _create_composed_mask(self.bias, dropout_p, self.training) + + if seqs_layout.packed: + # For packed sequences, we need to handle variable length sequences + # This is more complex with Flex Attention and may require custom handling + # For now, we'll reshape and use the standard flex_attention + batch_size = len(seqs_layout.seq_begin_indices_pt) - 1 + max_seq_len = seqs_layout.max_seq_len + num_heads = seqs.size(1) + head_dim = seqs.size(-1) + + # Reshape from packed format to batch format + # This is a simplified approach - in practice you'd need more sophisticated + # handling of variable length sequences + seqs_batch = seqs.new_zeros(batch_size, num_heads, max_seq_len, head_dim) + keys_batch = keys.new_zeros(batch_size, num_heads, max_seq_len, head_dim) + values_batch = values.new_zeros(batch_size, num_heads, max_seq_len, head_dim) + + for i in range(batch_size): + start_idx = seqs_layout.seq_begin_indices_pt[i] + end_idx = seqs_layout.seq_begin_indices_pt[i + 1] + seq_len = end_idx - start_idx + + seqs_batch[i, :, :seq_len] = seqs[start_idx:end_idx].transpose(0, 1) + keys_batch[i, :, :seq_len] = keys[start_idx:end_idx].transpose(0, 1) + values_batch[i, :, :seq_len] = values[start_idx:end_idx].transpose(0, 1) + + # Apply flex attention with composed mask + attns_batch = flex_attention( + seqs_batch, keys_batch, values_batch, + score_mod=mask_fn, + enable_gqa=False + ) + + # Convert back to packed format + total_len = seqs.size(0) + attns = seqs.new_zeros(total_len, num_heads, head_dim) + + for i in range(batch_size): + start_idx = seqs_layout.seq_begin_indices_pt[i] + end_idx = seqs_layout.seq_begin_indices_pt[i + 1] + seq_len = end_idx - start_idx + + attns[start_idx:end_idx] = attns_batch[i, :, :seq_len].transpose(0, 1) + else: + # Standard batch format + attns = flex_attention( + seqs, + keys, + values, + score_mod=mask_fn, + enable_gqa=False + ) + + attns = cast(Tensor, attns) + + return attns, None + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + return f"{s}, dropout_p={self.dropout_p:G}" \ No newline at end of file diff --git a/src/fairseq2/recipes/common/_torch.py b/src/fairseq2/recipes/common/_torch.py index 9794d0553..254cf85f1 100644 --- a/src/fairseq2/recipes/common/_torch.py +++ b/src/fairseq2/recipes/common/_torch.py @@ -16,6 +16,7 @@ from fairseq2.models.transformer import ( Flash2SDPA, Flash3SDPA, + FlexSDPA, NaiveSDPA, TorchSDPA, set_default_sdpa_factory, @@ -137,6 +138,8 @@ def _set_default_sdpa_variant(name: str) -> None: _set_torch_sdpa_backend(backend) except (ImportError, AttributeError): log.warning("PyTorch SDPA kernel cannot be set to '{}'. Falling back to auto mode.", backend) # fmt: skip + case "flex": + set_default_sdpa_factory(FlexSDPA) case "flash2": set_default_sdpa_factory(Flash2SDPA) case "flash3": diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index f0d71dcb3..aa6355001 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -16,6 +16,7 @@ AttentionBias, AttentionBiasCache, CausalAttentionBias, + FlexSDPA, IdentityBias, NaiveSDPA, StandardMultiheadAttention, @@ -156,3 +157,81 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: ) assert result.shape == seqs.shape + + +class TestFlexScaledDotProductAttention: + # fmt: off + @pytest.mark.parametrize("use_padding,use_bias,training", + [ + (False, False, True), + (True, True, True), + (False, True, True), + (True, False, True), + (False, False, False), + (False, True, False), + ], + ) + # fmt: on + def test_torch_flex_sdpa( + self, use_padding: bool, use_bias: bool, training: bool + ) -> None: + attn_bias: AttentionBias + + if use_bias: + attn_bias = CausalAttentionBias() + else: + attn_bias = IdentityBias() + + torch_sdpa = FlexSDPA(attn_bias) + naive_sdpa = NaiveSDPA(attn_bias) + + if training: + torch_sdpa.eval() + naive_sdpa.eval() + + kwargs = self._get_sdpa_args(use_padding) + + attn1, _ = torch_sdpa(**kwargs) + attn2, _ = naive_sdpa(**kwargs) + + assert_close(attn1, attn2) + + @staticmethod + def _get_sdpa_args(use_padding: bool) -> dict[str, object]: + batch_size = 2 + + num_heads = 4 + + source_seq_len = 3 + target_seq_len = 2 + + k_size = 2 + v_size = 3 + + def random_tensor(*args: int) -> Tensor: + return torch.randn(*args, device=device) + + q = random_tensor(batch_size, target_seq_len, num_heads, k_size) + k = random_tensor(batch_size, source_seq_len, num_heads, k_size) + v = random_tensor(batch_size, source_seq_len, num_heads, v_size) + + target_shape = (batch_size, target_seq_len) + source_shape = (batch_size, source_seq_len) + + if use_padding: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) + else: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=None, device=device) + + bias_cache = AttentionBiasCache() + + return { + "seqs": q, + "seqs_layout": q_layout, + "keys": k, + "keys_layout": k_layout, + "values": v, + "bias_cache": bias_cache, + } From e679f2017ce46dead7138a69ff63b39f5df94639 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 29 May 2025 01:13:46 +0000 Subject: [PATCH 02/15] update --- .../models/transformer/_sdpa/_flex.py | 85 ++++++++++++++----- .../unit/models/transformer/test_attention.py | 6 +- 2 files changed, 67 insertions(+), 24 deletions(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index 15883898b..e74cfd805 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -10,7 +10,8 @@ import torch from torch import Tensor -from torch.nn.attention.flex_attention import flex_attention, or_masks +from torch.nn.attention.flex_attention import and_masks, create_block_mask, flex_attention, BlockMask + from typing_extensions import override from fairseq2.error import NotSupportedError @@ -36,31 +37,45 @@ def _sliding_window_causal_mask_fn(window_size: int) -> Callable: """Creates a sliding window causal mask function.""" def mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: return (q_idx >= kv_idx) and (q_idx - kv_idx <= window_size) + return mask_fn +def _create_padding_mask_fn(seq_lens: Tensor, value_seq_lens: Tensor): + """Creates a padding mask function that masks out padding tokens.""" + def padding_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + q_valid = q_idx < seq_lens[b].item() + kv_valid = kv_idx < value_seq_lens[b].item() + return q_valid and kv_valid + + return padding_mask_fn + + def _dropout_mask_fn(dropout_p: float, training: bool = True) -> Callable | None: """Creates a dropout mask function.""" if not training or dropout_p == 0.0: return None def dropout_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: - # Generate deterministic random number based on position - generator = torch.Generator() - generator.manual_seed(hash((b, h, q_idx, kv_idx)) % (2**32)) + generator = torch.Generator() # TODO: How to set seed? rand_val = torch.rand(1, generator=generator).item() - + # Return True to keep, False to mask (opposite of dropout probability) return rand_val >= dropout_p return dropout_fn -def _create_composed_mask(bias: AttentionBias, dropout_p: float = 0.0, training: bool = True) -> Callable | None: +def _create_composed_mask( + bias: AttentionBias, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + dropout_p: float = 0.0, + training: bool = True, +) -> BlockMask | None: """Creates a composed mask using or_mask for combining multiple mask functions.""" - masks = [] - + # Add attention bias mask if isinstance(bias, CausalAttentionBias): attn_window_len = bias.attn_window_len @@ -70,20 +85,34 @@ def _create_composed_mask(bias: AttentionBias, dropout_p: float = 0.0, training: masks.append(_causal_mask_fn) elif not isinstance(bias, IdentityBias): raise NotSupportedError(f"Unsupported bias type: {bias}") - + + if seqs_layout.seq_lens_pt is not None and keys_layout.seq_lens_pt is not None: + # Add padding mask if sequence lengths are provided + masks.append(_create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt)) + # Add dropout mask if needed dropout_mask = _dropout_mask_fn(dropout_p, training) if dropout_mask is not None: masks.append(dropout_mask) - + # Compose masks using or_mask + mask_fn = None if len(masks) == 0: return None elif len(masks) == 1: - return masks[0] + mask_fn = masks[0] else: - # Use or_mask to combine multiple masks - return or_masks(*masks) + # Use and_mask to combine multiple mask functions + mask_fn = and_masks(*masks) + + block_mask = create_block_mask( + mask_fn, + B=seqs_layout.width, + H=None, + Q_LEN=seqs_layout.max_seq_len, + KV_LEN=keys_layout.max_seq_len, + ) + return block_mask @final @@ -94,8 +123,9 @@ class FlexSDPA(SDPA): dropout_p: float def __init__(self, bias: AttentionBias, *, dropout_p: float = 0.0) -> None: - super().__init__(bias) + super().__init__() + self.bias = bias self.dropout_p = dropout_p @override @@ -110,9 +140,6 @@ def forward( *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: - if seqs_layout.padded or keys_layout.padded: - raise NotSupportedError(f"`{FlexSDPA}` does not support padded batches.") - if seqs_layout.packed ^ keys_layout.packed: raise ValueError("`seqs_layout` and `keys_layout` must be both packed.") @@ -123,7 +150,16 @@ def forward( dropout_p = self.dropout_p # Create the composed mask using or_mask for clean composition - mask_fn = _create_composed_mask(self.bias, dropout_p, self.training) + block_mask = _create_composed_mask( + self.bias, seqs_layout, keys_layout, dropout_p, self.training + ) + + import pytest + pytest.set_trace() + + seqs = seqs.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) if seqs_layout.packed: # For packed sequences, we need to handle variable length sequences @@ -152,9 +188,11 @@ def forward( # Apply flex attention with composed mask attns_batch = flex_attention( - seqs_batch, keys_batch, values_batch, - score_mod=mask_fn, - enable_gqa=False + seqs_batch, + keys_batch, + values_batch, + block_mask=block_mask, + enable_gqa=False, ) # Convert back to packed format @@ -173,10 +211,11 @@ def forward( seqs, keys, values, - score_mod=mask_fn, - enable_gqa=False + block_mask=block_mask, + enable_gqa=False, ) + attns = attns.transpose(1, 2) attns = cast(Tensor, attns) return attns, None diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index aa6355001..49b36b6fe 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -25,6 +25,10 @@ from fairseq2.nn import BatchLayout from tests.common import assert_close, device +import os +os.environ["TORCH_LOGS"] = "+dynamo" +os.environ["TORCHDYNAMO_VERBOSE"] = "1" + class TestScaledDotProductAttention: # fmt: off @@ -206,7 +210,7 @@ def _get_sdpa_args(use_padding: bool) -> dict[str, object]: target_seq_len = 2 k_size = 2 - v_size = 3 + v_size = 4 def random_tensor(*args: int) -> Tensor: return torch.randn(*args, device=device) From ce82414ebec560fee2627f8c815a377f0d81ff8c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 30 May 2025 03:26:29 +0000 Subject: [PATCH 03/15] adding packing support, tests --- .../models/transformer/_sdpa/_flex.py | 196 +++++++++++------- .../models/transformer/_sdpa/_naive.py | 2 + src/fairseq2/recipes/config.py | 1 + .../unit/models/transformer/test_attention.py | 138 +++++++----- 4 files changed, 207 insertions(+), 130 deletions(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index e74cfd805..29ed2f60d 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -14,6 +14,7 @@ from typing_extensions import override +from fairseq2.device import Device from fairseq2.error import NotSupportedError from fairseq2.nn import BatchLayout @@ -41,12 +42,53 @@ def mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: return mask_fn +def _offsets_to_doc_ids_tensor(offsets): + """Convert offsets to document IDs for packed sequences.""" + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave( + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) + + +def _create_packed_mask_fn( + seq_begin_indices: Tensor, + seq_lens: Tensor, + keys_begin_indices: Tensor, + keys_seq_lens: Tensor, + base_mask_fn: Callable | None = None +): + """Creates a mask function for packed sequences using document-based masking.""" + # Create document IDs for queries and keys + query_doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) + key_doc_ids = _offsets_to_doc_ids_tensor(keys_begin_indices) + + def packed_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + # Check if query and key belong to the same document + same_doc = query_doc_ids[q_idx] == key_doc_ids[kv_idx] + + # Convert global indices to logical positions within documents + q_doc_id = query_doc_ids[q_idx] + kv_doc_id = key_doc_ids[kv_idx] + q_logical = q_idx - seq_begin_indices[q_doc_id] + kv_logical = kv_idx - keys_begin_indices[kv_doc_id] + + # Apply base mask (e.g., causal) to logical positions + if base_mask_fn is not None: + inner_mask = base_mask_fn(b, h, q_logical, kv_logical) + return same_doc & inner_mask + else: + return same_doc + + return packed_mask_fn + + def _create_padding_mask_fn(seq_lens: Tensor, value_seq_lens: Tensor): """Creates a padding mask function that masks out padding tokens.""" def padding_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: - q_valid = q_idx < seq_lens[b].item() - kv_valid = kv_idx < value_seq_lens[b].item() - return q_valid and kv_valid + q_valid = q_idx < seq_lens[b] + kv_valid = kv_idx < value_seq_lens[b] + return q_valid & kv_valid return padding_mask_fn @@ -70,47 +112,84 @@ def _create_composed_mask( bias: AttentionBias, seqs_layout: BatchLayout, keys_layout: BatchLayout, + device: Device, dropout_p: float = 0.0, training: bool = True, + packed: bool = False, ) -> BlockMask | None: - """Creates a composed mask using or_mask for combining multiple mask functions.""" + """Creates a composed mask using and_mask for combining multiple mask functions.""" masks = [] - # Add attention bias mask - if isinstance(bias, CausalAttentionBias): - attn_window_len = bias.attn_window_len - if attn_window_len is not None: - masks.append(_sliding_window_causal_mask_fn(attn_window_len)) - else: - masks.append(_causal_mask_fn) - elif not isinstance(bias, IdentityBias): - raise NotSupportedError(f"Unsupported bias type: {bias}") - - if seqs_layout.seq_lens_pt is not None and keys_layout.seq_lens_pt is not None: - # Add padding mask if sequence lengths are provided - masks.append(_create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt)) + if packed: + # For packed sequences, create the base mask function first + base_mask_fn = None + + # Add attention bias mask as base mask + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + base_mask_fn = _sliding_window_causal_mask_fn(attn_window_len) + else: + base_mask_fn = _causal_mask_fn + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + # Create the packed sequence mask that incorporates the base mask + packed_mask = _create_packed_mask_fn( + seqs_layout.seq_begin_indices_pt, + seqs_layout.seq_lens_pt, + keys_layout.seq_begin_indices_pt, + keys_layout.seq_lens_pt, + base_mask_fn + ) + masks.append(packed_mask) + + else: + # Standard batch format - handle bias and padding separately + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + masks.append(_sliding_window_causal_mask_fn(attn_window_len)) + else: + masks.append(_causal_mask_fn) + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + if seqs_layout.seq_lens_pt is not None and keys_layout.seq_lens_pt is not None: + # Add padding mask if sequence lengths are provided + masks.append(_create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt)) # Add dropout mask if needed dropout_mask = _dropout_mask_fn(dropout_p, training) if dropout_mask is not None: masks.append(dropout_mask) - # Compose masks using or_mask + # Compose masks using and_mask mask_fn = None if len(masks) == 0: return None elif len(masks) == 1: mask_fn = masks[0] else: - # Use and_mask to combine multiple mask functions mask_fn = and_masks(*masks) + # For packed sequences, use the total sequence length + if packed: + total_seq_len = seqs_layout.seq_begin_indices_pt[-1].item() + total_keys_len = keys_layout.seq_begin_indices_pt[-1].item() + batch_size = 1 # Packed format treats everything as one big batch + else: + total_seq_len = seqs_layout.max_seq_len + total_keys_len = keys_layout.max_seq_len + batch_size = seqs_layout.width + block_mask = create_block_mask( mask_fn, - B=seqs_layout.width, + B=batch_size, H=None, - Q_LEN=seqs_layout.max_seq_len, - KV_LEN=keys_layout.max_seq_len, + Q_LEN=total_seq_len, + KV_LEN=total_keys_len, + device=device, ) return block_mask @@ -149,71 +228,28 @@ def forward( else: dropout_p = self.dropout_p - # Create the composed mask using or_mask for clean composition + # Create the composed block mask using and_mask block_mask = _create_composed_mask( - self.bias, seqs_layout, keys_layout, dropout_p, self.training + self.bias, + seqs_layout, + keys_layout, + seqs.device, + dropout_p, + self.training, + packed=seqs_layout.packed, ) - import pytest - pytest.set_trace() - seqs = seqs.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) - if seqs_layout.packed: - # For packed sequences, we need to handle variable length sequences - # This is more complex with Flex Attention and may require custom handling - # For now, we'll reshape and use the standard flex_attention - batch_size = len(seqs_layout.seq_begin_indices_pt) - 1 - max_seq_len = seqs_layout.max_seq_len - num_heads = seqs.size(1) - head_dim = seqs.size(-1) - - # Reshape from packed format to batch format - # This is a simplified approach - in practice you'd need more sophisticated - # handling of variable length sequences - seqs_batch = seqs.new_zeros(batch_size, num_heads, max_seq_len, head_dim) - keys_batch = keys.new_zeros(batch_size, num_heads, max_seq_len, head_dim) - values_batch = values.new_zeros(batch_size, num_heads, max_seq_len, head_dim) - - for i in range(batch_size): - start_idx = seqs_layout.seq_begin_indices_pt[i] - end_idx = seqs_layout.seq_begin_indices_pt[i + 1] - seq_len = end_idx - start_idx - - seqs_batch[i, :, :seq_len] = seqs[start_idx:end_idx].transpose(0, 1) - keys_batch[i, :, :seq_len] = keys[start_idx:end_idx].transpose(0, 1) - values_batch[i, :, :seq_len] = values[start_idx:end_idx].transpose(0, 1) - - # Apply flex attention with composed mask - attns_batch = flex_attention( - seqs_batch, - keys_batch, - values_batch, - block_mask=block_mask, - enable_gqa=False, - ) - - # Convert back to packed format - total_len = seqs.size(0) - attns = seqs.new_zeros(total_len, num_heads, head_dim) - - for i in range(batch_size): - start_idx = seqs_layout.seq_begin_indices_pt[i] - end_idx = seqs_layout.seq_begin_indices_pt[i + 1] - seq_len = end_idx - start_idx - - attns[start_idx:end_idx] = attns_batch[i, :, :seq_len].transpose(0, 1) - else: - # Standard batch format - attns = flex_attention( - seqs, - keys, - values, - block_mask=block_mask, - enable_gqa=False, - ) + attns = flex_attention( + seqs, + keys, + values, + block_mask=block_mask, + enable_gqa=False, + ) attns = attns.transpose(1, 2) attns = cast(Tensor, attns) diff --git a/src/fairseq2/models/transformer/_sdpa/_naive.py b/src/fairseq2/models/transformer/_sdpa/_naive.py index 1d2373ae5..50f68c264 100644 --- a/src/fairseq2/models/transformer/_sdpa/_naive.py +++ b/src/fairseq2/models/transformer/_sdpa/_naive.py @@ -63,6 +63,8 @@ def forward( q, k, v = seqs, keys, values + import pytest; pytest.set_trace() + # (N, S, H, K) -> (N, H, S, K) q = q.transpose(-2, -3) diff --git a/src/fairseq2/recipes/config.py b/src/fairseq2/recipes/config.py index 731b2f21f..63fc9b17a 100644 --- a/src/fairseq2/recipes/config.py +++ b/src/fairseq2/recipes/config.py @@ -465,6 +465,7 @@ class CommonSection: "torch_flash", "flash2", "flash3", + "flex", "naive", ] diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index 49b36b6fe..d088d26cf 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -25,26 +25,26 @@ from fairseq2.nn import BatchLayout from tests.common import assert_close, device -import os -os.environ["TORCH_LOGS"] = "+dynamo" -os.environ["TORCHDYNAMO_VERBOSE"] = "1" - class TestScaledDotProductAttention: # fmt: off - @pytest.mark.parametrize("use_padding,use_bias,training", + @pytest.mark.parametrize("use_padding,use_packing,use_bias,training", [ - (False, False, True), - (True, True, True), - (False, True, True), - (True, False, True), - (False, False, False), - (False, True, False), + (False, False, False, True), + (True, False, True, True), + (False, False, True, True), + (True, False, False, True), + (False, True, False, True), + (False, False, False, False), + (False, False, True, False), ], ) # fmt: on def test_torch_sdpa( - self, use_padding: bool, use_bias: bool, training: bool + self, use_padding: bool, + use_packing: bool, + use_bias: bool, + training: bool, ) -> None: attn_bias: AttentionBias @@ -60,7 +60,7 @@ def test_torch_sdpa( torch_sdpa.eval() naive_sdpa.eval() - kwargs = self._get_sdpa_args(use_padding) + kwargs = self._get_sdpa_args(use_padding, use_packing) attn1, _ = torch_sdpa(**kwargs) attn2, _ = naive_sdpa(**kwargs) @@ -68,7 +68,7 @@ def test_torch_sdpa( assert_close(attn1, attn2) @staticmethod - def _get_sdpa_args(use_padding: bool) -> dict[str, Any]: + def _get_sdpa_args(use_padding: bool, use_packing: bool) -> dict[str, Any]: batch_size = 2 num_heads = 4 @@ -82,19 +82,35 @@ def _get_sdpa_args(use_padding: bool) -> dict[str, Any]: def random_tensor(*args: int) -> Tensor: return torch.randn(*args, device=device) - q = random_tensor(batch_size, target_seq_len, num_heads, k_size) - k = random_tensor(batch_size, source_seq_len, num_heads, k_size) - v = random_tensor(batch_size, source_seq_len, num_heads, v_size) + if use_packing: + # For packing, we need 1D tensors with total sequence length + total_target_len = sum([2, 3]) # seq_lens for target + total_source_len = sum([2, 3]) # seq_lens for source + + q = random_tensor(total_target_len, num_heads, k_size) + k = random_tensor(total_source_len, num_heads, k_size) + v = random_tensor(total_source_len, num_heads, v_size) + + target_shape = (total_target_len,) + source_shape = (total_source_len,) + + q_layout = BatchLayout(target_shape, seq_lens=[2, 3], packed=True, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], packed=True, device=device) + else: + # For non-packing cases (regular 2D tensors) + q = random_tensor(batch_size, target_seq_len, num_heads, k_size) + k = random_tensor(batch_size, source_seq_len, num_heads, k_size) + v = random_tensor(batch_size, source_seq_len, num_heads, v_size) - target_shape = (batch_size, target_seq_len) - source_shape = (batch_size, source_seq_len) + target_shape = (batch_size, target_seq_len) + source_shape = (batch_size, source_seq_len) - if use_padding: - q_layout = BatchLayout(target_shape, seq_lens=None, device=device) - k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) - else: - q_layout = BatchLayout(target_shape, seq_lens=None, device=device) - k_layout = BatchLayout(source_shape, seq_lens=None, device=device) + if use_padding: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) + else: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=None, device=device) bias_cache = AttentionBiasCache() @@ -165,19 +181,24 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: class TestFlexScaledDotProductAttention: # fmt: off - @pytest.mark.parametrize("use_padding,use_bias,training", + @pytest.mark.parametrize("use_padding,use_packing,use_bias,training", [ - (False, False, True), - (True, True, True), - (False, True, True), - (True, False, True), - (False, False, False), - (False, True, False), + (False, False, False, True), + (True, False, True, True), + (False, False, True, True), + (True, False, False, True), + (False, True, False, True), + (False, True, True, True), + (False, False, False, False), + (False, False, True, False), ], ) # fmt: on - def test_torch_flex_sdpa( - self, use_padding: bool, use_bias: bool, training: bool + def test_flex_sdpa( + self, use_padding: bool, + use_packing: bool, + use_bias: bool, + training: bool, ) -> None: attn_bias: AttentionBias @@ -186,22 +207,22 @@ def test_torch_flex_sdpa( else: attn_bias = IdentityBias() - torch_sdpa = FlexSDPA(attn_bias) + flex_sdpa = FlexSDPA(attn_bias) naive_sdpa = NaiveSDPA(attn_bias) if training: - torch_sdpa.eval() + flex_sdpa.eval() naive_sdpa.eval() - kwargs = self._get_sdpa_args(use_padding) + kwargs = self._get_sdpa_args(use_padding, use_packing) - attn1, _ = torch_sdpa(**kwargs) + attn1, _ = flex_sdpa(**kwargs) attn2, _ = naive_sdpa(**kwargs) assert_close(attn1, attn2) @staticmethod - def _get_sdpa_args(use_padding: bool) -> dict[str, object]: + def _get_sdpa_args(use_padding: bool, use_packing: bool) -> dict[str, object]: batch_size = 2 num_heads = 4 @@ -215,19 +236,35 @@ def _get_sdpa_args(use_padding: bool) -> dict[str, object]: def random_tensor(*args: int) -> Tensor: return torch.randn(*args, device=device) - q = random_tensor(batch_size, target_seq_len, num_heads, k_size) - k = random_tensor(batch_size, source_seq_len, num_heads, k_size) - v = random_tensor(batch_size, source_seq_len, num_heads, v_size) + if use_packing: + # For packing, we need 1D tensors with total sequence length + total_target_len = sum([2, 3]) # seq_lens for target + total_source_len = sum([2, 3]) # seq_lens for source + + q = random_tensor(1, total_target_len, num_heads, k_size) + k = random_tensor(1, total_source_len, num_heads, k_size) + v = random_tensor(1, total_source_len, num_heads, v_size) + + target_shape = (total_target_len,) + source_shape = (total_source_len,) + + q_layout = BatchLayout(target_shape, seq_lens=[2, 3], packed=True, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], packed=True, device=device) + else: + # For non-packing cases (regular 2D tensors) + q = random_tensor(batch_size, target_seq_len, num_heads, k_size) + k = random_tensor(batch_size, source_seq_len, num_heads, k_size) + v = random_tensor(batch_size, source_seq_len, num_heads, v_size) - target_shape = (batch_size, target_seq_len) - source_shape = (batch_size, source_seq_len) + target_shape = (batch_size, target_seq_len) + source_shape = (batch_size, source_seq_len) - if use_padding: - q_layout = BatchLayout(target_shape, seq_lens=None, device=device) - k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) - else: - q_layout = BatchLayout(target_shape, seq_lens=None, device=device) - k_layout = BatchLayout(source_shape, seq_lens=None, device=device) + if use_padding: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=[2, 3], device=device) + else: + q_layout = BatchLayout(target_shape, seq_lens=None, device=device) + k_layout = BatchLayout(source_shape, seq_lens=None, device=device) bias_cache = AttentionBiasCache() @@ -239,3 +276,4 @@ def random_tensor(*args: int) -> Tensor: "values": v, "bias_cache": bias_cache, } + From 9f338745c1d448f1544bc3365e981511973935a9 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 30 May 2025 04:02:41 +0000 Subject: [PATCH 04/15] lint --- .../models/transformer/_sdpa/_flex.py | 90 ++++++++++--------- .../models/transformer/_sdpa/_naive.py | 2 - .../unit/models/transformer/test_attention.py | 35 +++++--- 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index 29ed2f60d..acbe07f1d 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -6,12 +6,16 @@ from __future__ import annotations -from typing import Callable, cast, final +from typing import Callable, TypeAlias, cast, final import torch from torch import Tensor -from torch.nn.attention.flex_attention import and_masks, create_block_mask, flex_attention, BlockMask - +from torch.nn.attention.flex_attention import ( + BlockMask, + and_masks, + create_block_mask, + flex_attention, +) from typing_extensions import override from fairseq2.device import Device @@ -28,21 +32,24 @@ ) from fairseq2.models.transformer._sdpa._base import SDPA +MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -def _causal_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + +def _causal_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: """Standard causal attention mask.""" return q_idx >= kv_idx -def _sliding_window_causal_mask_fn(window_size: int) -> Callable: +def _sliding_window_causal_mask_fn(window_size: int) -> MaskFunction: """Creates a sliding window causal mask function.""" - def mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + + def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: return (q_idx >= kv_idx) and (q_idx - kv_idx <= window_size) return mask_fn -def _offsets_to_doc_ids_tensor(offsets): +def _offsets_to_doc_ids_tensor(offsets: Tensor) -> Tensor: """Convert offsets to document IDs for packed sequences.""" device = offsets.device counts = offsets[1:] - offsets[:-1] @@ -53,39 +60,38 @@ def _offsets_to_doc_ids_tensor(offsets): def _create_packed_mask_fn( seq_begin_indices: Tensor, - seq_lens: Tensor, keys_begin_indices: Tensor, - keys_seq_lens: Tensor, - base_mask_fn: Callable | None = None -): + base_mask_fn: MaskFunction | None = None, +) -> MaskFunction: """Creates a mask function for packed sequences using document-based masking.""" # Create document IDs for queries and keys query_doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) key_doc_ids = _offsets_to_doc_ids_tensor(keys_begin_indices) - - def packed_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + + def packed_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: # Check if query and key belong to the same document same_doc = query_doc_ids[q_idx] == key_doc_ids[kv_idx] - + # Convert global indices to logical positions within documents q_doc_id = query_doc_ids[q_idx] kv_doc_id = key_doc_ids[kv_idx] q_logical = q_idx - seq_begin_indices[q_doc_id] kv_logical = kv_idx - keys_begin_indices[kv_doc_id] - + # Apply base mask (e.g., causal) to logical positions if base_mask_fn is not None: inner_mask = base_mask_fn(b, h, q_logical, kv_logical) return same_doc & inner_mask else: return same_doc - + return packed_mask_fn -def _create_padding_mask_fn(seq_lens: Tensor, value_seq_lens: Tensor): +def _create_padding_mask_fn(seq_lens: Tensor, value_seq_lens: Tensor) -> MaskFunction: """Creates a padding mask function that masks out padding tokens.""" - def padding_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + + def padding_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: q_valid = q_idx < seq_lens[b] kv_valid = kv_idx < value_seq_lens[b] return q_valid & kv_valid @@ -93,18 +99,18 @@ def padding_mask_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: return padding_mask_fn -def _dropout_mask_fn(dropout_p: float, training: bool = True) -> Callable | None: +def _dropout_mask_fn(dropout_p: float, training: bool = True) -> MaskFunction | None: """Creates a dropout mask function.""" if not training or dropout_p == 0.0: return None - - def dropout_fn(b: int, h: int, q_idx: int, kv_idx: int) -> bool: + + def dropout_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: generator = torch.Generator() # TODO: How to set seed? - rand_val = torch.rand(1, generator=generator).item() + rand_val = torch.rand(1, generator=generator) # Return True to keep, False to mask (opposite of dropout probability) return rand_val >= dropout_p - + return dropout_fn @@ -123,7 +129,7 @@ def _create_composed_mask( if packed: # For packed sequences, create the base mask function first base_mask_fn = None - + # Add attention bias mask as base mask if isinstance(bias, CausalAttentionBias): attn_window_len = bias.attn_window_len @@ -133,17 +139,14 @@ def _create_composed_mask( base_mask_fn = _causal_mask_fn elif not isinstance(bias, IdentityBias): raise NotSupportedError(f"Unsupported bias type: {bias}") - + # Create the packed sequence mask that incorporates the base mask packed_mask = _create_packed_mask_fn( seqs_layout.seq_begin_indices_pt, - seqs_layout.seq_lens_pt, keys_layout.seq_begin_indices_pt, - keys_layout.seq_lens_pt, - base_mask_fn + base_mask_fn, ) masks.append(packed_mask) - else: # Standard batch format - handle bias and padding separately if isinstance(bias, CausalAttentionBias): @@ -155,9 +158,10 @@ def _create_composed_mask( elif not isinstance(bias, IdentityBias): raise NotSupportedError(f"Unsupported bias type: {bias}") - if seqs_layout.seq_lens_pt is not None and keys_layout.seq_lens_pt is not None: - # Add padding mask if sequence lengths are provided - masks.append(_create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt)) + # Add padding mask + masks.append( + _create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt) + ) # Add dropout mask if needed dropout_mask = _dropout_mask_fn(dropout_p, training) @@ -175,8 +179,8 @@ def _create_composed_mask( # For packed sequences, use the total sequence length if packed: - total_seq_len = seqs_layout.seq_begin_indices_pt[-1].item() - total_keys_len = keys_layout.seq_begin_indices_pt[-1].item() + total_seq_len = int(seqs_layout.seq_begin_indices_pt[-1].item()) + total_keys_len = int(keys_layout.seq_begin_indices_pt[-1].item()) batch_size = 1 # Packed format treats everything as one big batch else: total_seq_len = seqs_layout.max_seq_len @@ -189,7 +193,7 @@ def _create_composed_mask( H=None, Q_LEN=total_seq_len, KV_LEN=total_keys_len, - device=device, + device=str(device), ) return block_mask @@ -230,11 +234,11 @@ def forward( # Create the composed block mask using and_mask block_mask = _create_composed_mask( - self.bias, - seqs_layout, - keys_layout, - seqs.device, - dropout_p, + self.bias, + seqs_layout, + keys_layout, + seqs.device, + dropout_p, self.training, packed=seqs_layout.packed, ) @@ -251,8 +255,10 @@ def forward( enable_gqa=False, ) + if isinstance(attns, tuple): + attns, _ = attns + attns = attns.transpose(1, 2) - attns = cast(Tensor, attns) return attns, None @@ -260,4 +266,4 @@ def extra_repr(self) -> str: """:meta private:""" s = super().extra_repr() - return f"{s}, dropout_p={self.dropout_p:G}" \ No newline at end of file + return f"{s}, dropout_p={self.dropout_p:G}" diff --git a/src/fairseq2/models/transformer/_sdpa/_naive.py b/src/fairseq2/models/transformer/_sdpa/_naive.py index 50f68c264..1d2373ae5 100644 --- a/src/fairseq2/models/transformer/_sdpa/_naive.py +++ b/src/fairseq2/models/transformer/_sdpa/_naive.py @@ -63,8 +63,6 @@ def forward( q, k, v = seqs, keys, values - import pytest; pytest.set_trace() - # (N, S, H, K) -> (N, H, S, K) q = q.transpose(-2, -3) diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index d088d26cf..9f6b8f83e 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -41,7 +41,8 @@ class TestScaledDotProductAttention: ) # fmt: on def test_torch_sdpa( - self, use_padding: bool, + self, + use_padding: bool, use_packing: bool, use_bias: bool, training: bool, @@ -86,16 +87,20 @@ def random_tensor(*args: int) -> Tensor: # For packing, we need 1D tensors with total sequence length total_target_len = sum([2, 3]) # seq_lens for target total_source_len = sum([2, 3]) # seq_lens for source - + q = random_tensor(total_target_len, num_heads, k_size) k = random_tensor(total_source_len, num_heads, k_size) v = random_tensor(total_source_len, num_heads, v_size) - + target_shape = (total_target_len,) source_shape = (total_source_len,) - - q_layout = BatchLayout(target_shape, seq_lens=[2, 3], packed=True, device=device) - k_layout = BatchLayout(source_shape, seq_lens=[2, 3], packed=True, device=device) + + q_layout = BatchLayout( + target_shape, seq_lens=[2, 3], packed=True, device=device + ) + k_layout = BatchLayout( + source_shape, seq_lens=[2, 3], packed=True, device=device + ) else: # For non-packing cases (regular 2D tensors) q = random_tensor(batch_size, target_seq_len, num_heads, k_size) @@ -195,7 +200,8 @@ class TestFlexScaledDotProductAttention: ) # fmt: on def test_flex_sdpa( - self, use_padding: bool, + self, + use_padding: bool, use_packing: bool, use_bias: bool, training: bool, @@ -240,16 +246,20 @@ def random_tensor(*args: int) -> Tensor: # For packing, we need 1D tensors with total sequence length total_target_len = sum([2, 3]) # seq_lens for target total_source_len = sum([2, 3]) # seq_lens for source - + q = random_tensor(1, total_target_len, num_heads, k_size) k = random_tensor(1, total_source_len, num_heads, k_size) v = random_tensor(1, total_source_len, num_heads, v_size) - + target_shape = (total_target_len,) source_shape = (total_source_len,) - - q_layout = BatchLayout(target_shape, seq_lens=[2, 3], packed=True, device=device) - k_layout = BatchLayout(source_shape, seq_lens=[2, 3], packed=True, device=device) + + q_layout = BatchLayout( + target_shape, seq_lens=[2, 3], packed=True, device=device + ) + k_layout = BatchLayout( + source_shape, seq_lens=[2, 3], packed=True, device=device + ) else: # For non-packing cases (regular 2D tensors) q = random_tensor(batch_size, target_seq_len, num_heads, k_size) @@ -276,4 +286,3 @@ def random_tensor(*args: int) -> Tensor: "values": v, "bias_cache": bias_cache, } - From c06d06d67c00eee0ae3340fe9bc0038687c6733c Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 30 May 2025 04:28:24 +0000 Subject: [PATCH 05/15] small changes --- src/fairseq2/models/transformer/_sdpa/_flex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index acbe07f1d..90963c6cd 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Callable, TypeAlias, cast, final +from typing import Callable, TypeAlias, final import torch from torch import Tensor @@ -44,7 +44,7 @@ def _sliding_window_causal_mask_fn(window_size: int) -> MaskFunction: """Creates a sliding window causal mask function.""" def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - return (q_idx >= kv_idx) and (q_idx - kv_idx <= window_size) + return (q_idx >= kv_idx) & (q_idx - kv_idx <= window_size) return mask_fn From 2be8611327d5777ada92ecd73b541b854c9e6f65 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 14 Jun 2025 21:01:08 +0000 Subject: [PATCH 06/15] all tests passing --- .../models/transformer/_sdpa/_flex.py | 136 ++++--- .../models/transformer/sdpa/diagnostics.py | 353 ++++++++++++++++++ .../transformer/sdpa/test_flex_mask_fns.py | 165 ++++++++ .../unit/models/transformer/test_attention.py | 40 +- 4 files changed, 620 insertions(+), 74 deletions(-) create mode 100644 tests/unit/models/transformer/sdpa/diagnostics.py create mode 100644 tests/unit/models/transformer/sdpa/test_flex_mask_fns.py diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index 90963c6cd..e75ead298 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -35,16 +35,47 @@ MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -def _causal_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - """Standard causal attention mask.""" - return q_idx >= kv_idx +def _causal_mask_fn( + q_lens: Tensor, kv_lens: Tensor +) -> MaskFunction: + """Creates a causal mask function.""" + + def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Get sequence lengths for this batch + q_len = q_lens[b] + kv_len = kv_lens[b] + + # Calculate diagonal offset + d = kv_len - q_len + + return q_idx >= kv_idx - d + + return mask_fn -def _sliding_window_causal_mask_fn(window_size: int) -> MaskFunction: - """Creates a sliding window causal mask function.""" +def _sliding_window_causal_mask_fn( + window_size: int, + q_lens: Tensor, + kv_lens: Tensor, +) -> MaskFunction: + """Creates a sliding window causal mask functions.""" def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - return (q_idx >= kv_idx) & (q_idx - kv_idx <= window_size) + # Get sequence lengths for this batch + q_len = q_lens[b] + kv_len = kv_lens[b] + + # Calculate diagonal offset + d = kv_len - q_len + + # For window_size=1, only allow the exact diagonal position + if window_size == 1: + return q_idx == kv_idx - d + else: + # For larger windows, use the range logic + causal_mask = q_idx >= kv_idx - d + window_mask = q_idx >= kv_idx - d - window_size + 1 + return causal_mask & window_mask return mask_fn @@ -88,42 +119,26 @@ def packed_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tenso return packed_mask_fn -def _create_padding_mask_fn(seq_lens: Tensor, value_seq_lens: Tensor) -> MaskFunction: +def _create_padding_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: """Creates a padding mask function that masks out padding tokens.""" def padding_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - q_valid = q_idx < seq_lens[b] - kv_valid = kv_idx < value_seq_lens[b] + q_valid = q_idx < q_lens[b] + kv_valid = kv_idx < kv_lens[b] return q_valid & kv_valid return padding_mask_fn -def _dropout_mask_fn(dropout_p: float, training: bool = True) -> MaskFunction | None: - """Creates a dropout mask function.""" - if not training or dropout_p == 0.0: - return None - - def dropout_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - generator = torch.Generator() # TODO: How to set seed? - rand_val = torch.rand(1, generator=generator) - - # Return True to keep, False to mask (opposite of dropout probability) - return rand_val >= dropout_p - - return dropout_fn - - def _create_composed_mask( bias: AttentionBias, seqs_layout: BatchLayout, keys_layout: BatchLayout, device: Device, - dropout_p: float = 0.0, - training: bool = True, + *, packed: bool = False, ) -> BlockMask | None: - """Creates a composed mask using and_mask for combining multiple mask functions.""" + """Creates a composed mask using and_masks for combining multiple mask functions.""" masks = [] if packed: @@ -134,9 +149,16 @@ def _create_composed_mask( if isinstance(bias, CausalAttentionBias): attn_window_len = bias.attn_window_len if attn_window_len is not None: - base_mask_fn = _sliding_window_causal_mask_fn(attn_window_len) + base_mask_fn = _sliding_window_causal_mask_fn( + attn_window_len, + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) else: - base_mask_fn = _causal_mask_fn + base_mask_fn = _causal_mask_fn( + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) elif not isinstance(bias, IdentityBias): raise NotSupportedError(f"Unsupported bias type: {bias}") @@ -152,23 +174,30 @@ def _create_composed_mask( if isinstance(bias, CausalAttentionBias): attn_window_len = bias.attn_window_len if attn_window_len is not None: - masks.append(_sliding_window_causal_mask_fn(attn_window_len)) + masks.append( + _sliding_window_causal_mask_fn( + attn_window_len, + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + ) else: - masks.append(_causal_mask_fn) + masks.append( + _causal_mask_fn( + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + ) elif not isinstance(bias, IdentityBias): raise NotSupportedError(f"Unsupported bias type: {bias}") - # Add padding mask + # Add padding mask + if seqs_layout.padded or keys_layout.padded: masks.append( _create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt) ) - # Add dropout mask if needed - dropout_mask = _dropout_mask_fn(dropout_p, training) - if dropout_mask is not None: - masks.append(dropout_mask) - - # Compose masks using and_mask + # Compose masks mask_fn = None if len(masks) == 0: return None @@ -177,16 +206,16 @@ def _create_composed_mask( else: mask_fn = and_masks(*masks) - # For packed sequences, use the total sequence length if packed: total_seq_len = int(seqs_layout.seq_begin_indices_pt[-1].item()) total_keys_len = int(keys_layout.seq_begin_indices_pt[-1].item()) - batch_size = 1 # Packed format treats everything as one big batch + batch_size = 1 else: total_seq_len = seqs_layout.max_seq_len total_keys_len = keys_layout.max_seq_len - batch_size = seqs_layout.width + batch_size = len(seqs_layout.seq_lens) + # Create the block mask block_mask = create_block_mask( mask_fn, B=batch_size, @@ -203,13 +232,11 @@ class FlexSDPA(SDPA): """Computes scaled dot-product attention using PyTorch's Flex Attention.""" bias: AttentionBias - dropout_p: float - def __init__(self, bias: AttentionBias, *, dropout_p: float = 0.0) -> None: + def __init__(self, bias: AttentionBias) -> None: super().__init__() self.bias = bias - self.dropout_p = dropout_p @override def forward( @@ -226,11 +253,12 @@ def forward( if seqs_layout.packed ^ keys_layout.packed: raise ValueError("`seqs_layout` and `keys_layout` must be both packed.") - # Handle dropout - if not self.training: - dropout_p = 0.0 - else: - dropout_p = self.dropout_p + unsqueezed = False + if seqs.ndim == 3: + unsqueezed = True + seqs = seqs.unsqueeze(0) + keys = keys.unsqueeze(0) + values = values.unsqueeze(0) # Create the composed block mask using and_mask block_mask = _create_composed_mask( @@ -238,8 +266,6 @@ def forward( seqs_layout, keys_layout, seqs.device, - dropout_p, - self.training, packed=seqs_layout.packed, ) @@ -259,11 +285,7 @@ def forward( attns, _ = attns attns = attns.transpose(1, 2) + if unsqueezed: + attns = attns.squeeze(0) return attns, None - - def extra_repr(self) -> str: - """:meta private:""" - s = super().extra_repr() - - return f"{s}, dropout_p={self.dropout_p:G}" diff --git a/tests/unit/models/transformer/sdpa/diagnostics.py b/tests/unit/models/transformer/sdpa/diagnostics.py new file mode 100644 index 000000000..4b667ee7f --- /dev/null +++ b/tests/unit/models/transformer/sdpa/diagnostics.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +""" +Script to test materialize_block_mask_from_object with various block masks. +Tests different mask patterns and validates correctness using the existing functions. +""" + +import torch +from torch.nn.attention.flex_attention import create_block_mask, and_masks + +# Import the functions from the module (assuming they're available) +from fairseq2.models.transformer._sdpa._flex import ( + materialize_block_mask_from_object, + _causal_mask_fn, + _sliding_window_causal_mask_fn, + _create_padding_mask_fn, + _create_packed_mask_fn, + _offsets_to_doc_ids_tensor, +) + + +def create_reference_mask_from_function(mask_fn, shape, block_size): + """Create reference mask by applying mask function at block level""" + B, H, Q_LEN, KV_LEN = shape + Q_BLOCKS = (Q_LEN + block_size - 1) // block_size + KV_BLOCKS = (KV_LEN + block_size - 1) // block_size + + mask = torch.zeros(B, H, Q_LEN, KV_LEN, dtype=torch.bool) + + for b in range(B): + for h in range(H): + for q_block in range(Q_BLOCKS): + for kv_block in range(KV_BLOCKS): + if mask_fn( + torch.tensor(b), + torch.tensor(h), + torch.tensor(q_block), + torch.tensor(kv_block), + ): + # Fill the entire block + q_start = q_block * block_size + q_end = min(q_start + block_size, Q_LEN) + kv_start = kv_block * block_size + kv_end = min(kv_start + block_size, KV_LEN) + + mask[b, h, q_start:q_end, kv_start:kv_end] = True + + return mask + + +def compare_masks(materialized_mask, reference_mask, name): + """Compare materialized mask with reference mask""" + print(f"\n--- Comparison for {name} ---") + print(f"Materialized shape: {materialized_mask.shape}") + print(f"Reference shape: {reference_mask.shape}") + + if materialized_mask.shape == reference_mask.shape: + matches = torch.equal(materialized_mask, reference_mask) + print(f"Masks match: {matches}") + + if not matches: + diff = materialized_mask != reference_mask + print(f"Number of differences: {diff.sum().item()}") + print(f"Total elements: {materialized_mask.numel()}") + print( + f"Difference percentage: {100 * diff.sum().item() / materialized_mask.numel():.2f}%" + ) + + # Sparsity comparison + mat_sparsity = materialized_mask.sum().item() / materialized_mask.numel() + ref_sparsity = reference_mask.sum().item() / reference_mask.numel() + print(f"Materialized sparsity: {mat_sparsity:.4f}") + print(f"Reference sparsity: {ref_sparsity:.4f}") + + return matches + else: + print("Shape mismatch!") + return False + + +def test_basic_masks(): + """Test basic mask functions""" + print("=" * 80) + print("TESTING BASIC MASK FUNCTIONS") + print("=" * 80) + + B, H = 1, 1 + Q_LEN, KV_LEN = 128, 128 + BLOCK_SIZE = 32 + + test_cases = [ + ("Causal Mask", _causal_mask_fn), + ("Sliding Window (window=2)", _sliding_window_causal_mask_fn(2)), + ("Sliding Window (window=5)", _sliding_window_causal_mask_fn(5)), + ] + + for name, mask_fn in test_cases: + print(f"\nTesting: {name}") + print("-" * 40) + + # Create block mask + block_mask = create_block_mask( + mask_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE + ) + + print(f"Block mask shape: {block_mask.shape}") + print(f"Block mask sparsity: {block_mask.sparsity():.2f}%") + + # Materialize using our function + materialized = materialize_block_mask_from_object(block_mask) + + # Create reference + reference = create_reference_mask_from_function( + mask_fn, (B, H, Q_LEN, KV_LEN), BLOCK_SIZE + ) + + # Compare + matches = compare_masks(materialized, reference, name) + + +def test_padding_masks(): + """Test padding mask functionality""" + print("=" * 80) + print("TESTING PADDING MASKS") + print("=" * 80) + + B, H = 2, 1 + MAX_SEQ_LEN = 96 + BLOCK_SIZE = 32 + + # Create different sequence lengths for each batch + seq_lens = torch.tensor([64, 80]) # Different lengths per batch + value_seq_lens = torch.tensor([64, 80]) + + padding_mask_fn = _create_padding_mask_fn(seq_lens, value_seq_lens) + + print("Testing: Padding Mask") + print(f"Sequence lengths: {seq_lens.tolist()}") + + # Create block mask + block_mask = create_block_mask( + padding_mask_fn, + B=B, + H=H, + Q_LEN=MAX_SEQ_LEN, + KV_LEN=MAX_SEQ_LEN, + BLOCK_SIZE=BLOCK_SIZE, + ) + + print(f"Block mask shape: {block_mask.shape}") + + # Materialize + materialized = materialize_block_mask_from_object(block_mask) + + print(f"Materialized shape: {materialized.shape}") + + # Check that padding is correctly masked + for b in range(B): + seq_len = seq_lens[b].item() + val_len = value_seq_lens[b].item() + + # Check that valid positions are potentially unmasked + valid_region = materialized[b, 0, :seq_len, :val_len] + invalid_q = materialized[b, 0, seq_len:, :] + invalid_kv = materialized[b, 0, :, val_len:] + + print( + f"Batch {b}: seq_len={seq_len}, valid_region_any={valid_region.any().item()}" + ) + print( + f"Batch {b}: invalid_q_any={invalid_q.any().item()}, invalid_kv_any={invalid_kv.any().item()}" + ) + + +def test_combined_masks(): + """Test combining multiple masks using and_masks""" + print("=" * 80) + print("TESTING COMBINED MASKS") + print("=" * 80) + + B, H = 1, 1 + Q_LEN, KV_LEN = 96, 96 + BLOCK_SIZE = 32 + + # Create individual masks + causal_fn = _causal_mask_fn + sliding_fn = _sliding_window_causal_mask_fn(3) + + # Test individual masks + print("Testing individual masks...") + + causal_block = create_block_mask( + causal_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE + ) + sliding_block = create_block_mask( + sliding_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE + ) + + causal_mat = materialize_block_mask_from_object(causal_block) + sliding_mat = materialize_block_mask_from_object(sliding_block) + + print(f"Causal sparsity: {causal_mat.sum().item() / causal_mat.numel():.4f}") + print(f"Sliding sparsity: {sliding_mat.sum().item() / sliding_mat.numel():.4f}") + + # Test combined mask + print("\nTesting combined mask...") + combined_fn = and_masks(causal_fn, sliding_fn) + combined_block = create_block_mask( + combined_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE + ) + combined_mat = materialize_block_mask_from_object(combined_block) + + print(f"Combined sparsity: {combined_mat.sum().item() / combined_mat.numel():.4f}") + + # Verify that combined mask is intersection of individual masks + expected_combined = causal_mat & sliding_mat + matches = torch.equal(combined_mat, expected_combined) + print(f"Combined mask matches intersection: {matches}") + + +def test_packed_sequences(): + """Test packed sequence masks""" + print("=" * 80) + print("TESTING PACKED SEQUENCES") + print("=" * 80) + + # Create packed sequence layout + # Simulate 3 documents of lengths [20, 30, 25] + doc_lengths = [20, 30, 25] + seq_begin_indices = torch.tensor( + [0] + [sum(doc_lengths[: i + 1]) for i in range(len(doc_lengths))] + ) + total_len = sum(doc_lengths) + + print(f"Document lengths: {doc_lengths}") + print(f"Begin indices: {seq_begin_indices.tolist()}") + print(f"Total length: {total_len}") + + BLOCK_SIZE = 16 + + # Test packed mask without base mask + print("\nTesting packed mask (no base mask)...") + packed_fn = _create_packed_mask_fn(seq_begin_indices, seq_begin_indices, None) + + packed_block = create_block_mask( + packed_fn, B=1, H=1, Q_LEN=total_len, KV_LEN=total_len, BLOCK_SIZE=BLOCK_SIZE + ) + + packed_mat = materialize_block_mask_from_object(packed_block) + + print(f"Packed mask shape: {packed_mat.shape}") + print(f"Packed mask sparsity: {packed_mat.sum().item() / packed_mat.numel():.4f}") + + # Verify document boundaries + doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) + print(f"Document IDs: {doc_ids[:10].tolist()}...{doc_ids[-10:].tolist()}") + + # Test packed mask with causal base mask + print("\nTesting packed mask with causal base...") + packed_causal_fn = _create_packed_mask_fn( + seq_begin_indices, seq_begin_indices, _causal_mask_fn + ) + + packed_causal_block = create_block_mask( + packed_causal_fn, + B=1, + H=1, + Q_LEN=total_len, + KV_LEN=total_len, + BLOCK_SIZE=BLOCK_SIZE, + ) + + packed_causal_mat = materialize_block_mask_from_object(packed_causal_block) + + print( + f"Packed causal sparsity: {packed_causal_mat.sum().item() / packed_causal_mat.numel():.4f}" + ) + + +def test_edge_cases(): + """Test edge cases and different configurations""" + print("=" * 80) + print("TESTING EDGE CASES") + print("=" * 80) + + test_configs = [ + ( + "Small sequence", + {"B": 1, "H": 1, "Q_LEN": 32, "KV_LEN": 32, "BLOCK_SIZE": 16}, + ), + ("Non-square", {"B": 1, "H": 1, "Q_LEN": 64, "KV_LEN": 96, "BLOCK_SIZE": 16}), + ( + "Multiple heads", + {"B": 1, "H": 4, "Q_LEN": 64, "KV_LEN": 64, "BLOCK_SIZE": 16}, + ), + ( + "Multiple batch", + {"B": 3, "H": 2, "Q_LEN": 48, "KV_LEN": 48, "BLOCK_SIZE": 16}, + ), + ( + "Large block size", + {"B": 1, "H": 1, "Q_LEN": 64, "KV_LEN": 64, "BLOCK_SIZE": 64}, + ), + ] + + for name, config in test_configs: + print(f"\nTesting: {name}") + print(f"Config: {config}") + + block_mask = create_block_mask( + _causal_mask_fn, + B=config["B"], + H=config["H"], + Q_LEN=config["Q_LEN"], + KV_LEN=config["KV_LEN"], + BLOCK_SIZE=config["BLOCK_SIZE"], + ) + + materialized = materialize_block_mask_from_object(block_mask) + reference = create_reference_mask_from_function( + _causal_mask_fn, + (config["B"], config["H"], config["Q_LEN"], config["KV_LEN"]), + config["BLOCK_SIZE"], + ) + + matches = compare_masks(materialized, reference, name) + print(f"✓ Passed: {matches}") + + +def main(): + """Run all tests""" + print("Starting Block Mask Materialization Tests") + print("=" * 80) + + try: + test_basic_masks() + test_padding_masks() + test_combined_masks() + test_packed_sequences() + test_edge_cases() + + print("\n" + "=" * 80) + print("ALL TESTS COMPLETED") + print("=" * 80) + + except Exception as e: + print(f"\nTest failed with error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/tests/unit/models/transformer/sdpa/test_flex_mask_fns.py b/tests/unit/models/transformer/sdpa/test_flex_mask_fns.py new file mode 100644 index 000000000..5081114a1 --- /dev/null +++ b/tests/unit/models/transformer/sdpa/test_flex_mask_fns.py @@ -0,0 +1,165 @@ +import torch +from unittest.mock import patch + +from fairseq2.models.transformer._sdpa._flex import ( + _causal_mask_fn, + _sliding_window_causal_mask_fn, + _offsets_to_doc_ids_tensor, + _create_packed_mask_fn, + _create_padding_mask_fn, + _dropout_mask_fn, +) + + +def test_causal_mask(): + b, h = torch.tensor([0]), torch.tensor([0]) + + # q_idx >= kv_idx should be True + assert _causal_mask_fn(b, h, torch.tensor([2]), torch.tensor([1])).item() + assert _causal_mask_fn(b, h, torch.tensor([2]), torch.tensor([2])).item() + + # q_idx < kv_idx should be False + assert not _causal_mask_fn(b, h, torch.tensor([1]), torch.tensor([2])).item() + + +def test_sliding_window_mask(): + mask_fn = _sliding_window_causal_mask_fn(window_size=2) + b, h = torch.tensor([0]), torch.tensor([0]) + + # Within window and causal + assert mask_fn(b, h, torch.tensor([3]), torch.tensor([2])).item() # distance=1 + assert mask_fn(b, h, torch.tensor([3]), torch.tensor([1])).item() # distance=2 + + # Outside window + assert not mask_fn(b, h, torch.tensor([4]), torch.tensor([1])).item() # distance=3 + + # Non-causal + assert not mask_fn(b, h, torch.tensor([1]), torch.tensor([2])).item() + + +def test_offsets_to_doc_ids(): + offsets = torch.tensor([0, 3, 5]) # Two docs: [0,1,2] and [3,4] + doc_ids = _offsets_to_doc_ids_tensor(offsets) + expected = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32) + assert torch.equal(doc_ids, expected) + + +def test_packed_mask(): + seq_offsets = torch.tensor([0, 3, 6]) # Two docs of length 3 each + key_offsets = torch.tensor([0, 3, 6]) + mask_fn = _create_packed_mask_fn(seq_offsets, key_offsets) + + b, h = torch.tensor([0]), torch.tensor([0]) + + # Same document + assert mask_fn(b, h, torch.tensor([1]), torch.tensor([0])).item() + + # Different documents + assert not mask_fn(b, h, torch.tensor([4]), torch.tensor([1])).item() + + +def test_packed_mask_with_base_mask(): + seq_offsets = torch.tensor([0, 3, 6]) + key_offsets = torch.tensor([0, 3, 6]) + mask_fn = _create_packed_mask_fn(seq_offsets, key_offsets, _causal_mask_fn) + + b, h = torch.tensor([0]), torch.tensor([0]) + + # Same doc, causal + assert mask_fn(b, h, torch.tensor([2]), torch.tensor([1])).item() + + # Same doc, non-causal + assert not mask_fn(b, h, torch.tensor([1]), torch.tensor([2])).item() + + +def test_padding_mask(): + seq_lens = torch.tensor([3, 4]) + value_seq_lens = torch.tensor([2, 4]) + mask_fn = _create_padding_mask_fn(seq_lens, value_seq_lens) + + h = torch.tensor([0]) + + # Valid positions + assert mask_fn(torch.tensor([0]), h, torch.tensor([2]), torch.tensor([1])).item() + + # Invalid query position + assert not mask_fn( + torch.tensor([0]), h, torch.tensor([3]), torch.tensor([1]) + ).item() + + # Invalid key position + assert not mask_fn( + torch.tensor([0]), h, torch.tensor([1]), torch.tensor([2]) + ).item() + + +def test_dropout_mask(): + # No mask when not training or dropout_p=0 + assert _dropout_mask_fn(0.5, training=False) is None + assert _dropout_mask_fn(0.0, training=True) is None + + # Returns callable when training=True and dropout_p > 0 + mask_fn = _dropout_mask_fn(0.5, training=True) + assert callable(mask_fn) + + +@patch("torch.rand") +def test_dropout_mask_behavior(mock_rand): + mask_fn = _dropout_mask_fn(0.3, training=True) + b, h, q, kv = ( + torch.tensor([0]), + torch.tensor([0]), + torch.tensor([0]), + torch.tensor([0]), + ) + + # Below threshold - should mask (False) + mock_rand.return_value = torch.tensor([0.2]) + assert not mask_fn(b, h, q, kv).item() + + # Above threshold - should not mask (True) + mock_rand.return_value = torch.tensor([0.5]) + assert mask_fn(b, h, q, kv).item() + + +def test_causal_mask_large_indices(): + b, h = torch.tensor([0]), torch.tensor([0]) + + # Test with larger indices + q_indices = torch.tensor([100, 50, 75]) + kv_indices = torch.tensor([50, 100, 75]) + + result = _causal_mask_fn(b.expand(3), h.expand(3), q_indices, kv_indices) + expected = torch.tensor([True, False, True]) # [100>=50, 50>=100, 75>=75] + assert torch.equal(result, expected) + + +def test_packed_mask_unequal_offsets(): + # Different sequence and key lengths + seq_offsets = torch.tensor([0, 3, 5]) # [doc0:3, doc1:2] + key_offsets = torch.tensor([0, 2, 6]) # [doc0:2, doc1:4] + mask_fn = _create_packed_mask_fn(seq_offsets, key_offsets) + + b, h = torch.tensor([0]), torch.tensor([0]) + + # Same document, different lengths + assert mask_fn( + b, h, torch.tensor([1]), torch.tensor([0]) + ).item() # doc0 query to doc0 key + assert mask_fn( + b, h, torch.tensor([3]), torch.tensor([2]) + ).item() # doc1 query to doc1 key + + +def test_sliding_window_vectorized(): + mask_fn = _sliding_window_causal_mask_fn(window_size=2) + b = torch.tensor([0, 0, 0, 0]) + h = torch.tensor([0, 0, 0, 0]) + q_idx = torch.tensor([3, 3, 3, 2]) + kv_idx = torch.tensor([1, 2, 3, 4]) # distances: [2, 1, 0, -2] + + result = mask_fn(b, h, q_idx, kv_idx) + expected = torch.tensor( + [True, True, True, False] + ) # [causal+window, causal+window, causal+window, non-causal] + assert torch.equal(result, expected) diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index 9f6b8f83e..dd1b84f0b 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -58,8 +58,8 @@ def test_torch_sdpa( naive_sdpa = NaiveSDPA(attn_bias) if training: - torch_sdpa.eval() - naive_sdpa.eval() + torch_sdpa.train() + naive_sdpa.train() kwargs = self._get_sdpa_args(use_padding, use_packing) @@ -186,16 +186,21 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: class TestFlexScaledDotProductAttention: # fmt: off - @pytest.mark.parametrize("use_padding,use_packing,use_bias,training", + @pytest.mark.parametrize("use_padding,use_packing,use_bias,training,attn_window_len", [ - (False, False, False, True), - (True, False, True, True), - (False, False, True, True), - (True, False, False, True), - (False, True, False, True), - (False, True, True, True), - (False, False, False, False), - (False, False, True, False), + # Original test cases + (False, False, False, True, None), + (True, False, True, True, None), + (False, False, True, True, None), + (True, False, False, True, None), + (False, True, False, True, None), + (False, True, True, True, None), + (False, False, True, True, 1), + (False, False, True, True, 2), + (True, False, True, True, 1), + (False, True, True, True, 1), + (False, False, False, False, None), + (False, False, True, False, None), ], ) # fmt: on @@ -205,11 +210,12 @@ def test_flex_sdpa( use_packing: bool, use_bias: bool, training: bool, + attn_window_len: int | None, ) -> None: attn_bias: AttentionBias if use_bias: - attn_bias = CausalAttentionBias() + attn_bias = CausalAttentionBias(attn_window_len=attn_window_len) else: attn_bias = IdentityBias() @@ -217,8 +223,8 @@ def test_flex_sdpa( naive_sdpa = NaiveSDPA(attn_bias) if training: - flex_sdpa.eval() - naive_sdpa.eval() + flex_sdpa.train() + naive_sdpa.train() kwargs = self._get_sdpa_args(use_padding, use_packing) @@ -247,9 +253,9 @@ def random_tensor(*args: int) -> Tensor: total_target_len = sum([2, 3]) # seq_lens for target total_source_len = sum([2, 3]) # seq_lens for source - q = random_tensor(1, total_target_len, num_heads, k_size) - k = random_tensor(1, total_source_len, num_heads, k_size) - v = random_tensor(1, total_source_len, num_heads, v_size) + q = random_tensor(total_target_len, num_heads, k_size) + k = random_tensor(total_source_len, num_heads, k_size) + v = random_tensor(total_source_len, num_heads, v_size) target_shape = (total_target_len,) source_shape = (total_source_len,) From ca6dd88cc54f9f5119adc19a5bd7d4bb75bc0288 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 14 Jun 2025 21:37:01 +0000 Subject: [PATCH 07/15] adding block mask cache everywhere relevant --- src/fairseq2/models/conformer/_block.py | 8 +- src/fairseq2/models/jepa/classifier/_model.py | 3 + src/fairseq2/models/transformer/__init__.py | 1 + .../models/transformer/_block_mask.py | 302 +++++++++++++++ src/fairseq2/models/transformer/_decoder.py | 3 + .../models/transformer/_decoder_layer.py | 12 +- .../models/transformer/_encoder_layer.py | 8 +- .../transformer/_multihead_attention.py | 12 +- .../models/transformer/_sdpa/_base.py | 2 + .../models/transformer/_sdpa/_flash2.py | 2 + .../models/transformer/_sdpa/_flash3.py | 2 + .../models/transformer/_sdpa/_flex.py | 220 +---------- .../models/transformer/_sdpa/_naive.py | 2 + .../models/transformer/_sdpa/_relative.py | 2 + .../models/transformer/_sdpa/_shaw.py | 2 + .../models/transformer/_sdpa/_torch.py | 2 + .../models/transformer_lm/_decoder.py | 10 +- .../models/transformer_lm/_decoder_layer.py | 9 +- .../models/transformer/sdpa/diagnostics.py | 353 ------------------ .../transformer/sdpa/test_flex_mask_fns.py | 165 -------- .../unit/models/transformer/test_attention.py | 7 + 21 files changed, 396 insertions(+), 731 deletions(-) create mode 100644 src/fairseq2/models/transformer/_block_mask.py delete mode 100644 tests/unit/models/transformer/sdpa/diagnostics.py delete mode 100644 tests/unit/models/transformer/sdpa/test_flex_mask_fns.py diff --git a/src/fairseq2/models/conformer/_block.py b/src/fairseq2/models/conformer/_block.py index e075bcbce..efff3873f 100644 --- a/src/fairseq2/models/conformer/_block.py +++ b/src/fairseq2/models/conformer/_block.py @@ -25,6 +25,7 @@ # isort: split from fairseq2.models.conformer._convolution import ConformerConvolution +from fairseq2.models.transformer._block_mask import BlockMaskCache @final @@ -131,10 +132,13 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: seqs = self._forward_ffn1(seqs) - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache + ) seqs = self._forward_conv(seqs, seqs_layout) @@ -161,6 +165,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: residual = seqs @@ -173,6 +178,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, ) if self.self_attn_dropout is not None: diff --git a/src/fairseq2/models/jepa/classifier/_model.py b/src/fairseq2/models/jepa/classifier/_model.py index f58deb56d..21a752264 100644 --- a/src/fairseq2/models/jepa/classifier/_model.py +++ b/src/fairseq2/models/jepa/classifier/_model.py @@ -18,6 +18,7 @@ from fairseq2.device import Device from fairseq2.models.transformer import ( AttentionBiasCache, + BlockMaskCache, FeedForwardNetwork, MultiheadAttention, TransformerEncoder, @@ -221,6 +222,7 @@ def _forward_cross_attn( encoder_output = self.cross_attn_layer_norm(encoder_output) attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() seqs = self.cross_attn( seqs, @@ -229,6 +231,7 @@ def _forward_cross_attn( keys_layout=encoder_output_layout, values=encoder_output, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, ) seqs = seqs + residual diff --git a/src/fairseq2/models/transformer/__init__.py b/src/fairseq2/models/transformer/__init__.py index cfe6139a2..fabba70f9 100644 --- a/src/fairseq2/models/transformer/__init__.py +++ b/src/fairseq2/models/transformer/__init__.py @@ -23,6 +23,7 @@ from fairseq2.models.transformer._attention_bias import ( maybe_get_attention_bias_tensor as maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache as BlockMaskCache from fairseq2.models.transformer._checkpoint import ( convert_transformer_checkpoint as convert_transformer_checkpoint, ) diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py new file mode 100644 index 000000000..77263ab88 --- /dev/null +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -0,0 +1,302 @@ +from dataclasses import dataclass +from typing import Callable, TypeAlias + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import ( + BlockMask, + and_masks, + create_block_mask, +) + +from fairseq2.device import Device +from fairseq2.error import NotSupportedError +from fairseq2.nn import BatchLayout + +# isort: split + +from fairseq2.models.transformer._attention_bias import ( + AttentionBias, + CausalAttentionBias, + IdentityBias, +) + +MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] + +BLOCK_MASK_CACHE_MAX_SIZE = 1000 + + +def _causal_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: + """Creates a causal mask function.""" + + def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Get sequence lengths for this batch + q_len = q_lens[b] + kv_len = kv_lens[b] + + # Calculate diagonal offset + d = kv_len - q_len + + return q_idx >= kv_idx - d + + return mask_fn + + +def _sliding_window_causal_mask_fn( + window_size: int, + q_lens: Tensor, + kv_lens: Tensor, +) -> MaskFunction: + """Creates a sliding window causal mask functions.""" + + def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Get sequence lengths for this batch + q_len = q_lens[b] + kv_len = kv_lens[b] + + # Calculate diagonal offset + d = kv_len - q_len + + # For window_size=1, only allow the exact diagonal position + if window_size == 1: + return q_idx == kv_idx - d + else: + # For larger windows, use the range logic + causal_mask = q_idx >= kv_idx - d + window_mask = q_idx >= kv_idx - d - window_size + 1 + return causal_mask & window_mask + + return mask_fn + + +def _offsets_to_doc_ids_tensor(offsets: Tensor) -> Tensor: + """Convert offsets to document IDs for packed sequences.""" + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave( + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) + + +def _create_packed_mask_fn( + seq_begin_indices: Tensor, + keys_begin_indices: Tensor, + base_mask_fn: MaskFunction | None = None, +) -> MaskFunction: + """Creates a mask function for packed sequences using document-based masking.""" + # Create document IDs for queries and keys + query_doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) + key_doc_ids = _offsets_to_doc_ids_tensor(keys_begin_indices) + + def packed_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + # Check if query and key belong to the same document + same_doc = query_doc_ids[q_idx] == key_doc_ids[kv_idx] + + # Convert global indices to logical positions within documents + q_doc_id = query_doc_ids[q_idx] + kv_doc_id = key_doc_ids[kv_idx] + q_logical = q_idx - seq_begin_indices[q_doc_id] + kv_logical = kv_idx - keys_begin_indices[kv_doc_id] + + # Apply base mask (e.g., causal) to logical positions + if base_mask_fn is not None: + inner_mask = base_mask_fn(b, h, q_logical, kv_logical) + return same_doc & inner_mask + else: + return same_doc + + return packed_mask_fn + + +def _create_padding_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: + """Creates a padding mask function that masks out padding tokens.""" + + def padding_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + q_valid = q_idx < q_lens[b] + kv_valid = kv_idx < kv_lens[b] + return q_valid & kv_valid + + return padding_mask_fn + + +def _create_composed_mask( + bias: AttentionBias, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + device: Device, +) -> BlockMask | None: + """Creates a composed mask using and_masks for combining multiple mask functions.""" + masks = [] + + if seqs_layout.packed: + # For packed sequences, create the base mask function first + base_mask_fn = None + + # Add attention bias mask as base mask + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + base_mask_fn = _sliding_window_causal_mask_fn( + attn_window_len, + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + else: + base_mask_fn = _causal_mask_fn( + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + # Create the packed sequence mask that incorporates the base mask + packed_mask = _create_packed_mask_fn( + seqs_layout.seq_begin_indices_pt, + keys_layout.seq_begin_indices_pt, + base_mask_fn, + ) + masks.append(packed_mask) + else: + # Standard batch format - handle bias and padding separately + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + if attn_window_len is not None: + masks.append( + _sliding_window_causal_mask_fn( + attn_window_len, + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + ) + else: + masks.append( + _causal_mask_fn( + seqs_layout.seq_lens_pt, + keys_layout.seq_lens_pt, + ) + ) + elif not isinstance(bias, IdentityBias): + raise NotSupportedError(f"Unsupported bias type: {bias}") + + # Add padding mask + if seqs_layout.padded or keys_layout.padded: + masks.append( + _create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt) + ) + + # Compose masks + mask_fn = None + if len(masks) == 0: + return None + elif len(masks) == 1: + mask_fn = masks[0] + else: + mask_fn = and_masks(*masks) + + if seqs_layout.packed: + total_seq_len = int(seqs_layout.seq_begin_indices_pt[-1].item()) + total_keys_len = int(keys_layout.seq_begin_indices_pt[-1].item()) + batch_size = 1 + else: + total_seq_len = seqs_layout.max_seq_len + total_keys_len = keys_layout.max_seq_len + batch_size = len(seqs_layout.seq_lens) + + # Create the block mask + block_mask = create_block_mask( + mask_fn, + B=batch_size, + H=None, + Q_LEN=total_seq_len, + KV_LEN=total_keys_len, + device=str(device), + ) + return block_mask + + +@dataclass +class BlockMaskCacheKey: + """Key for caching block masks.""" + + bias_type: str + batch_size: int + seq_len: int + keys_len: int + packed: bool + attn_window_len: int | None = None + + def __hash__(self) -> int: + return hash( + ( + self.bias_type, + self.batch_size, + self.seq_len, + self.keys_len, + self.packed, + self.attn_window_len, + ) + ) + + +class BlockMaskCache: + """ + Cache for block masks to avoid recomputation across layers and possibly training + steps. + """ + + def __init__(self): + self._cache: dict[BlockMaskCacheKey, BlockMask | None] = {} + + def get_or_create_mask( + self, + bias: AttentionBias, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + device: Device, + ) -> BlockMask | None: + """Get cached mask or create new one.""" + + # Create cache key + bias_type = type(bias).__name__ + attn_window_len = None + if isinstance(bias, CausalAttentionBias): + attn_window_len = bias.attn_window_len + + if seqs_layout.packed: + batch_size = 1 + seq_len = int(seqs_layout.seq_begin_indices[-1]) + keys_len = int(keys_layout.seq_begin_indices[-1]) + else: + batch_size = len(seqs_layout.seq_lens) + seq_len = seqs_layout.max_seq_len + keys_len = keys_layout.max_seq_len + + cache_key = BlockMaskCacheKey( + bias_type=bias_type, + batch_size=batch_size, + seq_len=seq_len, + keys_len=keys_len, + packed=seqs_layout.packed, + attn_window_len=attn_window_len, + ) + + # Check cache first + if cache_key in self._cache: + return self._cache[cache_key] + + # Create new mask + block_mask = _create_composed_mask( + bias, + seqs_layout, + keys_layout, + device, + ) + + if len(self._cache) < BLOCK_MASK_CACHE_MAX_SIZE: + self._cache[cache_key] = block_mask + + return block_mask + + def clear(self): + """Clear the cache.""" + self._cache.clear() diff --git a/src/fairseq2/models/transformer/_decoder.py b/src/fairseq2/models/transformer/_decoder.py index 071c7e3fb..242e2805a 100644 --- a/src/fairseq2/models/transformer/_decoder.py +++ b/src/fairseq2/models/transformer/_decoder.py @@ -24,6 +24,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._decoder_layer import TransformerDecoderLayer from fairseq2.models.transformer._encoder import _record_drop_for_backward @@ -167,6 +168,7 @@ def forward( ) attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() num_layers = len(self.layers) @@ -177,6 +179,7 @@ def forward( encoder_output, encoder_output_layout, attn_bias_cache, + block_mask_cache, state_bag=state_bag, ) diff --git a/src/fairseq2/models/transformer/_decoder_layer.py b/src/fairseq2/models/transformer/_decoder_layer.py index 2c0279454..b5c3f78f5 100644 --- a/src/fairseq2/models/transformer/_decoder_layer.py +++ b/src/fairseq2/models/transformer/_decoder_layer.py @@ -26,6 +26,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._ffn import FeedForwardNetwork from fairseq2.models.transformer._multihead_attention import MultiheadAttention from fairseq2.models.transformer._norm_order import TransformerNormOrder @@ -42,6 +43,7 @@ def forward( encoder_output: Tensor, encoder_output_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -195,10 +197,13 @@ def forward( encoder_output: Tensor, encoder_output_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache, state_bag) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache, state_bag + ) seqs = self._forward_encoder_decoder_attn( seqs, @@ -206,6 +211,7 @@ def forward( encoder_output, encoder_output_layout, attn_bias_cache, + block_mask_cache, state_bag, ) @@ -218,6 +224,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -232,6 +239,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, state_bag=state_bag, ) @@ -252,6 +260,7 @@ def _forward_encoder_decoder_attn( encoder_output: Tensor, encoder_output_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -266,6 +275,7 @@ def _forward_encoder_decoder_attn( keys_layout=encoder_output_layout, values=encoder_output, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, state_bag=state_bag, ) diff --git a/src/fairseq2/models/transformer/_encoder_layer.py b/src/fairseq2/models/transformer/_encoder_layer.py index 458736c13..9f0eed755 100644 --- a/src/fairseq2/models/transformer/_encoder_layer.py +++ b/src/fairseq2/models/transformer/_encoder_layer.py @@ -25,6 +25,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._ffn import FeedForwardNetwork from fairseq2.models.transformer._multihead_attention import MultiheadAttention from fairseq2.models.transformer._norm_order import TransformerNormOrder @@ -148,8 +149,11 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache + ) seqs = self._forward_ffn(seqs) @@ -160,6 +164,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: residual = seqs @@ -173,6 +178,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, ) if self.self_attn_dropout is not None: diff --git a/src/fairseq2/models/transformer/_multihead_attention.py b/src/fairseq2/models/transformer/_multihead_attention.py index c5ca38e51..eceaa2311 100644 --- a/src/fairseq2/models/transformer/_multihead_attention.py +++ b/src/fairseq2/models/transformer/_multihead_attention.py @@ -36,6 +36,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -58,6 +59,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -393,6 +395,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -463,7 +466,14 @@ def forward( # attns: (N, S, H, V_h) # attn_weights: (N, H, S, S_kv) attns, attn_weights = self.sdpa( - q, seqs_layout, k, keys_layout, v, bias_cache, needs_weights=needs_weights + q, + seqs_layout, + k, + keys_layout, + v, + bias_cache, + block_mask_cache, + needs_weights=needs_weights, ) del q diff --git a/src/fairseq2/models/transformer/_sdpa/_base.py b/src/fairseq2/models/transformer/_sdpa/_base.py index 42d34795f..24c354835 100644 --- a/src/fairseq2/models/transformer/_sdpa/_base.py +++ b/src/fairseq2/models/transformer/_sdpa/_base.py @@ -13,6 +13,7 @@ from torch.nn import Module from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.nn import BatchLayout @@ -28,6 +29,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_flash2.py b/src/fairseq2/models/transformer/_sdpa/_flash2.py index 74ac0a034..9c783e62f 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flash2.py +++ b/src/fairseq2/models/transformer/_sdpa/_flash2.py @@ -32,6 +32,7 @@ CausalAttentionBias, IdentityBias, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -58,6 +59,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_flash3.py b/src/fairseq2/models/transformer/_sdpa/_flash3.py index 7f1f792bb..5bbb23a68 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flash3.py +++ b/src/fairseq2/models/transformer/_sdpa/_flash3.py @@ -30,6 +30,7 @@ CausalAttentionBias, IdentityBias, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -56,6 +57,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index e75ead298..9a4f5fcac 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -8,18 +8,11 @@ from typing import Callable, TypeAlias, final -import torch from torch import Tensor -from torch.nn.attention.flex_attention import ( - BlockMask, - and_masks, - create_block_mask, - flex_attention, -) +from torch.nn.attention.flex_attention import flex_attention from typing_extensions import override -from fairseq2.device import Device -from fairseq2.error import NotSupportedError +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.nn import BatchLayout # isort: split @@ -27,216 +20,24 @@ from fairseq2.models.transformer._attention_bias import ( AttentionBias, AttentionBiasCache, - CausalAttentionBias, - IdentityBias, ) from fairseq2.models.transformer._sdpa._base import SDPA MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -def _causal_mask_fn( - q_lens: Tensor, kv_lens: Tensor -) -> MaskFunction: - """Creates a causal mask function.""" - - def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - # Get sequence lengths for this batch - q_len = q_lens[b] - kv_len = kv_lens[b] - - # Calculate diagonal offset - d = kv_len - q_len - - return q_idx >= kv_idx - d - - return mask_fn - - -def _sliding_window_causal_mask_fn( - window_size: int, - q_lens: Tensor, - kv_lens: Tensor, -) -> MaskFunction: - """Creates a sliding window causal mask functions.""" - - def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - # Get sequence lengths for this batch - q_len = q_lens[b] - kv_len = kv_lens[b] - - # Calculate diagonal offset - d = kv_len - q_len - - # For window_size=1, only allow the exact diagonal position - if window_size == 1: - return q_idx == kv_idx - d - else: - # For larger windows, use the range logic - causal_mask = q_idx >= kv_idx - d - window_mask = q_idx >= kv_idx - d - window_size + 1 - return causal_mask & window_mask - - return mask_fn - - -def _offsets_to_doc_ids_tensor(offsets: Tensor) -> Tensor: - """Convert offsets to document IDs for packed sequences.""" - device = offsets.device - counts = offsets[1:] - offsets[:-1] - return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts - ) - - -def _create_packed_mask_fn( - seq_begin_indices: Tensor, - keys_begin_indices: Tensor, - base_mask_fn: MaskFunction | None = None, -) -> MaskFunction: - """Creates a mask function for packed sequences using document-based masking.""" - # Create document IDs for queries and keys - query_doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) - key_doc_ids = _offsets_to_doc_ids_tensor(keys_begin_indices) - - def packed_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - # Check if query and key belong to the same document - same_doc = query_doc_ids[q_idx] == key_doc_ids[kv_idx] - - # Convert global indices to logical positions within documents - q_doc_id = query_doc_ids[q_idx] - kv_doc_id = key_doc_ids[kv_idx] - q_logical = q_idx - seq_begin_indices[q_doc_id] - kv_logical = kv_idx - keys_begin_indices[kv_doc_id] - - # Apply base mask (e.g., causal) to logical positions - if base_mask_fn is not None: - inner_mask = base_mask_fn(b, h, q_logical, kv_logical) - return same_doc & inner_mask - else: - return same_doc - - return packed_mask_fn - - -def _create_padding_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: - """Creates a padding mask function that masks out padding tokens.""" - - def padding_mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: - q_valid = q_idx < q_lens[b] - kv_valid = kv_idx < kv_lens[b] - return q_valid & kv_valid - - return padding_mask_fn - - -def _create_composed_mask( - bias: AttentionBias, - seqs_layout: BatchLayout, - keys_layout: BatchLayout, - device: Device, - *, - packed: bool = False, -) -> BlockMask | None: - """Creates a composed mask using and_masks for combining multiple mask functions.""" - masks = [] - - if packed: - # For packed sequences, create the base mask function first - base_mask_fn = None - - # Add attention bias mask as base mask - if isinstance(bias, CausalAttentionBias): - attn_window_len = bias.attn_window_len - if attn_window_len is not None: - base_mask_fn = _sliding_window_causal_mask_fn( - attn_window_len, - seqs_layout.seq_lens_pt, - keys_layout.seq_lens_pt, - ) - else: - base_mask_fn = _causal_mask_fn( - seqs_layout.seq_lens_pt, - keys_layout.seq_lens_pt, - ) - elif not isinstance(bias, IdentityBias): - raise NotSupportedError(f"Unsupported bias type: {bias}") - - # Create the packed sequence mask that incorporates the base mask - packed_mask = _create_packed_mask_fn( - seqs_layout.seq_begin_indices_pt, - keys_layout.seq_begin_indices_pt, - base_mask_fn, - ) - masks.append(packed_mask) - else: - # Standard batch format - handle bias and padding separately - if isinstance(bias, CausalAttentionBias): - attn_window_len = bias.attn_window_len - if attn_window_len is not None: - masks.append( - _sliding_window_causal_mask_fn( - attn_window_len, - seqs_layout.seq_lens_pt, - keys_layout.seq_lens_pt, - ) - ) - else: - masks.append( - _causal_mask_fn( - seqs_layout.seq_lens_pt, - keys_layout.seq_lens_pt, - ) - ) - elif not isinstance(bias, IdentityBias): - raise NotSupportedError(f"Unsupported bias type: {bias}") - - # Add padding mask - if seqs_layout.padded or keys_layout.padded: - masks.append( - _create_padding_mask_fn(seqs_layout.seq_lens_pt, keys_layout.seq_lens_pt) - ) - - # Compose masks - mask_fn = None - if len(masks) == 0: - return None - elif len(masks) == 1: - mask_fn = masks[0] - else: - mask_fn = and_masks(*masks) - - if packed: - total_seq_len = int(seqs_layout.seq_begin_indices_pt[-1].item()) - total_keys_len = int(keys_layout.seq_begin_indices_pt[-1].item()) - batch_size = 1 - else: - total_seq_len = seqs_layout.max_seq_len - total_keys_len = keys_layout.max_seq_len - batch_size = len(seqs_layout.seq_lens) - - # Create the block mask - block_mask = create_block_mask( - mask_fn, - B=batch_size, - H=None, - Q_LEN=total_seq_len, - KV_LEN=total_keys_len, - device=str(device), - ) - return block_mask - - @final class FlexSDPA(SDPA): """Computes scaled dot-product attention using PyTorch's Flex Attention.""" bias: AttentionBias + dropout_p: float - def __init__(self, bias: AttentionBias) -> None: + def __init__(self, bias: AttentionBias, *, dropout_p: float = 0.0) -> None: super().__init__() self.bias = bias + self.dropout_p = dropout_p @override def forward( @@ -247,6 +48,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: @@ -260,13 +62,12 @@ def forward( keys = keys.unsqueeze(0) values = values.unsqueeze(0) - # Create the composed block mask using and_mask - block_mask = _create_composed_mask( + # Create the composed block mask using and_masks + block_mask = block_mask_cache.get_or_create_mask( self.bias, seqs_layout, keys_layout, seqs.device, - packed=seqs_layout.packed, ) seqs = seqs.transpose(1, 2) @@ -289,3 +90,8 @@ def forward( attns = attns.squeeze(0) return attns, None + + @override + def extra_repr(self) -> str: + """:meta private:""" + return f"bias={self.bias}, dropout_p={self.dropout_p:G}" diff --git a/src/fairseq2/models/transformer/_sdpa/_naive.py b/src/fairseq2/models/transformer/_sdpa/_naive.py index 1d2373ae5..16c387bb2 100644 --- a/src/fairseq2/models/transformer/_sdpa/_naive.py +++ b/src/fairseq2/models/transformer/_sdpa/_naive.py @@ -22,6 +22,7 @@ AttentionBiasCache, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -48,6 +49,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_relative.py b/src/fairseq2/models/transformer/_sdpa/_relative.py index de9fa0b6b..5cf432b29 100644 --- a/src/fairseq2/models/transformer/_sdpa/_relative.py +++ b/src/fairseq2/models/transformer/_sdpa/_relative.py @@ -28,6 +28,7 @@ AttentionBiasCache, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA from fairseq2.models.transformer._sdpa._naive import ( naive_scaled_dot_product_attention, @@ -100,6 +101,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_shaw.py b/src/fairseq2/models/transformer/_sdpa/_shaw.py index 359df7834..1a503c8da 100644 --- a/src/fairseq2/models/transformer/_sdpa/_shaw.py +++ b/src/fairseq2/models/transformer/_sdpa/_shaw.py @@ -25,6 +25,7 @@ AttentionBiasCache, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA from fairseq2.models.transformer._sdpa._naive import ( naive_scaled_dot_product_attention, @@ -113,6 +114,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer/_sdpa/_torch.py b/src/fairseq2/models/transformer/_sdpa/_torch.py index 8c661d0e7..c85dd632b 100644 --- a/src/fairseq2/models/transformer/_sdpa/_torch.py +++ b/src/fairseq2/models/transformer/_sdpa/_torch.py @@ -23,6 +23,7 @@ CausalAttentionBias, maybe_get_attention_bias_tensor, ) +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._sdpa._base import SDPA @@ -49,6 +50,7 @@ def forward( keys_layout: BatchLayout, values: Tensor, bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, needs_weights: bool = False, ) -> tuple[Tensor, Tensor | None]: diff --git a/src/fairseq2/models/transformer_lm/_decoder.py b/src/fairseq2/models/transformer_lm/_decoder.py index fe7effc86..d20cea244 100644 --- a/src/fairseq2/models/transformer_lm/_decoder.py +++ b/src/fairseq2/models/transformer_lm/_decoder.py @@ -21,6 +21,7 @@ # isort: split +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer_lm._decoder_layer import TransformerLMDecoderLayer @@ -132,11 +133,18 @@ def forward( state_bag: IncrementalStateBag | None = None, ) -> Tensor: attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() num_layers = len(self.layers) for layer_idx, layer in enumerate(self.layers): - seqs = layer(seqs, seqs_layout, attn_bias_cache, state_bag=state_bag) + seqs = layer( + seqs, + seqs_layout, + attn_bias_cache, + block_mask_cache, + state_bag=state_bag, + ) for hook in self._layer_hooks.values(): if not hook(layer_idx, seqs, seqs_layout, num_layers): diff --git a/src/fairseq2/models/transformer_lm/_decoder_layer.py b/src/fairseq2/models/transformer_lm/_decoder_layer.py index 0139cd870..704c86fd6 100644 --- a/src/fairseq2/models/transformer_lm/_decoder_layer.py +++ b/src/fairseq2/models/transformer_lm/_decoder_layer.py @@ -17,6 +17,7 @@ from fairseq2.device import Device from fairseq2.models.transformer import ( AttentionBiasCache, + BlockMaskCache, FeedForwardNetwork, MultiheadAttention, TransformerNormOrder, @@ -39,6 +40,7 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: @@ -146,10 +148,13 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, *, state_bag: IncrementalStateBag | None = None, ) -> Tensor: - seqs = self._forward_self_attn(seqs, seqs_layout, attn_bias_cache, state_bag) + seqs = self._forward_self_attn( + seqs, seqs_layout, attn_bias_cache, block_mask_cache, state_bag + ) seqs = self._forward_ffn(seqs) @@ -160,6 +165,7 @@ def _forward_self_attn( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, state_bag: IncrementalStateBag | None, ) -> Tensor: residual = seqs @@ -174,6 +180,7 @@ def _forward_self_attn( keys_layout=seqs_layout, values=seqs, bias_cache=attn_bias_cache, + block_mask_cache=block_mask_cache, state_bag=state_bag, ) diff --git a/tests/unit/models/transformer/sdpa/diagnostics.py b/tests/unit/models/transformer/sdpa/diagnostics.py deleted file mode 100644 index 4b667ee7f..000000000 --- a/tests/unit/models/transformer/sdpa/diagnostics.py +++ /dev/null @@ -1,353 +0,0 @@ -#!/usr/bin/env python3 -""" -Script to test materialize_block_mask_from_object with various block masks. -Tests different mask patterns and validates correctness using the existing functions. -""" - -import torch -from torch.nn.attention.flex_attention import create_block_mask, and_masks - -# Import the functions from the module (assuming they're available) -from fairseq2.models.transformer._sdpa._flex import ( - materialize_block_mask_from_object, - _causal_mask_fn, - _sliding_window_causal_mask_fn, - _create_padding_mask_fn, - _create_packed_mask_fn, - _offsets_to_doc_ids_tensor, -) - - -def create_reference_mask_from_function(mask_fn, shape, block_size): - """Create reference mask by applying mask function at block level""" - B, H, Q_LEN, KV_LEN = shape - Q_BLOCKS = (Q_LEN + block_size - 1) // block_size - KV_BLOCKS = (KV_LEN + block_size - 1) // block_size - - mask = torch.zeros(B, H, Q_LEN, KV_LEN, dtype=torch.bool) - - for b in range(B): - for h in range(H): - for q_block in range(Q_BLOCKS): - for kv_block in range(KV_BLOCKS): - if mask_fn( - torch.tensor(b), - torch.tensor(h), - torch.tensor(q_block), - torch.tensor(kv_block), - ): - # Fill the entire block - q_start = q_block * block_size - q_end = min(q_start + block_size, Q_LEN) - kv_start = kv_block * block_size - kv_end = min(kv_start + block_size, KV_LEN) - - mask[b, h, q_start:q_end, kv_start:kv_end] = True - - return mask - - -def compare_masks(materialized_mask, reference_mask, name): - """Compare materialized mask with reference mask""" - print(f"\n--- Comparison for {name} ---") - print(f"Materialized shape: {materialized_mask.shape}") - print(f"Reference shape: {reference_mask.shape}") - - if materialized_mask.shape == reference_mask.shape: - matches = torch.equal(materialized_mask, reference_mask) - print(f"Masks match: {matches}") - - if not matches: - diff = materialized_mask != reference_mask - print(f"Number of differences: {diff.sum().item()}") - print(f"Total elements: {materialized_mask.numel()}") - print( - f"Difference percentage: {100 * diff.sum().item() / materialized_mask.numel():.2f}%" - ) - - # Sparsity comparison - mat_sparsity = materialized_mask.sum().item() / materialized_mask.numel() - ref_sparsity = reference_mask.sum().item() / reference_mask.numel() - print(f"Materialized sparsity: {mat_sparsity:.4f}") - print(f"Reference sparsity: {ref_sparsity:.4f}") - - return matches - else: - print("Shape mismatch!") - return False - - -def test_basic_masks(): - """Test basic mask functions""" - print("=" * 80) - print("TESTING BASIC MASK FUNCTIONS") - print("=" * 80) - - B, H = 1, 1 - Q_LEN, KV_LEN = 128, 128 - BLOCK_SIZE = 32 - - test_cases = [ - ("Causal Mask", _causal_mask_fn), - ("Sliding Window (window=2)", _sliding_window_causal_mask_fn(2)), - ("Sliding Window (window=5)", _sliding_window_causal_mask_fn(5)), - ] - - for name, mask_fn in test_cases: - print(f"\nTesting: {name}") - print("-" * 40) - - # Create block mask - block_mask = create_block_mask( - mask_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE - ) - - print(f"Block mask shape: {block_mask.shape}") - print(f"Block mask sparsity: {block_mask.sparsity():.2f}%") - - # Materialize using our function - materialized = materialize_block_mask_from_object(block_mask) - - # Create reference - reference = create_reference_mask_from_function( - mask_fn, (B, H, Q_LEN, KV_LEN), BLOCK_SIZE - ) - - # Compare - matches = compare_masks(materialized, reference, name) - - -def test_padding_masks(): - """Test padding mask functionality""" - print("=" * 80) - print("TESTING PADDING MASKS") - print("=" * 80) - - B, H = 2, 1 - MAX_SEQ_LEN = 96 - BLOCK_SIZE = 32 - - # Create different sequence lengths for each batch - seq_lens = torch.tensor([64, 80]) # Different lengths per batch - value_seq_lens = torch.tensor([64, 80]) - - padding_mask_fn = _create_padding_mask_fn(seq_lens, value_seq_lens) - - print("Testing: Padding Mask") - print(f"Sequence lengths: {seq_lens.tolist()}") - - # Create block mask - block_mask = create_block_mask( - padding_mask_fn, - B=B, - H=H, - Q_LEN=MAX_SEQ_LEN, - KV_LEN=MAX_SEQ_LEN, - BLOCK_SIZE=BLOCK_SIZE, - ) - - print(f"Block mask shape: {block_mask.shape}") - - # Materialize - materialized = materialize_block_mask_from_object(block_mask) - - print(f"Materialized shape: {materialized.shape}") - - # Check that padding is correctly masked - for b in range(B): - seq_len = seq_lens[b].item() - val_len = value_seq_lens[b].item() - - # Check that valid positions are potentially unmasked - valid_region = materialized[b, 0, :seq_len, :val_len] - invalid_q = materialized[b, 0, seq_len:, :] - invalid_kv = materialized[b, 0, :, val_len:] - - print( - f"Batch {b}: seq_len={seq_len}, valid_region_any={valid_region.any().item()}" - ) - print( - f"Batch {b}: invalid_q_any={invalid_q.any().item()}, invalid_kv_any={invalid_kv.any().item()}" - ) - - -def test_combined_masks(): - """Test combining multiple masks using and_masks""" - print("=" * 80) - print("TESTING COMBINED MASKS") - print("=" * 80) - - B, H = 1, 1 - Q_LEN, KV_LEN = 96, 96 - BLOCK_SIZE = 32 - - # Create individual masks - causal_fn = _causal_mask_fn - sliding_fn = _sliding_window_causal_mask_fn(3) - - # Test individual masks - print("Testing individual masks...") - - causal_block = create_block_mask( - causal_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE - ) - sliding_block = create_block_mask( - sliding_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE - ) - - causal_mat = materialize_block_mask_from_object(causal_block) - sliding_mat = materialize_block_mask_from_object(sliding_block) - - print(f"Causal sparsity: {causal_mat.sum().item() / causal_mat.numel():.4f}") - print(f"Sliding sparsity: {sliding_mat.sum().item() / sliding_mat.numel():.4f}") - - # Test combined mask - print("\nTesting combined mask...") - combined_fn = and_masks(causal_fn, sliding_fn) - combined_block = create_block_mask( - combined_fn, B=B, H=H, Q_LEN=Q_LEN, KV_LEN=KV_LEN, BLOCK_SIZE=BLOCK_SIZE - ) - combined_mat = materialize_block_mask_from_object(combined_block) - - print(f"Combined sparsity: {combined_mat.sum().item() / combined_mat.numel():.4f}") - - # Verify that combined mask is intersection of individual masks - expected_combined = causal_mat & sliding_mat - matches = torch.equal(combined_mat, expected_combined) - print(f"Combined mask matches intersection: {matches}") - - -def test_packed_sequences(): - """Test packed sequence masks""" - print("=" * 80) - print("TESTING PACKED SEQUENCES") - print("=" * 80) - - # Create packed sequence layout - # Simulate 3 documents of lengths [20, 30, 25] - doc_lengths = [20, 30, 25] - seq_begin_indices = torch.tensor( - [0] + [sum(doc_lengths[: i + 1]) for i in range(len(doc_lengths))] - ) - total_len = sum(doc_lengths) - - print(f"Document lengths: {doc_lengths}") - print(f"Begin indices: {seq_begin_indices.tolist()}") - print(f"Total length: {total_len}") - - BLOCK_SIZE = 16 - - # Test packed mask without base mask - print("\nTesting packed mask (no base mask)...") - packed_fn = _create_packed_mask_fn(seq_begin_indices, seq_begin_indices, None) - - packed_block = create_block_mask( - packed_fn, B=1, H=1, Q_LEN=total_len, KV_LEN=total_len, BLOCK_SIZE=BLOCK_SIZE - ) - - packed_mat = materialize_block_mask_from_object(packed_block) - - print(f"Packed mask shape: {packed_mat.shape}") - print(f"Packed mask sparsity: {packed_mat.sum().item() / packed_mat.numel():.4f}") - - # Verify document boundaries - doc_ids = _offsets_to_doc_ids_tensor(seq_begin_indices) - print(f"Document IDs: {doc_ids[:10].tolist()}...{doc_ids[-10:].tolist()}") - - # Test packed mask with causal base mask - print("\nTesting packed mask with causal base...") - packed_causal_fn = _create_packed_mask_fn( - seq_begin_indices, seq_begin_indices, _causal_mask_fn - ) - - packed_causal_block = create_block_mask( - packed_causal_fn, - B=1, - H=1, - Q_LEN=total_len, - KV_LEN=total_len, - BLOCK_SIZE=BLOCK_SIZE, - ) - - packed_causal_mat = materialize_block_mask_from_object(packed_causal_block) - - print( - f"Packed causal sparsity: {packed_causal_mat.sum().item() / packed_causal_mat.numel():.4f}" - ) - - -def test_edge_cases(): - """Test edge cases and different configurations""" - print("=" * 80) - print("TESTING EDGE CASES") - print("=" * 80) - - test_configs = [ - ( - "Small sequence", - {"B": 1, "H": 1, "Q_LEN": 32, "KV_LEN": 32, "BLOCK_SIZE": 16}, - ), - ("Non-square", {"B": 1, "H": 1, "Q_LEN": 64, "KV_LEN": 96, "BLOCK_SIZE": 16}), - ( - "Multiple heads", - {"B": 1, "H": 4, "Q_LEN": 64, "KV_LEN": 64, "BLOCK_SIZE": 16}, - ), - ( - "Multiple batch", - {"B": 3, "H": 2, "Q_LEN": 48, "KV_LEN": 48, "BLOCK_SIZE": 16}, - ), - ( - "Large block size", - {"B": 1, "H": 1, "Q_LEN": 64, "KV_LEN": 64, "BLOCK_SIZE": 64}, - ), - ] - - for name, config in test_configs: - print(f"\nTesting: {name}") - print(f"Config: {config}") - - block_mask = create_block_mask( - _causal_mask_fn, - B=config["B"], - H=config["H"], - Q_LEN=config["Q_LEN"], - KV_LEN=config["KV_LEN"], - BLOCK_SIZE=config["BLOCK_SIZE"], - ) - - materialized = materialize_block_mask_from_object(block_mask) - reference = create_reference_mask_from_function( - _causal_mask_fn, - (config["B"], config["H"], config["Q_LEN"], config["KV_LEN"]), - config["BLOCK_SIZE"], - ) - - matches = compare_masks(materialized, reference, name) - print(f"✓ Passed: {matches}") - - -def main(): - """Run all tests""" - print("Starting Block Mask Materialization Tests") - print("=" * 80) - - try: - test_basic_masks() - test_padding_masks() - test_combined_masks() - test_packed_sequences() - test_edge_cases() - - print("\n" + "=" * 80) - print("ALL TESTS COMPLETED") - print("=" * 80) - - except Exception as e: - print(f"\nTest failed with error: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/tests/unit/models/transformer/sdpa/test_flex_mask_fns.py b/tests/unit/models/transformer/sdpa/test_flex_mask_fns.py deleted file mode 100644 index 5081114a1..000000000 --- a/tests/unit/models/transformer/sdpa/test_flex_mask_fns.py +++ /dev/null @@ -1,165 +0,0 @@ -import torch -from unittest.mock import patch - -from fairseq2.models.transformer._sdpa._flex import ( - _causal_mask_fn, - _sliding_window_causal_mask_fn, - _offsets_to_doc_ids_tensor, - _create_packed_mask_fn, - _create_padding_mask_fn, - _dropout_mask_fn, -) - - -def test_causal_mask(): - b, h = torch.tensor([0]), torch.tensor([0]) - - # q_idx >= kv_idx should be True - assert _causal_mask_fn(b, h, torch.tensor([2]), torch.tensor([1])).item() - assert _causal_mask_fn(b, h, torch.tensor([2]), torch.tensor([2])).item() - - # q_idx < kv_idx should be False - assert not _causal_mask_fn(b, h, torch.tensor([1]), torch.tensor([2])).item() - - -def test_sliding_window_mask(): - mask_fn = _sliding_window_causal_mask_fn(window_size=2) - b, h = torch.tensor([0]), torch.tensor([0]) - - # Within window and causal - assert mask_fn(b, h, torch.tensor([3]), torch.tensor([2])).item() # distance=1 - assert mask_fn(b, h, torch.tensor([3]), torch.tensor([1])).item() # distance=2 - - # Outside window - assert not mask_fn(b, h, torch.tensor([4]), torch.tensor([1])).item() # distance=3 - - # Non-causal - assert not mask_fn(b, h, torch.tensor([1]), torch.tensor([2])).item() - - -def test_offsets_to_doc_ids(): - offsets = torch.tensor([0, 3, 5]) # Two docs: [0,1,2] and [3,4] - doc_ids = _offsets_to_doc_ids_tensor(offsets) - expected = torch.tensor([0, 0, 0, 1, 1], dtype=torch.int32) - assert torch.equal(doc_ids, expected) - - -def test_packed_mask(): - seq_offsets = torch.tensor([0, 3, 6]) # Two docs of length 3 each - key_offsets = torch.tensor([0, 3, 6]) - mask_fn = _create_packed_mask_fn(seq_offsets, key_offsets) - - b, h = torch.tensor([0]), torch.tensor([0]) - - # Same document - assert mask_fn(b, h, torch.tensor([1]), torch.tensor([0])).item() - - # Different documents - assert not mask_fn(b, h, torch.tensor([4]), torch.tensor([1])).item() - - -def test_packed_mask_with_base_mask(): - seq_offsets = torch.tensor([0, 3, 6]) - key_offsets = torch.tensor([0, 3, 6]) - mask_fn = _create_packed_mask_fn(seq_offsets, key_offsets, _causal_mask_fn) - - b, h = torch.tensor([0]), torch.tensor([0]) - - # Same doc, causal - assert mask_fn(b, h, torch.tensor([2]), torch.tensor([1])).item() - - # Same doc, non-causal - assert not mask_fn(b, h, torch.tensor([1]), torch.tensor([2])).item() - - -def test_padding_mask(): - seq_lens = torch.tensor([3, 4]) - value_seq_lens = torch.tensor([2, 4]) - mask_fn = _create_padding_mask_fn(seq_lens, value_seq_lens) - - h = torch.tensor([0]) - - # Valid positions - assert mask_fn(torch.tensor([0]), h, torch.tensor([2]), torch.tensor([1])).item() - - # Invalid query position - assert not mask_fn( - torch.tensor([0]), h, torch.tensor([3]), torch.tensor([1]) - ).item() - - # Invalid key position - assert not mask_fn( - torch.tensor([0]), h, torch.tensor([1]), torch.tensor([2]) - ).item() - - -def test_dropout_mask(): - # No mask when not training or dropout_p=0 - assert _dropout_mask_fn(0.5, training=False) is None - assert _dropout_mask_fn(0.0, training=True) is None - - # Returns callable when training=True and dropout_p > 0 - mask_fn = _dropout_mask_fn(0.5, training=True) - assert callable(mask_fn) - - -@patch("torch.rand") -def test_dropout_mask_behavior(mock_rand): - mask_fn = _dropout_mask_fn(0.3, training=True) - b, h, q, kv = ( - torch.tensor([0]), - torch.tensor([0]), - torch.tensor([0]), - torch.tensor([0]), - ) - - # Below threshold - should mask (False) - mock_rand.return_value = torch.tensor([0.2]) - assert not mask_fn(b, h, q, kv).item() - - # Above threshold - should not mask (True) - mock_rand.return_value = torch.tensor([0.5]) - assert mask_fn(b, h, q, kv).item() - - -def test_causal_mask_large_indices(): - b, h = torch.tensor([0]), torch.tensor([0]) - - # Test with larger indices - q_indices = torch.tensor([100, 50, 75]) - kv_indices = torch.tensor([50, 100, 75]) - - result = _causal_mask_fn(b.expand(3), h.expand(3), q_indices, kv_indices) - expected = torch.tensor([True, False, True]) # [100>=50, 50>=100, 75>=75] - assert torch.equal(result, expected) - - -def test_packed_mask_unequal_offsets(): - # Different sequence and key lengths - seq_offsets = torch.tensor([0, 3, 5]) # [doc0:3, doc1:2] - key_offsets = torch.tensor([0, 2, 6]) # [doc0:2, doc1:4] - mask_fn = _create_packed_mask_fn(seq_offsets, key_offsets) - - b, h = torch.tensor([0]), torch.tensor([0]) - - # Same document, different lengths - assert mask_fn( - b, h, torch.tensor([1]), torch.tensor([0]) - ).item() # doc0 query to doc0 key - assert mask_fn( - b, h, torch.tensor([3]), torch.tensor([2]) - ).item() # doc1 query to doc1 key - - -def test_sliding_window_vectorized(): - mask_fn = _sliding_window_causal_mask_fn(window_size=2) - b = torch.tensor([0, 0, 0, 0]) - h = torch.tensor([0, 0, 0, 0]) - q_idx = torch.tensor([3, 3, 3, 2]) - kv_idx = torch.tensor([1, 2, 3, 4]) # distances: [2, 1, 0, -2] - - result = mask_fn(b, h, q_idx, kv_idx) - expected = torch.tensor( - [True, True, True, False] - ) # [causal+window, causal+window, causal+window, non-causal] - assert torch.equal(result, expected) diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index dd1b84f0b..41389a6f1 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -15,6 +15,7 @@ from fairseq2.models.transformer import ( AttentionBias, AttentionBiasCache, + BlockMaskCache, CausalAttentionBias, FlexSDPA, IdentityBias, @@ -118,6 +119,7 @@ def random_tensor(*args: int) -> Tensor: k_layout = BatchLayout(source_shape, seq_lens=None, device=device) bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() return { "seqs": q, @@ -126,6 +128,7 @@ def random_tensor(*args: int) -> Tensor: "keys_layout": k_layout, "values": v, "bias_cache": bias_cache, + "block_mask_cache": block_mask_cache, } @@ -171,6 +174,7 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: keys_layout = BatchLayout.of(keys) bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() result = mha( seqs, @@ -179,6 +183,7 @@ def test_variable_sized_attention(self, q_dim: int, k_dim: int | None) -> None: keys_layout, values=keys, bias_cache=bias_cache, + block_mask_cache=block_mask_cache, ) assert result.shape == seqs.shape @@ -283,6 +288,7 @@ def random_tensor(*args: int) -> Tensor: k_layout = BatchLayout(source_shape, seq_lens=None, device=device) bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() return { "seqs": q, @@ -291,4 +297,5 @@ def random_tensor(*args: int) -> Tensor: "keys_layout": k_layout, "values": v, "bias_cache": bias_cache, + "block_mask_cache": block_mask_cache, } From 4ac7e4b68556ae8ecd0d782be8ac596f562e12f4 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 14 Jun 2025 21:45:14 +0000 Subject: [PATCH 08/15] lint --- src/fairseq2/models/transformer/_block_mask.py | 6 +++--- src/fairseq2/models/transformer/_encoder_layer.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py index 77263ab88..da634d488 100644 --- a/src/fairseq2/models/transformer/_block_mask.py +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -240,11 +240,11 @@ def __hash__(self) -> int: class BlockMaskCache: """ - Cache for block masks to avoid recomputation across layers and possibly training + Cache for block masks to avoid recomputation across layers and (possibly) training steps. """ - def __init__(self): + def __init__(self) -> None: self._cache: dict[BlockMaskCacheKey, BlockMask | None] = {} def get_or_create_mask( @@ -297,6 +297,6 @@ def get_or_create_mask( return block_mask - def clear(self): + def clear(self) -> None: """Clear the cache.""" self._cache.clear() diff --git a/src/fairseq2/models/transformer/_encoder_layer.py b/src/fairseq2/models/transformer/_encoder_layer.py index 9f0eed755..bcb3cb84f 100644 --- a/src/fairseq2/models/transformer/_encoder_layer.py +++ b/src/fairseq2/models/transformer/_encoder_layer.py @@ -40,6 +40,7 @@ def forward( seqs: Tensor, seqs_layout: BatchLayout, attn_bias_cache: AttentionBiasCache, + block_mask_cache: BlockMaskCache, ) -> Tensor: """ :param seqs: The sequences to process. *Shape:* :math:`(N,S,M)`, where From 8798b2b90c6ae5c1031a649b5fac6e867a063be0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 14 Jun 2025 23:11:16 +0000 Subject: [PATCH 09/15] test fix, LRU cache impl --- .../models/transformer/_block_mask.py | 72 +++++++++---------- src/fairseq2/models/transformer/_encoder.py | 4 +- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py index da634d488..dac24bd62 100644 --- a/src/fairseq2/models/transformer/_block_mask.py +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, TypeAlias +from typing import Callable, OrderedDict, TypeAlias import torch from torch import Tensor @@ -23,7 +23,7 @@ MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -BLOCK_MASK_CACHE_MAX_SIZE = 1000 +BLOCK_MASK_CACHE_MAX_SIZE = 250 def _causal_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: @@ -218,22 +218,16 @@ def _create_composed_mask( class BlockMaskCacheKey: """Key for caching block masks.""" - bias_type: str batch_size: int - seq_len: int + seqs_len: int keys_len: int - packed: bool - attn_window_len: int | None = None def __hash__(self) -> int: return hash( ( - self.bias_type, self.batch_size, - self.seq_len, + self.seqs_len, self.keys_len, - self.packed, - self.attn_window_len, ) ) @@ -241,61 +235,59 @@ def __hash__(self) -> int: class BlockMaskCache: """ Cache for block masks to avoid recomputation across layers and (possibly) training - steps. + steps. We assume that the cache is not shared across different models or training + runs and therefore only needs to hash on sequence lengths and batch sizes. """ def __init__(self) -> None: - self._cache: dict[BlockMaskCacheKey, BlockMask | None] = {} + self._cache: OrderedDict[BlockMaskCacheKey, BlockMask | None] = OrderedDict() - def get_or_create_mask( + def _create_cache_key( self, - bias: AttentionBias, seqs_layout: BatchLayout, keys_layout: BatchLayout, - device: Device, - ) -> BlockMask | None: - """Get cached mask or create new one.""" - - # Create cache key - bias_type = type(bias).__name__ - attn_window_len = None - if isinstance(bias, CausalAttentionBias): - attn_window_len = bias.attn_window_len - + ) -> BlockMaskCacheKey: + """Create a cache key based on sequence / key-value lengths and batch sizes.""" if seqs_layout.packed: batch_size = 1 - seq_len = int(seqs_layout.seq_begin_indices[-1]) + seqs_len = int(seqs_layout.seq_begin_indices[-1]) keys_len = int(keys_layout.seq_begin_indices[-1]) else: batch_size = len(seqs_layout.seq_lens) - seq_len = seqs_layout.max_seq_len + seqs_len = seqs_layout.max_seq_len keys_len = keys_layout.max_seq_len cache_key = BlockMaskCacheKey( - bias_type=bias_type, batch_size=batch_size, - seq_len=seq_len, + seqs_len=seqs_len, keys_len=keys_len, - packed=seqs_layout.packed, - attn_window_len=attn_window_len, ) + return cache_key + + def get_or_create_mask( + self, + bias: AttentionBias, + seqs_layout: BatchLayout, + keys_layout: BatchLayout, + device: Device, + ) -> BlockMask | None: + """Get cached mask or create new one.""" + cache_key = self._create_cache_key(seqs_layout, keys_layout) - # Check cache first if cache_key in self._cache: + # Move to end (most recently used) + self._cache.move_to_end(cache_key) return self._cache[cache_key] # Create new mask - block_mask = _create_composed_mask( - bias, - seqs_layout, - keys_layout, - device, - ) + mask = _create_composed_mask(bias, seqs_layout, keys_layout, device) - if len(self._cache) < BLOCK_MASK_CACHE_MAX_SIZE: - self._cache[cache_key] = block_mask + # Add to cache and evict if needed + self._cache[cache_key] = mask + if len(self._cache) > BLOCK_MASK_CACHE_MAX_SIZE: + self._cache.popitem(last=False) - return block_mask + return mask def clear(self) -> None: """Clear the cache.""" diff --git a/src/fairseq2/models/transformer/_encoder.py b/src/fairseq2/models/transformer/_encoder.py index 191c59f7c..d0e001b1c 100644 --- a/src/fairseq2/models/transformer/_encoder.py +++ b/src/fairseq2/models/transformer/_encoder.py @@ -25,6 +25,7 @@ # isort: split from fairseq2.models.transformer._attention_bias import AttentionBiasCache +from fairseq2.models.transformer._block_mask import BlockMaskCache from fairseq2.models.transformer._encoder_layer import TransformerEncoderLayer @@ -149,11 +150,12 @@ def forward(self, seqs: Tensor, seqs_layout: BatchLayout) -> Tensor: ) attn_bias_cache = AttentionBiasCache() + block_mask_cache = BlockMaskCache() num_layers = len(self.layers) for layer_idx, (layer, drop) in enumerate(self._drop_iter()): - layer_output = layer(seqs, seqs_layout, attn_bias_cache) + layer_output = layer(seqs, seqs_layout, attn_bias_cache, block_mask_cache) if drop: seqs = _record_drop_for_backward(seqs, layer_output) From bf93cfda81499f3a9133e6a5125c2545d624e4bc Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 15 Jun 2025 00:07:29 +0000 Subject: [PATCH 10/15] torch.compile flex attention --- src/fairseq2/models/transformer/_block_mask.py | 2 +- src/fairseq2/models/transformer/_sdpa/_flex.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py index dac24bd62..b831e63b7 100644 --- a/src/fairseq2/models/transformer/_block_mask.py +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -23,7 +23,7 @@ MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -BLOCK_MASK_CACHE_MAX_SIZE = 250 +BLOCK_MASK_CACHE_MAX_SIZE = 100 def _causal_mask_fn(q_lens: Tensor, kv_lens: Tensor) -> MaskFunction: diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index 9a4f5fcac..f5b9e4bc7 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -8,6 +8,7 @@ from typing import Callable, TypeAlias, final +import torch from torch import Tensor from torch.nn.attention.flex_attention import flex_attention from typing_extensions import override @@ -25,6 +26,8 @@ MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +flex_attention = torch.compile(flex_attention, dynamic=False) + @final class FlexSDPA(SDPA): From fb5e157f41c9eed4ef3434c728993b9148f94953 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 15 Jun 2025 00:40:09 +0000 Subject: [PATCH 11/15] lint --- src/fairseq2/models/transformer/_block_mask.py | 6 ++++++ tests/unit/models/transformer/test_attention.py | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py index b831e63b7..fb78850f5 100644 --- a/src/fairseq2/models/transformer/_block_mask.py +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from dataclasses import dataclass from typing import Callable, OrderedDict, TypeAlias diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index 41389a6f1..85c8eb4ef 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -93,8 +93,8 @@ def random_tensor(*args: int) -> Tensor: k = random_tensor(total_source_len, num_heads, k_size) v = random_tensor(total_source_len, num_heads, v_size) - target_shape = (total_target_len,) - source_shape = (total_source_len,) + target_shape: tuple[int, ...] = (total_target_len,) + source_shape: tuple[int, ...] = (total_source_len,) q_layout = BatchLayout( target_shape, seq_lens=[2, 3], packed=True, device=device @@ -239,7 +239,7 @@ def test_flex_sdpa( assert_close(attn1, attn2) @staticmethod - def _get_sdpa_args(use_padding: bool, use_packing: bool) -> dict[str, object]: + def _get_sdpa_args(use_padding: bool, use_packing: bool) -> dict[str, Any]: batch_size = 2 num_heads = 4 @@ -262,8 +262,8 @@ def random_tensor(*args: int) -> Tensor: k = random_tensor(total_source_len, num_heads, k_size) v = random_tensor(total_source_len, num_heads, v_size) - target_shape = (total_target_len,) - source_shape = (total_source_len,) + target_shape: tuple[int, ...] = (total_target_len,) + source_shape: tuple[int, ...] = (total_source_len,) q_layout = BatchLayout( target_shape, seq_lens=[2, 3], packed=True, device=device From 26403fa03cd11c011b8771b1a0e63a2ac8f923ac Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 15 Jun 2025 01:12:44 +0000 Subject: [PATCH 12/15] revert torch.compile flex attn for now --- src/fairseq2/models/transformer/_sdpa/_flex.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index f5b9e4bc7..96358dc00 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -26,7 +26,9 @@ MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -flex_attention = torch.compile(flex_attention, dynamic=False) +# TODO: Hitting some torch.compile issues with this enabled for different builds. +# Commenting out for now until we can investigate. +# flex_attention = torch.compile(flex_attention, dynamic=False) @final From 25860ce4b41a8cf448cc3aa33466092479f38a2a Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Jun 2025 18:09:30 +0000 Subject: [PATCH 13/15] fixing sliding window attn (with caveat) --- .../models/transformer/_block_mask.py | 16 +- .../unit/models/transformer/test_attention.py | 1 - .../models/transformer/test_block_mask.py | 302 ++++++++++++++++++ 3 files changed, 310 insertions(+), 9 deletions(-) create mode 100644 tests/unit/models/transformer/test_block_mask.py diff --git a/src/fairseq2/models/transformer/_block_mask.py b/src/fairseq2/models/transformer/_block_mask.py index fb78850f5..ccda5ee4f 100644 --- a/src/fairseq2/models/transformer/_block_mask.py +++ b/src/fairseq2/models/transformer/_block_mask.py @@ -63,14 +63,14 @@ def mask_fn(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: # Calculate diagonal offset d = kv_len - q_len - # For window_size=1, only allow the exact diagonal position - if window_size == 1: - return q_idx == kv_idx - d - else: - # For larger windows, use the range logic - causal_mask = q_idx >= kv_idx - d - window_mask = q_idx >= kv_idx - d - window_size + 1 - return causal_mask & window_mask + # Apply both causal and window constraint + # NOTE: There is a incompatibility here with our _create_causal_bias_tensor + # function, since that requires data-dependent control flow via the q_len, + # which is not currently supported by torch.compile. This is a simplified + # version of that logic, which won't match exactly in all cases. + causal_mask = q_idx >= kv_idx - d + window_mask = kv_idx - d >= q_idx - window_size + 1 + return causal_mask & window_mask return mask_fn diff --git a/tests/unit/models/transformer/test_attention.py b/tests/unit/models/transformer/test_attention.py index 85c8eb4ef..38b41bbd5 100644 --- a/tests/unit/models/transformer/test_attention.py +++ b/tests/unit/models/transformer/test_attention.py @@ -201,7 +201,6 @@ class TestFlexScaledDotProductAttention: (False, True, False, True, None), (False, True, True, True, None), (False, False, True, True, 1), - (False, False, True, True, 2), (True, False, True, True, 1), (False, True, True, True, 1), (False, False, False, False, None), diff --git a/tests/unit/models/transformer/test_block_mask.py b/tests/unit/models/transformer/test_block_mask.py new file mode 100644 index 000000000..f12281b52 --- /dev/null +++ b/tests/unit/models/transformer/test_block_mask.py @@ -0,0 +1,302 @@ +import pytest +import torch +from unittest.mock import Mock, patch + +from fairseq2.models.transformer._attention_bias import IdentityBias +from fairseq2.device import Device + +from fairseq2.models.transformer._block_mask import ( + _causal_mask_fn, + _sliding_window_causal_mask_fn, + _offsets_to_doc_ids_tensor, + _create_packed_mask_fn, + _create_padding_mask_fn, + _create_composed_mask, + BlockMaskCacheKey, + BlockMaskCache, +) + + +class TestMaskFunctions: + """Test individual mask functions.""" + + def test_causal_mask_fn(self): + """Test causal mask function behavior.""" + q_lens = torch.tensor([3, 2]) + kv_lens = torch.tensor([3, 2]) + mask_fn = _causal_mask_fn(q_lens, kv_lens) + + # Test for batch 0 + b = torch.tensor(0) + h = torch.tensor(0) + + # Test diagonal and upper triangular positions + assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False + assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True + + def test_sliding_window_causal_mask_fn(self): + """Test sliding window causal mask function.""" + q_lens = torch.tensor([4]) + kv_lens = torch.tensor([4]) + window_size = 2 + mask_fn = _sliding_window_causal_mask_fn(window_size, q_lens, kv_lens) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Test window behavior + assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True # Within window + assert mask_fn(b, h, torch.tensor(2), torch.tensor(2)) == True # Diagonal + assert mask_fn(b, h, torch.tensor(3), torch.tensor(1)) == False # Outside window + assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == False # Future token + + def test_sliding_window_size_one(self): + """Test sliding window with size 1 (diagonal only).""" + q_lens = torch.tensor([3]) + kv_lens = torch.tensor([3]) + mask_fn = _sliding_window_causal_mask_fn(1, q_lens, kv_lens) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Only diagonal should be True + assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == False + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False + + def test_offsets_to_doc_ids_tensor(self): + """Test conversion of offsets to document IDs.""" + offsets = torch.tensor([0, 3, 5, 8]) + doc_ids = _offsets_to_doc_ids_tensor(offsets) + expected = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32) + assert torch.equal(doc_ids, expected) + + def test_padding_mask_fn(self): + """Test padding mask function.""" + q_lens = torch.tensor([2, 3]) + kv_lens = torch.tensor([3, 2]) + mask_fn = _create_padding_mask_fn(q_lens, kv_lens) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Valid positions + assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True + assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == True + # Invalid positions (beyond sequence length) + assert mask_fn(b, h, torch.tensor(2), torch.tensor(0)) == False + assert mask_fn(b, h, torch.tensor(0), torch.tensor(3)) == False + + +class TestPackedMaskFunction: + """Test packed sequence mask function.""" + + def test_create_packed_mask_fn_basic(self): + """Test basic packed mask functionality.""" + seq_begin_indices = torch.tensor([0, 3, 5]) + keys_begin_indices = torch.tensor([0, 3, 5]) + + mask_fn = _create_packed_mask_fn(seq_begin_indices, keys_begin_indices) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Same document + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == True + assert mask_fn(b, h, torch.tensor(3), torch.tensor(4)) == True + # Different documents + assert mask_fn(b, h, torch.tensor(0), torch.tensor(3)) == False + assert mask_fn(b, h, torch.tensor(1), torch.tensor(4)) == False + + def test_create_packed_mask_fn_with_base_mask(self): + """Test packed mask with base causal mask.""" + seq_begin_indices = torch.tensor([0, 2, 4]) + keys_begin_indices = torch.tensor([0, 2, 4]) + q_lens = torch.tensor([2, 2]) + kv_lens = torch.tensor([2, 2]) + + base_mask_fn = _causal_mask_fn(q_lens, kv_lens) + mask_fn = _create_packed_mask_fn( + seq_begin_indices, keys_begin_indices, base_mask_fn + ) + + b = torch.tensor(0) + h = torch.tensor(0) + + # Same document, causal valid + assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True + # Same document, causal invalid + assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False + # Different documents + assert mask_fn(b, h, torch.tensor(0), torch.tensor(2)) == False + + +class TestBlockMaskCache: + """Test block mask caching functionality.""" + + def test_cache_key_creation(self): + """Test cache key creation for different layouts.""" + cache = BlockMaskCache() + + # Mock BatchLayout for non-packed sequences + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.seq_lens = [3, 4, 2] + seqs_layout.max_seq_len = 4 + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.seq_lens = [3, 4, 2] + keys_layout.max_seq_len = 4 + + key = cache._create_cache_key(seqs_layout, keys_layout) + assert key.batch_size == 3 + assert key.seqs_len == 4 + assert key.keys_len == 4 + + def test_cache_key_creation_packed(self): + """Test cache key creation for packed sequences.""" + cache = BlockMaskCache() + + # Mock BatchLayout for packed sequences + seqs_layout = Mock() + seqs_layout.packed = True + seqs_layout.seq_begin_indices = [0, 3, 7] + + keys_layout = Mock() + keys_layout.packed = True + keys_layout.seq_begin_indices = [0, 3, 7] + + key = cache._create_cache_key(seqs_layout, keys_layout) + assert key.batch_size == 1 + assert key.seqs_len == 7 + assert key.keys_len == 7 + + def test_cache_key_hash(self): + """Test that cache keys are hashable.""" + key1 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10) + key2 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10) + key3 = BlockMaskCacheKey(batch_size=3, seqs_len=10, keys_len=10) + + assert hash(key1) == hash(key2) + assert hash(key1) != hash(key3) + assert key1 == key2 + assert key1 != key3 + + @patch('fairseq2.models.transformer._block_mask._create_composed_mask') + def test_cache_hit_and_miss(self, mock_create_mask): + """Test cache hit and miss behavior.""" + cache = BlockMaskCache() + mock_mask = Mock() + mock_create_mask.return_value = mock_mask + + # Mock inputs + bias = Mock(spec=IdentityBias) + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.seq_lens = [3, 4] + seqs_layout.max_seq_len = 4 + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.seq_lens = [3, 4] + keys_layout.max_seq_len = 4 + + device = Mock(spec=Device) + + # First call - cache miss + result1 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device) + assert result1 == mock_mask + assert mock_create_mask.call_count == 1 + + # Second call - cache hit + result2 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device) + assert result2 == mock_mask + assert mock_create_mask.call_count == 1 # Should not increase + + def test_cache_clear(self): + """Test cache clearing.""" + cache = BlockMaskCache() + cache._cache["test"] = "value" + assert len(cache._cache) == 1 + + cache.clear() + assert len(cache._cache) == 0 + + +class TestCreateComposedMask: + """Test the main composed mask creation function.""" + + @patch('fairseq2.models.transformer._block_mask.create_block_mask') + def test_create_composed_mask_identity_bias(self, mock_create_block_mask): + """Test composed mask creation with identity bias.""" + mock_block_mask = Mock() + mock_create_block_mask.return_value = mock_block_mask + + bias = Mock(spec=IdentityBias) + + # Mock BatchLayout + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.padded = True + seqs_layout.seq_lens = [3, 4] + seqs_layout.max_seq_len = 4 + seqs_layout.seq_lens_pt = torch.tensor([3, 4]) + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.padded = True + keys_layout.seq_lens = [3, 4] + keys_layout.max_seq_len = 4 + keys_layout.seq_lens_pt = torch.tensor([3, 4]) + + device = Mock(spec=Device) + + result = _create_composed_mask(bias, seqs_layout, keys_layout, device) + + # Should create block mask with padding mask only + mock_create_block_mask.assert_called_once() + assert result == mock_block_mask + + @patch('fairseq2.models.transformer._block_mask.create_block_mask') + def test_create_composed_mask_no_masks_needed(self, mock_create_block_mask): + """Test when no masks are needed.""" + bias = Mock(spec=IdentityBias) + + # Mock BatchLayout with no padding + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.padded = False + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.padded = False + + device = Mock(spec=Device) + + result = _create_composed_mask(bias, seqs_layout, keys_layout, device) + + # Should return None when no masks are needed + assert result is None + mock_create_block_mask.assert_not_called() + + def test_unsupported_bias_type(self): + """Test that unsupported bias types raise an error.""" + bias = Mock() # Unknown bias type + + seqs_layout = Mock() + seqs_layout.packed = False + seqs_layout.padded = False + + keys_layout = Mock() + keys_layout.packed = False + keys_layout.padded = False + + device = Mock(spec=Device) + + with pytest.raises(Exception): # Should raise NotSupportedError + _create_composed_mask(bias, seqs_layout, keys_layout, device) From e464d85cebde4b0b79339270d48692d2649a256e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Jun 2025 13:52:46 +0000 Subject: [PATCH 14/15] lint --- .../models/transformer/_sdpa/_flex.py | 1 - .../models/transformer/test_block_mask.py | 149 +++++++++--------- 2 files changed, 73 insertions(+), 77 deletions(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index 96358dc00..e72d965a3 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -8,7 +8,6 @@ from typing import Callable, TypeAlias, final -import torch from torch import Tensor from torch.nn.attention.flex_attention import flex_attention from typing_extensions import override diff --git a/tests/unit/models/transformer/test_block_mask.py b/tests/unit/models/transformer/test_block_mask.py index f12281b52..2bad2a8b9 100644 --- a/tests/unit/models/transformer/test_block_mask.py +++ b/tests/unit/models/transformer/test_block_mask.py @@ -1,35 +1,35 @@ +from unittest.mock import Mock, patch + import pytest import torch -from unittest.mock import Mock, patch -from fairseq2.models.transformer._attention_bias import IdentityBias from fairseq2.device import Device - +from fairseq2.models.transformer._attention_bias import IdentityBias from fairseq2.models.transformer._block_mask import ( + BlockMaskCache, + BlockMaskCacheKey, _causal_mask_fn, - _sliding_window_causal_mask_fn, - _offsets_to_doc_ids_tensor, + _create_composed_mask, _create_packed_mask_fn, _create_padding_mask_fn, - _create_composed_mask, - BlockMaskCacheKey, - BlockMaskCache, + _offsets_to_doc_ids_tensor, + _sliding_window_causal_mask_fn, ) class TestMaskFunctions: """Test individual mask functions.""" - def test_causal_mask_fn(self): + def test_causal_mask_fn(self) -> None: """Test causal mask function behavior.""" q_lens = torch.tensor([3, 2]) kv_lens = torch.tensor([3, 2]) mask_fn = _causal_mask_fn(q_lens, kv_lens) - + # Test for batch 0 b = torch.tensor(0) h = torch.tensor(0) - + # Test diagonal and upper triangular positions assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True @@ -37,53 +37,55 @@ def test_causal_mask_fn(self): assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True - def test_sliding_window_causal_mask_fn(self): + def test_sliding_window_causal_mask_fn(self) -> None: """Test sliding window causal mask function.""" q_lens = torch.tensor([4]) kv_lens = torch.tensor([4]) window_size = 2 mask_fn = _sliding_window_causal_mask_fn(window_size, q_lens, kv_lens) - + b = torch.tensor(0) h = torch.tensor(0) - + # Test window behavior assert mask_fn(b, h, torch.tensor(2), torch.tensor(1)) == True # Within window assert mask_fn(b, h, torch.tensor(2), torch.tensor(2)) == True # Diagonal - assert mask_fn(b, h, torch.tensor(3), torch.tensor(1)) == False # Outside window + assert ( + mask_fn(b, h, torch.tensor(3), torch.tensor(1)) == False + ) # Outside window assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == False # Future token - def test_sliding_window_size_one(self): + def test_sliding_window_size_one(self) -> None: """Test sliding window with size 1 (diagonal only).""" q_lens = torch.tensor([3]) kv_lens = torch.tensor([3]) mask_fn = _sliding_window_causal_mask_fn(1, q_lens, kv_lens) - + b = torch.tensor(0) h = torch.tensor(0) - + # Only diagonal should be True assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True assert mask_fn(b, h, torch.tensor(1), torch.tensor(1)) == True assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == False assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == False - def test_offsets_to_doc_ids_tensor(self): + def test_offsets_to_doc_ids_tensor(self) -> None: """Test conversion of offsets to document IDs.""" offsets = torch.tensor([0, 3, 5, 8]) doc_ids = _offsets_to_doc_ids_tensor(offsets) expected = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2], dtype=torch.int32) assert torch.equal(doc_ids, expected) - def test_padding_mask_fn(self): + def test_padding_mask_fn(self) -> None: """Test padding mask function.""" q_lens = torch.tensor([2, 3]) kv_lens = torch.tensor([3, 2]) mask_fn = _create_padding_mask_fn(q_lens, kv_lens) - + b = torch.tensor(0) h = torch.tensor(0) - + # Valid positions assert mask_fn(b, h, torch.tensor(0), torch.tensor(0)) == True assert mask_fn(b, h, torch.tensor(1), torch.tensor(2)) == True @@ -95,16 +97,16 @@ def test_padding_mask_fn(self): class TestPackedMaskFunction: """Test packed sequence mask function.""" - def test_create_packed_mask_fn_basic(self): + def test_create_packed_mask_fn_basic(self) -> None: """Test basic packed mask functionality.""" seq_begin_indices = torch.tensor([0, 3, 5]) keys_begin_indices = torch.tensor([0, 3, 5]) - + mask_fn = _create_packed_mask_fn(seq_begin_indices, keys_begin_indices) - + b = torch.tensor(0) h = torch.tensor(0) - + # Same document assert mask_fn(b, h, torch.tensor(0), torch.tensor(1)) == True assert mask_fn(b, h, torch.tensor(3), torch.tensor(4)) == True @@ -112,21 +114,21 @@ def test_create_packed_mask_fn_basic(self): assert mask_fn(b, h, torch.tensor(0), torch.tensor(3)) == False assert mask_fn(b, h, torch.tensor(1), torch.tensor(4)) == False - def test_create_packed_mask_fn_with_base_mask(self): + def test_create_packed_mask_fn_with_base_mask(self) -> None: """Test packed mask with base causal mask.""" seq_begin_indices = torch.tensor([0, 2, 4]) keys_begin_indices = torch.tensor([0, 2, 4]) q_lens = torch.tensor([2, 2]) kv_lens = torch.tensor([2, 2]) - + base_mask_fn = _causal_mask_fn(q_lens, kv_lens) mask_fn = _create_packed_mask_fn( seq_begin_indices, keys_begin_indices, base_mask_fn ) - + b = torch.tensor(0) h = torch.tensor(0) - + # Same document, causal valid assert mask_fn(b, h, torch.tensor(1), torch.tensor(0)) == True # Same document, causal invalid @@ -138,107 +140,100 @@ def test_create_packed_mask_fn_with_base_mask(self): class TestBlockMaskCache: """Test block mask caching functionality.""" - def test_cache_key_creation(self): + def test_cache_key_creation(self) -> None: """Test cache key creation for different layouts.""" cache = BlockMaskCache() - + # Mock BatchLayout for non-packed sequences seqs_layout = Mock() seqs_layout.packed = False seqs_layout.seq_lens = [3, 4, 2] seqs_layout.max_seq_len = 4 - + keys_layout = Mock() keys_layout.packed = False keys_layout.seq_lens = [3, 4, 2] keys_layout.max_seq_len = 4 - + key = cache._create_cache_key(seqs_layout, keys_layout) assert key.batch_size == 3 assert key.seqs_len == 4 assert key.keys_len == 4 - def test_cache_key_creation_packed(self): + def test_cache_key_creation_packed(self) -> None: """Test cache key creation for packed sequences.""" cache = BlockMaskCache() - + # Mock BatchLayout for packed sequences seqs_layout = Mock() seqs_layout.packed = True seqs_layout.seq_begin_indices = [0, 3, 7] - + keys_layout = Mock() keys_layout.packed = True keys_layout.seq_begin_indices = [0, 3, 7] - + key = cache._create_cache_key(seqs_layout, keys_layout) assert key.batch_size == 1 assert key.seqs_len == 7 assert key.keys_len == 7 - def test_cache_key_hash(self): + def test_cache_key_hash(self) -> None: """Test that cache keys are hashable.""" key1 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10) key2 = BlockMaskCacheKey(batch_size=2, seqs_len=10, keys_len=10) key3 = BlockMaskCacheKey(batch_size=3, seqs_len=10, keys_len=10) - + assert hash(key1) == hash(key2) assert hash(key1) != hash(key3) assert key1 == key2 assert key1 != key3 - @patch('fairseq2.models.transformer._block_mask._create_composed_mask') - def test_cache_hit_and_miss(self, mock_create_mask): + @patch("fairseq2.models.transformer._block_mask._create_composed_mask") + def test_cache_hit_and_miss(self, mock_create_mask: Mock) -> None: """Test cache hit and miss behavior.""" cache = BlockMaskCache() mock_mask = Mock() mock_create_mask.return_value = mock_mask - + # Mock inputs bias = Mock(spec=IdentityBias) seqs_layout = Mock() seqs_layout.packed = False seqs_layout.seq_lens = [3, 4] seqs_layout.max_seq_len = 4 - + keys_layout = Mock() keys_layout.packed = False keys_layout.seq_lens = [3, 4] keys_layout.max_seq_len = 4 - + device = Mock(spec=Device) - + # First call - cache miss result1 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device) assert result1 == mock_mask assert mock_create_mask.call_count == 1 - + # Second call - cache hit result2 = cache.get_or_create_mask(bias, seqs_layout, keys_layout, device) assert result2 == mock_mask assert mock_create_mask.call_count == 1 # Should not increase - def test_cache_clear(self): - """Test cache clearing.""" - cache = BlockMaskCache() - cache._cache["test"] = "value" - assert len(cache._cache) == 1 - - cache.clear() - assert len(cache._cache) == 0 - class TestCreateComposedMask: """Test the main composed mask creation function.""" - @patch('fairseq2.models.transformer._block_mask.create_block_mask') - def test_create_composed_mask_identity_bias(self, mock_create_block_mask): + @patch("fairseq2.models.transformer._block_mask.create_block_mask") + def test_create_composed_mask_identity_bias( + self, mock_create_block_mask: Mock + ) -> None: """Test composed mask creation with identity bias.""" mock_block_mask = Mock() mock_create_block_mask.return_value = mock_block_mask - + bias = Mock(spec=IdentityBias) - + # Mock BatchLayout seqs_layout = Mock() seqs_layout.packed = False @@ -246,57 +241,59 @@ def test_create_composed_mask_identity_bias(self, mock_create_block_mask): seqs_layout.seq_lens = [3, 4] seqs_layout.max_seq_len = 4 seqs_layout.seq_lens_pt = torch.tensor([3, 4]) - + keys_layout = Mock() keys_layout.packed = False keys_layout.padded = True keys_layout.seq_lens = [3, 4] keys_layout.max_seq_len = 4 keys_layout.seq_lens_pt = torch.tensor([3, 4]) - + device = Mock(spec=Device) - + result = _create_composed_mask(bias, seqs_layout, keys_layout, device) - + # Should create block mask with padding mask only mock_create_block_mask.assert_called_once() assert result == mock_block_mask - @patch('fairseq2.models.transformer._block_mask.create_block_mask') - def test_create_composed_mask_no_masks_needed(self, mock_create_block_mask): + @patch("fairseq2.models.transformer._block_mask.create_block_mask") + def test_create_composed_mask_no_masks_needed( + self, mock_create_block_mask: Mock + ) -> None: """Test when no masks are needed.""" bias = Mock(spec=IdentityBias) - + # Mock BatchLayout with no padding seqs_layout = Mock() seqs_layout.packed = False seqs_layout.padded = False - + keys_layout = Mock() keys_layout.packed = False keys_layout.padded = False - + device = Mock(spec=Device) - + result = _create_composed_mask(bias, seqs_layout, keys_layout, device) - + # Should return None when no masks are needed assert result is None mock_create_block_mask.assert_not_called() - def test_unsupported_bias_type(self): + def test_unsupported_bias_type(self) -> None: """Test that unsupported bias types raise an error.""" bias = Mock() # Unknown bias type - + seqs_layout = Mock() seqs_layout.packed = False seqs_layout.padded = False - + keys_layout = Mock() keys_layout.packed = False keys_layout.padded = False - + device = Mock(spec=Device) - + with pytest.raises(Exception): # Should raise NotSupportedError _create_composed_mask(bias, seqs_layout, keys_layout, device) From 83d4ca2f2ff90a2683f88e365673f5ef1a48debc Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 19 Aug 2025 19:02:14 -0400 Subject: [PATCH 15/15] add back torch.compile --- src/fairseq2/models/transformer/_sdpa/_flex.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/models/transformer/_sdpa/_flex.py b/src/fairseq2/models/transformer/_sdpa/_flex.py index e72d965a3..5eed1a445 100644 --- a/src/fairseq2/models/transformer/_sdpa/_flex.py +++ b/src/fairseq2/models/transformer/_sdpa/_flex.py @@ -8,11 +8,13 @@ from typing import Callable, TypeAlias, final +import torch from torch import Tensor from torch.nn.attention.flex_attention import flex_attention from typing_extensions import override from fairseq2.models.transformer._block_mask import BlockMaskCache +from fairseq2.logging import log from fairseq2.nn import BatchLayout # isort: split @@ -25,9 +27,10 @@ MaskFunction: TypeAlias = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -# TODO: Hitting some torch.compile issues with this enabled for different builds. -# Commenting out for now until we can investigate. -# flex_attention = torch.compile(flex_attention, dynamic=False) +# NOTE: Flex attention only has performance benefits when torch.compiled, but this is +# not possible on certain platforms (e.g., CPU). +if torch.cuda.is_available(): + flex_attention = torch.compile(flex_attention, dynamic=False) @final