17
17
18
18
from QEfficient .transformers .cache_utils import QEffDynamicCache
19
19
from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
20
+ from QEfficient .utils .constants import MIN_MASKED_ATTENTION_VALUE
20
21
21
22
22
23
def eager_attention_forward (module , query , key , value , attention_mask , head_mask = None , ** kwargs ):
@@ -30,15 +31,16 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
30
31
# if only "normal" attention layer implements causal mask
31
32
query_length , key_length = query .size (- 2 ), key .size (- 2 )
32
33
causal_mask = module .bias [:, :, key_length - query_length : key_length , :key_length ]
33
- mask_value = - 10000.0
34
34
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
35
35
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
36
- mask_value = torch .full ([], mask_value , dtype = attn_weights .dtype , device = attn_weights .device )
36
+ mask_value = torch .full ([], MIN_MASKED_ATTENTION_VALUE , dtype = attn_weights .dtype , device = attn_weights .device )
37
37
attn_weights = torch .where (causal_mask , attn_weights .to (attn_weights .dtype ), mask_value )
38
38
39
39
if attention_mask is not None :
40
40
# Apply the attention mask
41
- attn_weights = torch .where (attention_mask , torch .tensor (- 10000.0 , dtype = torch .float32 ), attn_weights )
41
+ attn_weights = torch .where (
42
+ attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights
43
+ )
42
44
43
45
attn_weights = nn .functional .softmax (attn_weights , dim = - 1 )
44
46
0 commit comments