Skip to content

Commit 13da12b

Browse files
authored
feat: add test framwork (#88)
* add test framwork
1 parent 8bca2fb commit 13da12b

File tree

3 files changed

+113
-49
lines changed

3 files changed

+113
-49
lines changed

tests/__init__.py

Whitespace-only changes.

tests/core.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import typing
3+
4+
__all__ = ["call_module", "call_func", "copy_to_cpu", "allclose"]
5+
6+
7+
def call_module(module: torch.nn.Module, *forward_args):
8+
output_forward = module(*forward_args)
9+
grads = []
10+
if torch.is_tensor(output_forward):
11+
output_forward.backward(torch.ones_like(output_forward))
12+
elif isinstance(output_forward, (list, tuple)):
13+
assert torch.is_tensor(output_forward[0]), "output_forward[0] is not a tensor"
14+
output_forward[0].backward(torch.ones_like(output_forward[0]))
15+
else:
16+
raise RuntimeError(
17+
"the result of forward is not a tensor or list or tuple of tensor"
18+
)
19+
for arg in forward_args:
20+
if torch.is_tensor(arg) and arg.requires_grad:
21+
grads.append(arg.grad)
22+
return output_forward, grads
23+
24+
25+
def call_func(f: typing.Callable, args: list):
26+
return f(args)
27+
28+
29+
def copy_to_cpu(tensors: list[torch.Tensor], dtype=None):
30+
if dtype is None:
31+
dtype = torch.float32
32+
return [
33+
tensor.detach().clone().to(dtype).cpu().requires_grad_(tensor.requires_grad)
34+
for tensor in tensors
35+
]
36+
37+
38+
def allclose(expected_vals: list, real_vals: list, rtol, atol):
39+
assert len(expected_vals) == len(real_vals), "length of outputs is not same"
40+
for i in range(len(expected_vals)):
41+
assert type(expected_vals[i]) == type(
42+
real_vals[i]
43+
), "the type of expected_vals[{index}] is {type1}, but real_vals[{index}] is {type2}.".format(
44+
index=i, type1=type(expected_vals[i]), type2=type(real_vals[i])
45+
)
46+
if isinstance(expected_vals[i], torch.Tensor):
47+
assert isinstance(real_vals[i], torch.Tensor)
48+
return torch.allclose(
49+
expected_vals[i].cpu().to(torch.float32),
50+
real_vals[i].cpu().to(torch.float32),
51+
rtol,
52+
atol,
53+
)
54+
elif isinstance(expected_vals[i], (tuple, list)):
55+
assert isinstance(real_vals[i], (tuple, list))
56+
allclose(expected_vals[i], real_vals[i], rtol, atol)
57+
elif isinstance(expected_vals[i], dict):
58+
assert isinstance(real_vals[i], dict)
59+
for key, val in expected_vals[i].items():
60+
assert key in real_vals.keys(), "key {k} not in real_val.keys()".format(
61+
k=key
62+
)
63+
allclose(val, real_vals[key], rtol, atol)
64+
# Primitive type
65+
else:
66+
return abs(real_vals[i] - expected_vals[i]) <= atol + rtol * abs(
67+
expected_vals[i]
68+
)
Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from tests.core import copy_to_cpu, allclose, call_module, call_func
23

34
from deeplink_ext.internevo_ops.flash_attention import (
45
FlashSelfAttention,
@@ -11,64 +12,59 @@
1112

1213

1314
def test_self_attention():
14-
batch = 8
15-
seqlen = 32
16-
nheads = 16
17-
headdim = 64
15+
batch, seqlen, nheads, headdim = [8, 32, 16, 64]
1816

19-
q_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
20-
k_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
21-
v_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
22-
qkv_ref = torch.stack([q_ref, k_ref, v_ref], 2)
23-
q_ext = q_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
24-
k_ext = k_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
25-
v_ext = v_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
26-
27-
model_ref = SelfAttention()
28-
model_ext = FlashSelfAttention()
29-
out_ref = model_ref(None, q_ref, k_ref, v_ref, None)
30-
out_ext = model_ext(None, q_ext, k_ext, v_ext, None)
31-
out_ref.backward(torch.ones_like(out_ref))
32-
out_ext.backward(torch.ones_like(out_ext))
33-
34-
assert torch.allclose(
35-
out_ext.cpu(), out_ref.to(torch.float16), rtol=1e-3, atol=1e-3
17+
q_gpu = torch.rand(
18+
[batch, seqlen, nheads, headdim],
19+
dtype=torch.float16,
20+
requires_grad=True,
21+
device="cuda",
3622
)
37-
assert torch.allclose(
38-
q_ext.grad.cpu(), q_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
23+
k_gpu = torch.rand(
24+
[batch, seqlen, nheads, headdim],
25+
dtype=torch.float16,
26+
requires_grad=True,
27+
device="cuda",
3928
)
40-
assert torch.allclose(
41-
k_ext.grad.cpu(), k_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
29+
v_gpu = torch.rand(
30+
[batch, seqlen, nheads, headdim],
31+
dtype=torch.float16,
32+
requires_grad=True,
33+
device="cuda",
34+
)
35+
36+
q_cpu, k_cpu, v_cpu = copy_to_cpu([q_gpu, k_gpu, v_gpu])
37+
ouput_forward_cpu, grads_cpu = call_module(
38+
SelfAttention(), None, q_cpu, k_cpu, v_cpu, None
4239
)
43-
assert torch.allclose(
44-
v_ext.grad.cpu(), v_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
40+
ouput_forward_gpu, grads_gpu = call_module(
41+
FlashSelfAttention().cuda(), None, q_gpu, k_gpu, v_gpu, None
4542
)
43+
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
44+
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)
4645

4746

4847
def test_cross_attention():
49-
batch = 8
50-
seqlen = 32
51-
nheads = 16
52-
headdim = 64
48+
batch, seqlen, nheads, headdim = [8, 32, 16, 64]
5349

54-
q_ref = torch.rand([batch, seqlen, nheads, headdim], requires_grad=True)
55-
kv_ref = torch.rand([batch, seqlen, 2, nheads, headdim], requires_grad=True)
56-
q_ext = q_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
57-
kv_ext = kv_ref.clone().detach().to(torch.float16).cuda().requires_grad_()
58-
59-
model_ref = CrossAttention()
60-
model_ext = FlashCrossAttention()
61-
out_ref = model_ref(q_ref, kv_ref)
62-
out_ext = model_ext(q_ext, kv_ext)
63-
out_ref.backward(torch.ones_like(out_ref))
64-
out_ext.backward(torch.ones_like(out_ext))
65-
66-
assert torch.allclose(
67-
out_ext.cpu(), out_ref.to(torch.float16), rtol=1e-3, atol=1e-3
50+
q_gpu = torch.rand(
51+
[batch, seqlen, nheads, headdim],
52+
dtype=torch.float16,
53+
requires_grad=True,
54+
device="cuda",
6855
)
69-
assert torch.allclose(
70-
q_ext.grad.cpu(), q_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
56+
kv_gpu = torch.rand(
57+
[batch, seqlen, 2, nheads, headdim],
58+
dtype=torch.float16,
59+
requires_grad=True,
60+
device="cuda",
7161
)
72-
assert torch.allclose(
73-
kv_ext.grad.cpu(), kv_ref.grad.to(torch.float16), rtol=1e-3, atol=1e-3
62+
63+
q_cpu, kv_cpu = copy_to_cpu([q_gpu, kv_gpu])
64+
ouput_forward_cpu, grads_cpu = call_module(CrossAttention(), q_cpu, kv_cpu)
65+
ouput_forward_gpu, grads_gpu = call_module(
66+
FlashCrossAttention().cuda(), q_gpu, kv_gpu
7467
)
68+
69+
assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3)
70+
assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)