Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
_use_top_left_mask = flash_attn_supports_top_left_mask()


def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype:
"""If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
return torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
return module.config._pre_quantization_dtype
else:
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
return None


def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
Expand Down Expand Up @@ -48,15 +61,7 @@ def flash_attention_forward(
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
target_dtype = get_target_dtype(query, module)

# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
is_causal = kwargs.pop("is_causal", None)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@


if is_flash_attn_available():
from ...integrations.flash_attention import get_target_dtype
from ...modeling_flash_attention_utils import _flash_attention_forward


Expand All @@ -78,6 +79,7 @@ def __init__(self, config, is_causal=False, layer_idx=None):
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
self.config = config

if config.hidden_size % config.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -228,6 +230,8 @@ def forward(
if past_key_values is not None:
key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})

target_dtype = get_target_dtype(query, self) # if the query is in float32, this is the dtype to cast to for FA

attn_output = _flash_attention_forward(
query,
key,
Expand All @@ -237,6 +241,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
target_dtype=target_dtype,
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,12 @@ def __init__(self, config: BltConfig, layer_idx: int):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.layer_idx = layer_idx
self.is_causal = True

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.is_causal = True

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/kosmos2/modeling_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def __init__(
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.is_causal = True

if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int):
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.layer_idx = layer_idx
self.is_causal = True

self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@


if is_flash_attn_available():
from ...integrations.flash_attention import get_target_dtype
from ...modeling_flash_attention_utils import _flash_attention_forward


Expand Down Expand Up @@ -495,6 +496,8 @@ def forward(

dropout_rate = self.attention_dropout.p if self.training else 0.0

target_dtype = get_target_dtype(query_states, self)

attn_output = _flash_attention_forward(
query_states,
key_states,
Expand All @@ -505,6 +508,7 @@ def forward(
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
target_dtype=target_dtype,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down