Skip to content

Commit d6da225

Browse files
committed
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs
There are really two parts to this change: (1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations. (2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply. The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums.
1 parent 1fe0029 commit d6da225

File tree

5 files changed

+251
-52
lines changed

5 files changed

+251
-52
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 125 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ struct vk_device_struct {
372372
bool subgroup_shuffle;
373373
bool multi_add;
374374

375+
bool atomic_float_add;
376+
bool add_rms_fusion;
377+
uint32_t atomic_binding_alignment;
378+
375379
bool integer_dot_product;
376380

377381
bool subgroup_size_control;
@@ -451,6 +455,8 @@ struct vk_device_struct {
451455
vk_pipeline pipeline_mul_norepeat[2][2][2];
452456
vk_pipeline pipeline_div[2][2][2];
453457
vk_pipeline pipeline_div_norepeat[2][2][2];
458+
vk_pipeline pipeline_add_rms[2][2][2];
459+
vk_pipeline pipeline_add_rms_norepeat[2][2][2];
454460

455461
// indexed by num_additional_fused_ops == num_adds - 1
456462
vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
@@ -1165,6 +1171,12 @@ class vk_perf_logger {
11651171
timings[name].push_back(time);
11661172
return;
11671173
}
1174+
if (node->op == GGML_OP_RMS_NORM) {
1175+
std::string name = ggml_op_name(node->op);
1176+
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1177+
timings[name].push_back(time);
1178+
return;
1179+
}
11681180
timings[ggml_op_name(node->op)].push_back(time);
11691181
}
11701182
private:
@@ -1179,10 +1191,13 @@ struct ggml_backend_vk_context {
11791191

11801192
size_t semaphore_idx, event_idx;
11811193
ggml_vk_garbage_collector gc;
1182-
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
1183-
vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
1194+
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset;
1195+
vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add;
11841196
vk::Fence fence, almost_ready_fence;
11851197
bool almost_ready_fence_pending {};
1198+
// Set before op_add and unset after op_rms_norm to indicate that the add should
1199+
// use atomics to accumulate the square of the vector components
1200+
bool do_add_rms_atomic;
11861201

11871202
vk_buffer buffer_pool[MAX_VK_BUFFERS];
11881203

@@ -2921,8 +2936,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
29212936

29222937
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29232938
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2924-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2925-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
2939+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
2940+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
29262941
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29272942
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
29282943

@@ -2992,20 +3007,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
29923007
};
29933008

29943009
bool rte = device->float_controls_rte_fp16;
2995-
#define CREATE_BINARY(name, namemod, spec) \
3010+
#define CREATE_BINARY(name, namemod, spec, bindings) \
29963011
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
29973012
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
29983013
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
2999-
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3000-
3001-
CREATE_BINARY(add, , {0})
3002-
CREATE_BINARY(add, _norepeat, {1})
3003-
CREATE_BINARY(sub, , {0})
3004-
CREATE_BINARY(sub, _norepeat, {1})
3005-
CREATE_BINARY(mul, , {0})
3006-
CREATE_BINARY(mul, _norepeat, {1})
3007-
CREATE_BINARY(div, , {0})
3008-
CREATE_BINARY(div, _norepeat, {1})
3014+
"main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
3015+
3016+
CREATE_BINARY(add, , {0}, 4)
3017+
CREATE_BINARY(add, _norepeat, {1}, 4)
3018+
CREATE_BINARY(sub, , {0}, 3)
3019+
CREATE_BINARY(sub, _norepeat, {1}, 3)
3020+
CREATE_BINARY(mul, , {0}, 3)
3021+
CREATE_BINARY(mul, _norepeat, {1}, 3)
3022+
CREATE_BINARY(div, , {0}, 3)
3023+
CREATE_BINARY(div, _norepeat, {1}, 3)
3024+
CREATE_BINARY(add_rms, , {0}, 4)
3025+
CREATE_BINARY(add_rms, _norepeat, {1}, 4)
30093026
#undef CREATE_BINARY
30103027

30113028
if (device->multi_add) {
@@ -3286,6 +3303,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
32863303
device->coopmat_support = false;
32873304
device->integer_dot_product = false;
32883305
bool bfloat16_support = false;
3306+
bool atomic_float_support = false;
32893307

32903308
for (const auto& properties : ext_props) {
32913309
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -3325,6 +3343,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
33253343
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
33263344
bfloat16_support = true;
33273345
#endif
3346+
} else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) {
3347+
atomic_float_support = true;
33283348
}
33293349
}
33303350

@@ -3541,6 +3561,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
35413561
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
35423562
}
35433563

3564+
VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {};
3565+
atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT;
3566+
if (atomic_float_support) {
3567+
last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features;
3568+
last_struct = (VkBaseOutStructure *)&atomic_float_features;
3569+
device_extensions.push_back("VK_EXT_shader_atomic_float");
3570+
}
3571+
35443572
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
35453573

35463574
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -3552,6 +3580,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
35523580
#endif
35533581

35543582
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
3583+
device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd;
35553584

35563585
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
35573586
device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
@@ -3872,6 +3901,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
38723901

38733902
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
38743903

3904+
device->add_rms_fusion = !device->disable_fusion &&
3905+
device->subgroup_add &&
3906+
device->atomic_float_add;
3907+
device->atomic_binding_alignment =
3908+
std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
3909+
38753910
return device;
38763911
}
38773912

