Skip to content

Commit d81e831

Browse files
committed
add softmax op
1 parent 26511b2 commit d81e831

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

keras_hub/src/models/smollm3/smollm3_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from keras import layers
12
from keras import ops
23
from keras import random
34

@@ -38,6 +39,12 @@ def eager_attention_forward(
3839
dropout=0.0,
3940
training=False,
4041
):
42+
softmax_op = layers.Softmax(
43+
axis=-1,
44+
dtype="float32",
45+
name="attention_softmax",
46+
)
47+
4148
key_states = repeat_kv(key, module.num_key_value_groups)
4249
value_states = repeat_kv(value, module.num_key_value_groups)
4350

@@ -47,10 +54,9 @@ def eager_attention_forward(
4754
)
4855

4956
if attention_mask is not None:
50-
causal_mask = attention_mask[:, :, :, : ops.shape(key_states)[-2]]
51-
attn_weights = ops.add(attn_weights, causal_mask)
52-
53-
attn_weights = ops.softmax(attn_weights, axis=-1)
57+
attn_weights = softmax_op(attn_weights, attention_mask[:, None, :, :])
58+
else:
59+
attn_weights = softmax_op(attn_weights)
5460

5561
if training:
5662
attn_weights = random.dropout(attn_weights, rate=dropout)

0 commit comments

Comments
 (0)