Skip to content

Commit f15d515

Browse files
committed
model : avoid ggml_cont_3d for fused QKV weights
ggml-ci
1 parent 009b709 commit f15d515

File tree

2 files changed

+54
-94
lines changed

2 files changed

+54
-94
lines changed

src/llama-kv-cache.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,9 +1020,14 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
10201020

10211021
auto * k = layers[ikv].k;
10221022

1023+
const int64_t n_embd_k_gqa = k_cur->ne[0]*k_cur->ne[1];
10231024
const int64_t n_tokens = k_cur->ne[2];
10241025

1025-
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1026+
// we can merge dims 0 and 1
1027+
assert(k_cur->nb[0]*k_cur->ne[0] == k_cur->nb[1]);
1028+
1029+
//k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1030+
k_cur = ggml_view_2d(ctx, k_cur, n_embd_k_gqa, n_tokens, k_cur->nb[2], 0);
10261031

10271032
if (k->ne[2] > 1) {
10281033
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
@@ -1041,16 +1046,27 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
10411046
const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
10421047
const int64_t n_tokens = v_cur->ne[2];
10431048

1044-
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1049+
// we can merge dims 0 and 1
1050+
assert(v_cur->nb[0]*v_cur->ne[0] == v_cur->nb[1]);
1051+
1052+
//v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
10451053

10461054
if (!v_trans) {
1055+
v_cur = ggml_view_2d(ctx, v_cur, n_embd_v_gqa, n_tokens, v_cur->nb[2], 0);
1056+
10471057
if (v->ne[2] > 1) {
10481058
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
10491059
}
10501060

10511061
return ggml_set_rows(ctx, v, v_cur, v_idxs);
10521062
}
10531063

1064+
if (v_cur->nb[1]*v_cur->ne[1] != v_cur->nb[2]) {
1065+
v_cur = ggml_cont_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1066+
} else {
1067+
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1068+
}
1069+
10541070
// [TAG_V_CACHE_VARIABLE]
10551071
if (n_embd_v_gqa < v->ne[0]) {
10561072
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);

0 commit comments

Comments
 (0)