@@ -441,18 +441,64 @@ void main() {
441
441
442
442
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
443
443
444
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
445
- sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
446
-
447
444
uint k_iters = (end_k - start_k + BK - 1) / BK;
448
445
449
446
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
447
+ store_scales(tid);
448
+
449
+ #ifdef MUL_MAT_ID
450
+ if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
451
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
452
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
453
+
454
+ [[dont_unroll]]
455
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
456
+
457
+ if ((block_k % QUANT_K) == 0) {
458
+ store_scales(tid);
459
+ }
460
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
461
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
462
+ }
463
+
464
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
465
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
466
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
467
+
468
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
469
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
470
+
471
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
472
+ } else {
473
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
474
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
475
+
476
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
477
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
478
+
479
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
480
+ }
481
+ }
482
+
483
+ // Convert from ACC_TYPE to D_TYPE
484
+ coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
485
+ mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
486
+
487
+ // Call callback to store each element, remapping row through shared memory
488
+ coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
489
+ return;
490
+ }
491
+ #endif
492
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
493
+ sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
450
494
451
495
[[dont_unroll]]
452
496
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
453
497
454
- store_scales(tid);
455
- if (block_k + BK < end_k) {
498
+ if ((block_k % QUANT_K) == 0) {
499
+ store_scales(tid);
500
+ }
501
+ if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
456
502
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
457
503
}
458
504
0 commit comments