Skip to content

Commit 8bca2fb

Browse files
test: add test for selfatten, crossatten (#85)
add selfattention & crossattention utest for internevo
1 parent 5ddf327 commit 8bca2fb

File tree

3 files changed

+122
-16
lines changed

3 files changed

+122
-16
lines changed

deeplink_ext/internevo_ops/flash_attention_fallback.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,23 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
2828
self.softmax_scale = softmax_scale
2929
self.drop = nn.Dropout(attention_dropout)
3030

31-
def forward(self, qkv, causal=None, key_padding_mask=None):
31+
def forward(
32+
self,
33+
qkv=None,
34+
q=None,
35+
k=None,
36+
v=None,
37+
kv=None,
38+
causal=None,
39+
cu_seqlens=None,
40+
max_seqlen=None,
41+
cu_seqlens_q=None,
42+
cu_seqlens_k=None,
43+
max_seqlen_q=None,
44+
max_seqlen_k=None,
45+
softmax_scale=None,
46+
dropout_p=0.0,
47+
):
3248
"""Only supports the padded mode"""
3349
"""Implements the multihead softmax attention.
3450
Arguments
@@ -38,29 +54,48 @@ def forward(self, qkv, causal=None, key_padding_mask=None):
3854
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
3955
False means to mask out. (B, S)
4056
"""
41-
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
57+
if qkv is not None:
58+
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
59+
device = query.device
60+
elif kv is not None:
61+
assert q is not None, "q should not be None, when kv is not None"
62+
assert q.device == kv.device, "the devices of q and kv should be same"
63+
query = q
64+
key = kv[:, :, 0], kv[:, :, 1]
65+
device = query.device
66+
else:
67+
assert (
68+
q is not None and k is not None and q is not None
69+
), "q, k, v should not be None"
70+
assert (
71+
q.device == k.device and k.device == v.device
72+
), "the devices of q, k and v should be same"
73+
query = q
74+
key, value = k, v
75+
device = query.device
76+
77+
batch_size, seqlen = query.shape[0], query.shape[1]
4278
causal = self.causal if causal is None else causal
43-
q, k, v = qkv.unbind(dim=2)
4479
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
45-
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
46-
if key_padding_mask is not None:
47-
padding_mask = torch.full(
48-
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
49-
)
50-
padding_mask.masked_fill_(key_padding_mask, 0.0)
51-
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
52-
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
80+
scores = torch.einsum("bthd,bshd->bhts", query, key * softmax_scale)
81+
# if key_padding_mask is not None:
82+
# padding_mask = torch.full(
83+
# (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
84+
# )
85+
# padding_mask.masked_fill_(key_padding_mask, 0.0)
86+
# # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
87+
# scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
5388
if causal:
5489
# "triu_tril_cuda_template" not implemented for 'BFloat16'
5590
# So we have to construct the mask in float
5691
causal_mask = torch.triu(
57-
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
92+
torch.full((seqlen, seqlen), -10000.0, device=device), 1
5893
)
5994
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
6095
scores = scores + causal_mask.to(dtype=scores.dtype)
6196
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
6297
attention_drop = self.drop(attention)
63-
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
98+
output = torch.einsum("bhts,bshd->bthd", attention_drop, value)
6499
return output
65100

66101

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
3+
from deeplink_ext.internevo_ops.flash_attention import (
4+
FlashSelfAttention,
5+
FlashCrossAttention,
6+
)
7+
from deeplink_ext.internevo_ops.flash_attention_fallback import (
8+
SelfAttention,
9+
CrossAttention,
10+
)
11+
12+
13+
def test_self_attention():
14+
batch = 8
15+
seqlen = 32
16+
nheads = 16
17+
headdim = 64
18+
19+
q_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
20+
k_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
21+
v_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
22+
qkv_ref = torch.stack([q_ref, k_ref, v_ref], 2)
23+
q_ext = q_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
24+
k_ext = k_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
25+
v_ext = v_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
26+
27+
model_ref = SelfAttention()
28+
model_ext = FlashSelfAttention()
29+
out_ref = model_ref(None, q_ref, k_ref, v_ref, None)
30+
out_ext = model_ext(None, q_ext, k_ext, v_ext, None)
31+
out_ref.backward(torch.ones_like(out_ref))
32+
out_ext.backward(torch.ones_like(out_ext))
33+
34+
assert torch.allclose(
35+
out_ext.cpu(), out_ref.to(torch.float16), rtol=1e-3, atol=1e-3
36+
)
37+
assert torch.allclose(
38+
q_ext.grad.cpu(), q_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
39+
)
40+
assert torch.allclose(
41+
k_ext.grad.cpu(), k_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
42+
)
43+
assert torch.allclose(
44+
v_ext.grad.cpu(), v_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
45+
)
46+
47+
48+
def test_cross_attention():
49+
batch = 8
50+
seqlen = 32
51+
nheads = 16
52+
headdim = 64
53+
54+
q_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
55+
kv_ref = torch.rand([batch, seqlen, 2, nheads, headdim], requires_grad=True)
56+
q_ext = q_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
57+
kv_ext = kv_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
58+
59+
model_ref = CrossAttention()
60+
model_ext = FlashCrossAttention()
61+
out_ref = model_ref(q_ref, kv_ref)
62+
out_ext = model_ext(q_ext, kv_ext)
63+
out_ref.backward(torch.ones_like(out_ref))
64+
out_ext.backward(torch.ones_like(out_ext))
65+
66+
assert torch.allclose(
67+
out_ext.cpu(), out_ref.to(torch.float16), rtol=1e-3, atol=1e-3
68+
)
69+
assert torch.allclose(
70+
q_ext.grad.cpu(), q_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
71+
)
72+
assert torch.allclose(
73+
kv_ext.grad.cpu(), kv_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
74+
)

tests/test_rms_norm_internevo.py renamed to tests/internevo/test_rms_norm_internevo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,3 @@ def test_multi_cases_for_mixed_rms_norm():
6262
print(
6363
f"When input dtype is {input_dtype} and weight dtype is {weight_dtype}, MixedRMSNorm passes the backward test!"
6464
)
65-
66-
67-
test_multi_cases_for_mixed_rms_norm()

0 commit comments

Comments
 (0)