@@ -4258,26 +4258,32 @@ struct test_flash_attn_ext : public test_case {
4258
4258
const int64_t hsk_padded = GGML_PAD (hsk, ggml_blck_size (type_KV));
4259
4259
const int64_t hsv_padded = GGML_PAD (hsv, ggml_blck_size (type_KV));
4260
4260
4261
- auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
4261
+ auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view ) -> ggml_tensor * {
4262
4262
int64_t ne[4 ] = {ne0, ne1, ne2, ne3};
4263
4263
int64_t ne_perm[4 ];
4264
4264
for (int i = 0 ; i < 4 ; ++i) {
4265
4265
ne_perm[permute[i]] = ne[i];
4266
4266
}
4267
- ggml_tensor * t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4267
+ ggml_tensor * t;
4268
+ if (is_view) {
4269
+ ggml_tensor * t0 = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], 2 *ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4270
+ t = ggml_view_4d (ctx, t0, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ], t0->nb [1 ], t0->nb [2 ], t0->nb [3 ], 0 );
4271
+ } else {
4272
+ t = ggml_new_tensor_4d (ctx, type, ne_perm[0 ], ne_perm[1 ], ne_perm[2 ], ne_perm[3 ]);
4273
+ }
4268
4274
if (permute != std::array<int32_t , 4 >{0 , 1 , 2 , 3 }) {
4269
4275
t = ggml_permute (ctx, t, permute[0 ], permute[1 ], permute[2 ], permute[3 ]);
4270
4276
}
4271
4277
return t;
4272
4278
};
4273
4279
4274
- ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ]);
4280
+ ggml_tensor * q = create_permuted (GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0 ], nr23[1 ], false );
4275
4281
ggml_set_name (q, " q" );
4276
4282
4277
- ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ]);
4283
+ ggml_tensor * k = create_permuted (type_KV, hsk_padded, kv, nh, nr23[1 ], true ); // the K tensor is usually a view of the K cache
4278
4284
ggml_set_name (k, " k" );
4279
4285
4280
- ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ]);
4286
+ ggml_tensor * v = create_permuted (type_KV, hsv_padded, kv, nh, nr23[1 ], true ); // the V tensor is usually a view of the V cache
4281
4287
ggml_set_name (v, " v" );
4282
4288
4283
4289
ggml_tensor * m = nullptr ;
0 commit comments