@@ -28,7 +28,23 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
28
28
self .softmax_scale = softmax_scale
29
29
self .drop = nn .Dropout (attention_dropout )
30
30
31
- def forward (self , qkv , causal = None , key_padding_mask = None ):
31
+ def forward (
32
+ self ,
33
+ qkv = None ,
34
+ q = None ,
35
+ k = None ,
36
+ v = None ,
37
+ kv = None ,
38
+ causal = None ,
39
+ cu_seqlens = None ,
40
+ max_seqlen = None ,
41
+ cu_seqlens_q = None ,
42
+ cu_seqlens_k = None ,
43
+ max_seqlen_q = None ,
44
+ max_seqlen_k = None ,
45
+ softmax_scale = None ,
46
+ dropout_p = 0.0 ,
47
+ ):
32
48
"""Only supports the padded mode"""
33
49
"""Implements the multihead softmax attention.
34
50
Arguments
@@ -38,29 +54,48 @@ def forward(self, qkv, causal=None, key_padding_mask=None):
38
54
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
39
55
False means to mask out. (B, S)
40
56
"""
41
- batch_size , seqlen = qkv .shape [0 ], qkv .shape [1 ]
57
+ if qkv is not None :
58
+ query , key , value = qkv [:, :, 0 ], qkv [:, :, 1 ], qkv [:, :, 2 ]
59
+ device = query .device
60
+ elif kv is not None :
61
+ assert q is not None , "q should not be None, when kv is not None"
62
+ assert q .device == kv .device , "the devices of q and kv should be same"
63
+ query = q
64
+ key = kv [:, :, 0 ], kv [:, :, 1 ]
65
+ device = query .device
66
+ else :
67
+ assert (
68
+ q is not None and k is not None and q is not None
69
+ ), "q, k, v should not be None"
70
+ assert (
71
+ q .device == k .device and k .device == v .device
72
+ ), "the devices of q, k and v should be same"
73
+ query = q
74
+ key , value = k , v
75
+ device = query .device
76
+
77
+ batch_size , seqlen = query .shape [0 ], query .shape [1 ]
42
78
causal = self .causal if causal is None else causal
43
- q , k , v = qkv .unbind (dim = 2 )
44
79
softmax_scale = self .softmax_scale or 1.0 / math .sqrt (q .shape [- 1 ])
45
- scores = torch .einsum ("bthd,bshd->bhts" , q , k * softmax_scale )
46
- if key_padding_mask is not None :
47
- padding_mask = torch .full (
48
- (batch_size , seqlen ), - 10000.0 , dtype = scores .dtype , device = scores .device
49
- )
50
- padding_mask .masked_fill_ (key_padding_mask , 0.0 )
51
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
52
- scores = scores + rearrange (padding_mask , "b s -> b 1 1 s" )
80
+ scores = torch .einsum ("bthd,bshd->bhts" , query , key * softmax_scale )
81
+ # if key_padding_mask is not None:
82
+ # padding_mask = torch.full(
83
+ # (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
84
+ # )
85
+ # padding_mask.masked_fill_(key_padding_mask, 0.0)
86
+ # # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
87
+ # scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
53
88
if causal :
54
89
# "triu_tril_cuda_template" not implemented for 'BFloat16'
55
90
# So we have to construct the mask in float
56
91
causal_mask = torch .triu (
57
- torch .full ((seqlen , seqlen ), - 10000.0 , device = scores . device ), 1
92
+ torch .full ((seqlen , seqlen ), - 10000.0 , device = device ), 1
58
93
)
59
94
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
60
95
scores = scores + causal_mask .to (dtype = scores .dtype )
61
96
attention = torch .softmax (scores , dim = - 1 , dtype = v .dtype )
62
97
attention_drop = self .drop (attention )
63
- output = torch .einsum ("bhts,bshd->bthd" , attention_drop , v )
98
+ output = torch .einsum ("bhts,bshd->bthd" , attention_drop , value )
64
99
return output
65
100
66
101
0 commit comments