Skip to content

Commit a856a56

Browse files
committed
tests : add non-cont K,V FA tests
ggml-ci
1 parent eacdeb5 commit a856a56

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

tests/test-backend-ops.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4258,26 +4258,32 @@ struct test_flash_attn_ext : public test_case {
42584258
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
42594259
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
42604260

4261-
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
4261+
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
42624262
int64_t ne[4] = {ne0, ne1, ne2, ne3};
42634263
int64_t ne_perm[4];
42644264
for (int i = 0; i < 4; ++i) {
42654265
ne_perm[permute[i]] = ne[i];
42664266
}
4267-
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
4267+
ggml_tensor * t;
4268+
if (is_view) {
4269+
ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]);
4270+
t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0);
4271+
} else {
4272+
t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
4273+
}
42684274
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
42694275
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
42704276
}
42714277
return t;
42724278
};
42734279

4274-
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
4280+
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false);
42754281
ggml_set_name(q, "q");
42764282

4277-
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]);
4283+
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
42784284
ggml_set_name(k, "k");
42794285

4280-
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]);
4286+
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
42814287
ggml_set_name(v, "v");
42824288

42834289
ggml_tensor * m = nullptr;

0 commit comments

Comments
 (0)