@@ -6914,10 +6949,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69146949
case GGML_OP_ADD:
69156950
{
69166951
if (ctx->num_additional_fused_ops > 0) {
6917-
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6952+
if (ctx->do_add_rms_atomic) {
6953+
return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
6954+
} else {
6955+
return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6956+
}
6957+
}
6958+
if (ctx->do_add_rms_atomic) {
6959+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
6960+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6961+
} else {
6962+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6963+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69186964
}
6919-
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6920-
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
69216965
}
69226966
case GGML_OP_SUB:
69236967
{
@@ -7523,7 +7567,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
75237567
}
75247568
} break;
75257569
case GGML_OP_RMS_NORM:
7526-
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7570+
if (ctx->do_add_rms_atomic) {
7571+
// Run one element per thread, 128 threads per workgroup
7572+
elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
7573+
} else {
7574+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
7575+
}
75277576
break;
75287577

75297578
case GGML_OP_SUM:
@@ -7671,7 +7720,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76717720
}
76727721
}
76737722

7674-
if (op == GGML_OP_GLU) {
7723+
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
7724+
vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X;
7725+
size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0;
7726+
ggml_vk_sync_buffers(subctx);
7727+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7728+
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
7729+
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
7730+
vk_subbuffer{ d_D, d_buf_offset, d_sz },
7731+
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
7732+
}, pc, elements);
7733+
} else if (op == GGML_OP_GLU) {
76757734
// Empty src1 is possible in glu, but the shader needs a buffer
76767735
vk_subbuffer subbuf_y;
76777736
if (use_src1) {
@@ -7884,7 +7943,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
78847943
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
78857944
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
78867945
0,
7887-
0.0f, 0.0f, 0,
7946+
0.0f, 0.0f, ctx->do_add_rms_atomic,
78887947
}, dryrun);
78897948
}
78907949

@@ -8353,8 +8412,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
83538412
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
83548413
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
83558414
0,
8356-
op_params[0], 0.0f, 0,
8415+
op_params[0], 0.0f, ctx->do_add_rms_atomic,
83578416
}, dryrun);
8417+
8418+
if (ctx->do_add_rms_atomic) {
8419+
ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment;
8420+
ctx->do_add_rms_atomic = false;
8421+
}
83588422
}
83598423

