@@ -372,6 +372,10 @@ struct vk_device_struct {
372
372
bool subgroup_shuffle;
373
373
bool multi_add;
374
374
375
+ bool atomic_float_add;
376
+ bool add_rms_fusion;
377
+ uint32_t atomic_binding_alignment;
378
+
375
379
bool integer_dot_product;
376
380
377
381
bool subgroup_size_control;
@@ -451,6 +455,8 @@ struct vk_device_struct {
451
455
vk_pipeline pipeline_mul_norepeat[2][2][2];
452
456
vk_pipeline pipeline_div[2][2][2];
453
457
vk_pipeline pipeline_div_norepeat[2][2][2];
458
+ vk_pipeline pipeline_add_rms[2][2][2];
459
+ vk_pipeline pipeline_add_rms_norepeat[2][2][2];
454
460
455
461
// indexed by num_additional_fused_ops == num_adds - 1
456
462
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
@@ -1165,6 +1171,12 @@ class vk_perf_logger {
1165
1171
timings[name].push_back(time);
1166
1172
return;
1167
1173
}
1174
+ if (node->op == GGML_OP_RMS_NORM) {
1175
+ std::string name = ggml_op_name(node->op);
1176
+ name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1177
+ timings[name].push_back(time);
1178
+ return;
1179
+ }
1168
1180
timings[ggml_op_name(node->op)].push_back(time);
1169
1181
}
1170
1182
private:
@@ -1179,10 +1191,13 @@ struct ggml_backend_vk_context {
1179
1191
1180
1192
size_t semaphore_idx, event_idx;
1181
1193
ggml_vk_garbage_collector gc;
1182
- size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1183
- vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1194
+ size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset ;
1195
+ vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add ;
1184
1196
vk::Fence fence, almost_ready_fence;
1185
1197
bool almost_ready_fence_pending {};
1198
+ // Set before op_add and unset after op_rms_norm to indicate that the add should
1199
+ // use atomics to accumulate the square of the vector components
1200
+ bool do_add_rms_atomic;
1186
1201
1187
1202
vk_buffer buffer_pool[MAX_VK_BUFFERS];
1188
1203
@@ -2921,8 +2936,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2921
2936
2922
2937
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2923
2938
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2924
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2925
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2939
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true );
2940
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4 , sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true );
2926
2941
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2927
2942
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2928
2943
@@ -2992,20 +3007,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
2992
3007
};
2993
3008
2994
3009
bool rte = device->float_controls_rte_fp16;
2995
- #define CREATE_BINARY(name, namemod, spec) \
3010
+ #define CREATE_BINARY(name, namemod, spec, bindings ) \
2996
3011
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2997
3012
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2998
3013
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
2999
- "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3000
-
3001
- CREATE_BINARY(add, , {0})
3002
- CREATE_BINARY(add, _norepeat, {1})
3003
- CREATE_BINARY(sub, , {0})
3004
- CREATE_BINARY(sub, _norepeat, {1})
3005
- CREATE_BINARY(mul, , {0})
3006
- CREATE_BINARY(mul, _norepeat, {1})
3007
- CREATE_BINARY(div, , {0})
3008
- CREATE_BINARY(div, _norepeat, {1})
3014
+ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3015
+
3016
+ CREATE_BINARY(add, , {0}, 4)
3017
+ CREATE_BINARY(add, _norepeat, {1}, 4)
3018
+ CREATE_BINARY(sub, , {0}, 3)
3019
+ CREATE_BINARY(sub, _norepeat, {1}, 3)
3020
+ CREATE_BINARY(mul, , {0}, 3)
3021
+ CREATE_BINARY(mul, _norepeat, {1}, 3)
3022
+ CREATE_BINARY(div, , {0}, 3)
3023
+ CREATE_BINARY(div, _norepeat, {1}, 3)
3024
+ CREATE_BINARY(add_rms, , {0}, 4)
3025
+ CREATE_BINARY(add_rms, _norepeat, {1}, 4)
3009
3026
#undef CREATE_BINARY
3010
3027
3011
3028
if (device->multi_add) {
@@ -3286,6 +3303,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
3286
3303
device->coopmat_support = false;
3287
3304
device->integer_dot_product = false;
3288
3305
bool bfloat16_support = false;
3306
+ bool atomic_float_support = false;
3289
3307
3290
3308
for (const auto& properties : ext_props) {
3291
3309
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3325,6 +3343,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
3325
3343
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
3326
3344
bfloat16_support = true;
3327
3345
#endif
3346
+ } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3347
+ atomic_float_support = true;
3328
3348
}
3329
3349
}
3330
3350
@@ -3541,6 +3561,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
3541
3561
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
3542
3562
}
3543
3563
3564
+ VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3565
+ atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3566
+ if (atomic_float_support) {
3567
+ last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3568
+ last_struct = (VkBaseOutStructure *)&atomic_float_features;
3569
+ device_extensions.push_back("VK_EXT_shader_atomic_float");
3570
+ }
3571
+
3544
3572
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
3545
3573
3546
3574
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3552,6 +3580,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
3552
3580
#endif
3553
3581
3554
3582
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3583
+ device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
3555
3584
3556
3585
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
3557
3586
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
@@ -3872,6 +3901,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
3872
3901
3873
3902
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
3874
3903
3904
+ device->add_rms_fusion = !device->disable_fusion &&
3905
+ device->subgroup_add &&
3906
+ device->atomic_float_add;
3907
+ device->atomic_binding_alignment =
3908
+ std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3909
+
3875
3910
return device;
3876
3911
}
3877
3912
@@ -6914,10 +6949,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6914
6949
case GGML_OP_ADD:
6915
6950
{
6916
6951
if (ctx->num_additional_fused_ops > 0) {
6917
- return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6952
+ if (ctx->do_add_rms_atomic) {
6953
+ return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
6954
+ } else {
6955
+ return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6956
+ }
6957
+ }
6958
+ if (ctx->do_add_rms_atomic) {
6959
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6960
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6961
+ } else {
6962
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6963
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6918
6964
}
6919
- auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6920
- return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6921
6965
}
6922
6966
case GGML_OP_SUB:
6923
6967
{
@@ -7523,7 +7567,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7523
7567
}
7524
7568
} break;
7525
7569
case GGML_OP_RMS_NORM:
7526
- elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7570
+ if (ctx->do_add_rms_atomic) {
7571
+ // Run one element per thread, 128 threads per workgroup
7572
+ elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7573
+ } else {
7574
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7575
+ }
7527
7576
break;
7528
7577
7529
7578
case GGML_OP_SUM:
@@ -7671,7 +7720,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7671
7720
}
7672
7721
}
7673
7722
7674
- if (op == GGML_OP_GLU) {
7723
+ if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7724
+ vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7725
+ size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7726
+ ggml_vk_sync_buffers(subctx);
7727
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7728
+ { vk_subbuffer{ d_X, x_buf_offset, x_sz },
7729
+ vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7730
+ vk_subbuffer{ d_D, d_buf_offset, d_sz },
7731
+ vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7732
+ }, pc, elements);
7733
+ } else if (op == GGML_OP_GLU) {
7675
7734
// Empty src1 is possible in glu, but the shader needs a buffer
7676
7735
vk_subbuffer subbuf_y;
7677
7736
if (use_src1) {
@@ -7884,7 +7943,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
7884
7943
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7885
7944
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7886
7945
0,
7887
- 0.0f, 0.0f, 0 ,
7946
+ 0.0f, 0.0f, ctx->do_add_rms_atomic ,
7888
7947
}, dryrun);
7889
7948
}
7890
7949
@@ -8353,8 +8412,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
8353
8412
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
8354
8413
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
8355
8414
0,
8356
- op_params[0], 0.0f, 0 ,
8415
+ op_params[0], 0.0f, ctx->do_add_rms_atomic ,
8357
8416
}, dryrun);
8417
+
8418
+ if (ctx->do_add_rms_atomic) {
8419
+ ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8420
+ ctx->do_add_rms_atomic = false;
8421
+ }
8358
8422
}
8359
8423
8360
8424
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9632,6 +9696,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
9632
9696
}
9633
9697
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
9634
9698
}
9699
+ if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9700
+ VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9701
+ // Resize buffer
9702
+ if (ctx->prealloc_atomic_add != nullptr) {
9703
+ ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9704
+ }
9705
+ ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9706
+ }
9635
9707
}
9636
9708
9637
9709
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9687,10 +9759,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9687
9759
return false;
9688
9760
}
9689
9761
break;
9762
+ case GGML_OP_ADD:
9763
+ if (node_idx + 1 < cgraph->n_nodes &&
9764
+ cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9765
+ cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9766
+ ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9767
+ ctx->device->add_rms_fusion) {
9768
+ if (dryrun) {
9769
+ ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9770
+ }
9771
+ ctx->do_add_rms_atomic = true;
9772
+ }
9773
+ break;
9690
9774
case GGML_OP_REPEAT:
9691
9775
case GGML_OP_REPEAT_BACK:
9692
9776
case GGML_OP_GET_ROWS:
9693
- case GGML_OP_ADD:
9694
9777
case GGML_OP_ADD_ID:
9695
9778
case GGML_OP_ACC:
9696
9779
case GGML_OP_SUB:
@@ -9808,6 +9891,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9808
9891
// do the only thing needed for the dryrun.
9809
9892
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
9810
9893
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9894
+ if (node->op == GGML_OP_RMS_NORM) {
9895
+ ctx->do_add_rms_atomic = false;
9896
+ }
9811
9897
return false;
9812
9898
}
9813
9899
default:
@@ -10782,6 +10868,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10782
10868
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
10783
10869
}
10784
10870
10871
+ ctx->prealloc_size_atomic_add = 0;
10872
+ ctx->prealloc_size_atomic_add_offset = 0;
10873
+ ctx->do_add_rms_atomic = false;
10874
+
10785
10875
uint64_t total_mat_mul_bytes = 0;
10786
10876
for (int i = 0; i < cgraph->n_nodes; i++) {
10787
10877
if (!ctx->device->disable_fusion) {
@@ -10847,6 +10937,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
10847
10937
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
10848
10938
}
10849
10939
10940
+ if (ctx->prealloc_size_atomic_add) {
10941
+ if (ctx->compute_ctx.expired()) {
10942
+ compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10943
+ ctx->compute_ctx = compute_ctx;
10944
+ ggml_vk_ctx_begin(ctx->device, compute_ctx);
10945
+ } else {
10946
+ compute_ctx = ctx->compute_ctx.lock();
10947
+ }
10948
+ // initialize atomic sums to zero.
10949
+ ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10950
+ }
10951
+
10850
10952
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
10851
10953
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
10852
10954
// (and scaled down based on model size, so smaller models submit earlier).
0 commit comments