Skip to content

Commit e5e5eb9

Browse files
committed
updated stablelm_attention
1 parent eefed1e commit e5e5eb9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras_hub/src/models/stablelm/stablelm_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
77
from keras_hub.src.utils.keras_utils import clone_initializer
8-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
8+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
99

1010

1111
class StableLMAttention(keras.layers.Layer):
@@ -203,7 +203,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
203203
return self._softmax(attention_scores)
204204

205205
def _compute_attention(self, query, key, value, attention_mask=None):
206-
if has_flash_attention_support() and self.dropout == 0:
206+
if fused_attention_op_available() and self.dropout == 0:
207207
if attention_mask is not None:
208208
attention_mask = ops.expand_dims(attention_mask, axis=1)
209209
attention_mask = ops.cast(attention_mask, dtype="bool")

0 commit comments

Comments
 (0)