83608424
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9632,6 +9696,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
96329696
}
96339697
ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
96349698
}
9699+
if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) {
9700+
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")");
9701+
// Resize buffer
9702+
if (ctx->prealloc_atomic_add != nullptr) {
9703+
ggml_vk_destroy_buffer(ctx->prealloc_atomic_add);
9704+
}
9705+
ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add);
9706+
}
96359707
}
96369708

96379709
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
@@ -9687,10 +9759,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
96879759
return false;
96889760
}
96899761
break;
9762+
case GGML_OP_ADD:
9763+
if (node_idx + 1 < cgraph->n_nodes &&
9764+
cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM &&
9765+
cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] &&
9766+
ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 &&
9767+
ctx->device->add_rms_fusion) {
9768+
if (dryrun) {
9769+
ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment;
9770+
}
9771+
ctx->do_add_rms_atomic = true;
9772+
}
9773+
break;
96909774
case GGML_OP_REPEAT:
96919775
case GGML_OP_REPEAT_BACK:
96929776
case GGML_OP_GET_ROWS:
9693-
case GGML_OP_ADD:
96949777
case GGML_OP_ADD_ID:
96959778
case GGML_OP_ACC:
96969779
case GGML_OP_SUB:
@@ -9808,6 +9891,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
98089891
// do the only thing needed for the dryrun.
98099892
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
98109893
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9894+
if (node->op == GGML_OP_RMS_NORM) {
9895+
ctx->do_add_rms_atomic = false;
9896+
}
98119897
return false;
98129898
}
98139899
default:
@@ -10782,6 +10868,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1078210868
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
1078310869
}
1078410870

10871+
ctx->prealloc_size_atomic_add = 0;
10872+
ctx->prealloc_size_atomic_add_offset = 0;
10873+
ctx->do_add_rms_atomic = false;
10874+
1078510875
uint64_t total_mat_mul_bytes = 0;
1078610876
for (int i = 0; i < cgraph->n_nodes; i++) {
1078710877
if (!ctx->device->disable_fusion) {
@@ -10847,6 +10937,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1084710937
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0);
1084810938
}
1084910939

10940+
if (ctx->prealloc_size_atomic_add) {
10941+
if (ctx->compute_ctx.expired()) {
10942+
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
10943+
ctx->compute_ctx = compute_ctx;
10944+
ggml_vk_ctx_begin(ctx->device, compute_ctx);
10945+
} else {
10946+
compute_ctx = ctx->compute_ctx.lock();
10947+
}
10948+
// initialize atomic sums to zero.
10949+
ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add);
10950+
}
10951+
1085010952
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
1085110953
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
1085210954
// (and scaled down based on model size, so smaller models submit earlier).
Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4+
#if ADD_RMS
5+
#extension GL_EXT_shader_atomic_float : enable
6+
#extension GL_KHR_shader_subgroup_arithmetic : enable
7+
#extension GL_KHR_shader_subgroup_basic : enable
8+
#endif
49

510
#include "types.comp"
611
#include "generic_binary_head.comp"
712

813
const uint num_threads = 256;
914

15+
layout (binding = 3) buffer AtomBuf {float data_atom;};
16+
1017
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
1118

1219
void main() {
@@ -15,15 +22,29 @@ void main() {
1522
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
1623
const uint num_iter = 2;
1724

25+
FLOAT_TYPE sum_sq = 0;
26+
1827
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
1928
if (idx >= p.ne) {
2029
continue;
2130
}
2231
uint i00, i01, i02, i03;
2332
get_indices(idx, i00, i01, i02, i03);
2433

25-
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
34+
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
35+
sum_sq += sum*sum;
36+
37+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
2638

2739
idx += num_threads;
2840
}
41+
42+
#if ADD_RMS
43+
if (p.param3 != 0) {
44+
sum_sq = subgroupAdd(sum_sq);
45+
if (sum_sq != 0 && gl_SubgroupInvocationID == 0) {
46+
atomicAdd(data_atom, sum_sq);
47+
}
48+
}
49+
#endif
2950
}

0 commit comments

Comments
 (0)