Skip to content

Commit b753596

Browse files
committed
vulkan: mul_mat_id coopmat2 optimizations
Add a path for when the tile fits in BN/2, similar to what we have for mul_mat. Only call fetch_scales/store_scales once per QUANT_K block, and once at the beginning in case start_k is not aligned.
1 parent b730706 commit b753596

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2214,7 +2214,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
22142214
s_mmq_wg_denoms_k = { 32, 64, 1 };
22152215

22162216
// spec constants and tile sizes for quant matmul_id
2217-
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
2217+
l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size };
22182218
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
22192219
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
22202220
l_mmqid_wg_denoms = { 128, 128, 1 };

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,18 +441,64 @@ void main() {
441441

442442
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
443443

444-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
445-
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
446-
447444
uint k_iters = (end_k - start_k + BK - 1) / BK;
448445

449446
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
447+
store_scales(tid);
448+
449+
#ifdef MUL_MAT_ID
450+
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
451+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
452+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
453+
454+
[[dont_unroll]]
455+
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
456+
457+
if ((block_k % QUANT_K) == 0) {
458+
store_scales(tid);
459+
}
460+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
461+
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
462+
}
463+
464+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
465+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
466+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
467+
468+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
469+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
470+
471+
sum = coopMatMulAdd(mat_a, mat_b, sum);
472+
} else {
473+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
474+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
475+
476+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
477+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
478+
479+
sum = coopMatMulAdd(mat_a, mat_b, sum);
480+
}
481+
}
482+
483+
// Convert from ACC_TYPE to D_TYPE
484+
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
485+
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
486+
487+
// Call callback to store each element, remapping row through shared memory
488+
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
489+
return;
490+
}
491+
#endif
492+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
493+
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
450494

451495
[[dont_unroll]]
452496
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
453497

454-
store_scales(tid);
455-
if (block_k + BK < end_k) {
498+
if ((block_k % QUANT_K) == 0) {
499+
store_scales(tid);
500+
}
501+
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
456502
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
457503
}
458504

0 commit comments

Comments
 (0)