@@ -1020,9 +1020,14 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm
1020
1020
1021
1021
auto * k = layers[ikv].k ;
1022
1022
1023
+ const int64_t n_embd_k_gqa = k_cur->ne [0 ]*k_cur->ne [1 ];
1023
1024
const int64_t n_tokens = k_cur->ne [2 ];
1024
1025
1025
- k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
1026
+ // we can merge dims 0 and 1
1027
+ assert (k_cur->nb [0 ]*k_cur->ne [0 ] == k_cur->nb [1 ]);
1028
+
1029
+ // k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
1030
+ k_cur = ggml_view_2d (ctx, k_cur, n_embd_k_gqa, n_tokens, k_cur->nb [2 ], 0 );
1026
1031
1027
1032
if (k->ne [2 ] > 1 ) {
1028
1033
k = ggml_reshape_2d (ctx, k, k->ne [0 ], k->ne [1 ]*k->ne [2 ]);
@@ -1041,16 +1046,27 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
1041
1046
const int64_t n_embd_v_gqa = v_cur->ne [0 ]*v_cur->ne [1 ];
1042
1047
const int64_t n_tokens = v_cur->ne [2 ];
1043
1048
1044
- v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1049
+ // we can merge dims 0 and 1
1050
+ assert (v_cur->nb [0 ]*v_cur->ne [0 ] == v_cur->nb [1 ]);
1051
+
1052
+ // v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
1045
1053
1046
1054
if (!v_trans) {
1055
+ v_cur = ggml_view_2d (ctx, v_cur, n_embd_v_gqa, n_tokens, v_cur->nb [2 ], 0 );
1056
+
1047
1057
if (v->ne [2 ] > 1 ) {
1048
1058
v = ggml_reshape_2d (ctx, v, v->ne [0 ], v->ne [1 ]*v->ne [2 ]);
1049
1059
}
1050
1060
1051
1061
return ggml_set_rows (ctx, v, v_cur, v_idxs);
1052
1062
}
1053
1063
1064
+ if (v_cur->nb [1 ]*v_cur->ne [1 ] != v_cur->nb [2 ]) {
1065
+ v_cur = ggml_cont_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1066
+ } else {
1067
+ v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
1068
+ }
1069
+
1054
1070
// [TAG_V_CACHE_VARIABLE]
1055
1071
if (n_embd_v_gqa < v->ne [0 ]) {
1056
1072
v_cur = ggml_pad (ctx, v_cur, v->ne [0 ] - n_embd_v_gqa, 0 , 0 , 0 );
0 commit comments