Skip to content

Commit 2935a1b

Browse files
authored
Fix fp32_ln for various models (#41605)
* Add is_causal to KosmosTextAttention * Move get target_dtype to be imported elsewhere * Fix fp32 flash attention bug in bark * Fix is_causal in mllama * Fix fp32 issue on StableLM * Fix repo-consistency
1 parent b9bd8c4 commit 2935a1b

File tree

6 files changed

+26
-10
lines changed

6 files changed

+26
-10
lines changed

src/transformers/integrations/flash_attention.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
_use_top_left_mask = flash_attn_supports_top_left_mask()
1212

1313

14+
def get_target_dtype(query: torch.Tensor, module: torch.nn.Module) -> torch.dtype:
15+
"""If the query is in float32, return a target dtype compatible with flash attention. Return None otherwise."""
16+
if query.dtype == torch.float32:
17+
if torch.is_autocast_enabled():
18+
return torch.get_autocast_gpu_dtype()
19+
# Handle the case where the model is quantized
20+
elif hasattr(module.config, "_pre_quantization_dtype"):
21+
return module.config._pre_quantization_dtype
22+
else:
23+
return next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
24+
return None
25+
26+
1427
def flash_attention_forward(
1528
module: torch.nn.Module,
1629
query: torch.Tensor,
@@ -48,15 +61,7 @@ def flash_attention_forward(
4861
# cast them back in the correct dtype just to be sure everything works as expected.
4962
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
5063
# in fp32. (usually our RMSNorm modules handle it correctly)
51-
target_dtype = None
52-
if query.dtype == torch.float32:
53-
if torch.is_autocast_enabled():
54-
target_dtype = torch.get_autocast_gpu_dtype()
55-
# Handle the case where the model is quantized
56-
elif hasattr(module.config, "_pre_quantization_dtype"):
57-
target_dtype = module.config._pre_quantization_dtype
58-
else:
59-
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
64+
target_dtype = get_target_dtype(query, module)
6065

6166
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
6267
is_causal = kwargs.pop("is_causal", None)

src/transformers/models/bark/modeling_bark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858

5959
if is_flash_attn_available():
60+
from ...integrations.flash_attention import get_target_dtype
6061
from ...modeling_flash_attention_utils import _flash_attention_forward
6162

6263

@@ -78,6 +79,7 @@ def __init__(self, config, is_causal=False, layer_idx=None):
7879
self.embed_dim = config.hidden_size
7980
self.num_heads = config.num_heads
8081
self.head_dim = self.embed_dim // self.num_heads
82+
self.config = config
8183

8284
if config.hidden_size % config.num_heads != 0:
8385
raise ValueError(
@@ -228,6 +230,8 @@ def forward(
228230
if past_key_values is not None:
229231
key, value = past_key_values.update(key, value, self.layer_idx, {"cache_position": cache_position})
230232

233+
target_dtype = get_target_dtype(query, self) # if the query is in float32, this is the dtype to cast to for FA
234+
231235
attn_output = _flash_attention_forward(
232236
query,
233237
key,
@@ -237,6 +241,7 @@ def forward(
237241
dropout=self.dropout if self.training else 0.0,
238242
use_top_left_mask=self._flash_attn_uses_top_left_mask,
239243
is_causal=self.is_causal,
244+
target_dtype=target_dtype,
240245
)
241246

242247
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)

src/transformers/models/blt/modeling_blt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,12 @@ def __init__(self, config: BltConfig, layer_idx: int):
280280
self.scaling = self.head_dim**-0.5
281281
self.rope_theta = config.rope_theta
282282
self.layer_idx = layer_idx
283+
self.is_causal = True
283284

284285
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
285286
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
286287
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
287288
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
288-
self.is_causal = True
289289

290290
def forward(
291291
self,

src/transformers/models/kosmos2/modeling_kosmos2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def __init__(
680680
self.num_heads = num_heads
681681
self.dropout = dropout
682682
self.head_dim = embed_dim // num_heads
683+
self.is_causal = True
683684

684685
if (self.head_dim * num_heads) != self.embed_dim:
685686
raise ValueError(

src/transformers/models/mllama/modeling_mllama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int):
519519
self.scaling = self.head_dim**-0.5
520520
self.rope_theta = config.rope_theta
521521
self.layer_idx = layer_idx
522+
self.is_causal = True
522523

523524
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
524525
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)

src/transformers/models/stablelm/modeling_stablelm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252

5353

5454
if is_flash_attn_available():
55+
from ...integrations.flash_attention import get_target_dtype
5556
from ...modeling_flash_attention_utils import _flash_attention_forward
5657

5758

@@ -495,6 +496,8 @@ def forward(
495496

496497
dropout_rate = self.attention_dropout.p if self.training else 0.0
497498

499+
target_dtype = get_target_dtype(query_states, self)
500+
498501
attn_output = _flash_attention_forward(
499502
query_states,
500503
key_states,
@@ -505,6 +508,7 @@ def forward(
505508
dropout=dropout_rate,
506509
use_top_left_mask=self._flash_attn_uses_top_left_mask,
507510
is_causal=self.is_causal,
511+
target_dtype=target_dtype,
508512
)
509513

510514
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()

0 commit comments

Comments
 (0)