From b26cf611757360ba11bbd7b36da461a143f8b9f2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 12 Aug 2025 21:18:10 -0500 Subject: [PATCH 1/6] vulkan: optimize rms_norm, and allow the work to spread across multiple SMs There are really two parts to this change: (1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations. (2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply. The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 148 +++++++++++++++--- ggml/src/ggml-vulkan/vulkan-shaders/add.comp | 23 ++- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 121 +++++++++++--- .../vulkan-shaders/vulkan-shaders-gen.cpp | 8 +- tests/test-backend-ops.cpp | 3 + 5 files changed, 251 insertions(+), 52 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c7cfb6473e37d..804aa8442902b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -381,6 +381,10 @@ struct vk_device_struct { bool subgroup_shuffle; bool multi_add; + bool atomic_float_add; + bool add_rms_fusion; + uint32_t atomic_binding_alignment; + bool integer_dot_product; bool subgroup_size_control; @@ -460,6 +464,8 @@ struct vk_device_struct { vk_pipeline pipeline_mul_norepeat[2][2][2]; vk_pipeline pipeline_div[2][2][2]; vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_rms[2][2][2]; + vk_pipeline pipeline_add_rms_norepeat[2][2][2]; // indexed by num_additional_fused_ops == num_adds - 1 vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; @@ -1208,6 +1214,12 @@ class vk_perf_logger { timings[name].push_back(time); return; } + if (node->op == GGML_OP_RMS_NORM) { + std::string name = ggml_op_name(node->op); + name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; + timings[name].push_back(time); + return; + } timings[ggml_op_name(node->op)].push_back(time); } private: @@ -1222,10 +1234,13 @@ struct ggml_backend_vk_context { size_t semaphore_idx, event_idx; ggml_vk_garbage_collector gc; - size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; - vk_buffer prealloc_x, prealloc_y, prealloc_split_k; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add; vk::Fence fence, almost_ready_fence; bool almost_ready_fence_pending {}; + // Set before op_add and unset after op_rms_norm to indicate that the add should + // use atomics to accumulate the square of the vector components + bool do_add_rms_atomic; // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. vk_pipeline_struct * prealloc_y_last_pipeline_used {}; @@ -2987,8 +3002,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -3058,20 +3073,22 @@ static void ggml_vk_load_shaders(vk_device& device) { }; bool rte = device->float_controls_rte_fp16; -#define CREATE_BINARY(name, namemod, spec) \ +#define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ - "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); - - CREATE_BINARY(add, , {0}) - CREATE_BINARY(add, _norepeat, {1}) - CREATE_BINARY(sub, , {0}) - CREATE_BINARY(sub, _norepeat, {1}) - CREATE_BINARY(mul, , {0}) - CREATE_BINARY(mul, _norepeat, {1}) - CREATE_BINARY(div, , {0}) - CREATE_BINARY(div, _norepeat, {1}) + "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}, 4) + CREATE_BINARY(add, _norepeat, {1}, 4) + CREATE_BINARY(sub, , {0}, 3) + CREATE_BINARY(sub, _norepeat, {1}, 3) + CREATE_BINARY(mul, , {0}, 3) + CREATE_BINARY(mul, _norepeat, {1}, 3) + CREATE_BINARY(div, , {0}, 3) + CREATE_BINARY(div, _norepeat, {1}, 3) + CREATE_BINARY(add_rms, , {0}, 4) + CREATE_BINARY(add_rms, _norepeat, {1}, 4) #undef CREATE_BINARY if (device->multi_add) { @@ -3358,6 +3375,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_support = false; device->integer_dot_product = false; bool bfloat16_support = false; + bool atomic_float_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -3397,6 +3415,8 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) { + atomic_float_support = true; } } @@ -3613,6 +3633,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {}; + atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT; + if (atomic_float_support) { + last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features; + last_struct = (VkBaseOutStructure *)&atomic_float_features; + device_extensions.push_back("VK_EXT_shader_atomic_float"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -3624,6 +3652,7 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd; device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && @@ -3944,6 +3973,12 @@ static vk_device ggml_vk_get_device(size_t idx) { device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + device->add_rms_fusion = !device->disable_fusion && + device->subgroup_add && + device->atomic_float_add; + device->atomic_binding_alignment = + std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); + return device; } @@ -7109,10 +7144,19 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_ADD: { if (ctx->num_additional_fused_ops > 0) { - return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + if (ctx->do_add_rms_atomic) { + return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; + } else { + return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + } + } + if (ctx->do_add_rms_atomic) { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } else { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; } - auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; - return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; } case GGML_OP_SUB: { @@ -7748,7 +7792,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } break; case GGML_OP_RMS_NORM: - elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + if (ctx->do_add_rms_atomic) { + // Run one element per thread, 128 threads per workgroup + elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; + } else { + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + } break; case GGML_OP_SUM: @@ -7897,7 +7946,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_GLU) { + if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { + vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X; + size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0; + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { vk_subbuffer{ d_X, x_buf_offset, x_sz }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz }, + vk_subbuffer{ d_D, d_buf_offset, d_sz }, + vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE }, + }, pc, elements); + } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -8100,7 +8159,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, 0, + 0.0f, 0.0f, ctx->do_add_rms_atomic, }, dryrun); } @@ -8569,8 +8628,13 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, 0, + op_params[0], 0.0f, ctx->do_add_rms_atomic, }, dryrun); + + if (ctx->do_add_rms_atomic) { + ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment; + ctx->do_add_rms_atomic = false; + } } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -9848,6 +9912,14 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); } + if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")"); + // Resize buffer + if (ctx->prealloc_atomic_add != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_atomic_add); + } + ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add); + } } static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); @@ -9904,10 +9976,21 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr return false; } break; + case GGML_OP_ADD: + if (node_idx + 1 < cgraph->n_nodes && + cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM && + cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] && + ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 && + ctx->device->add_rms_fusion) { + if (dryrun) { + ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment; + } + ctx->do_add_rms_atomic = true; + } + break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ACC: case GGML_OP_SUB: @@ -10029,6 +10112,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr // do the only thing needed for the dryrun. vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (node->op == GGML_OP_RMS_NORM) { + ctx->do_add_rms_atomic = false; + } return false; } default: @@ -11098,6 +11184,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); } + ctx->prealloc_size_atomic_add = 0; + ctx->prealloc_size_atomic_add_offset = 0; + ctx->do_add_rms_atomic = false; + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (!ctx->device->disable_fusion) { @@ -11166,6 +11256,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + if (ctx->prealloc_size_atomic_add) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // initialize atomic sums to zero. + ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add); + } + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 2b4085c4f82d5..9feb09f113912 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -1,12 +1,19 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require +#if ADD_RMS +#extension GL_EXT_shader_atomic_float : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif #include "types.comp" #include "generic_binary_head.comp" const uint num_threads = 256; +layout (binding = 3) buffer AtomBuf {float data_atom;}; + layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; void main() { @@ -15,6 +22,8 @@ void main() { // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; + FLOAT_TYPE sum_sq = 0; + [[unroll]] for (uint i = 0; i < num_iter; ++i) { if (idx >= p.ne) { continue; @@ -22,8 +31,20 @@ void main() { uint i00, i01, i02, i03; get_indices(idx, i00, i01, i02, i03); - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]); + sum_sq += sum*sum; + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); idx += num_threads; } + +#if ADD_RMS + if (p.param3 != 0) { + sum_sq = subgroupAdd(sum_sq); + if (sum_sq != 0 && gl_SubgroupInvocationID == 0) { + atomicAdd(data_atom, sum_sq); + } + } +#endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index bdd7db2d6987a..fbd51ddc1f5d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -4,23 +4,26 @@ #include "types.comp" #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 512 +#define BLOCK_SIZE 128 layout (constant_id = 1) const bool do_multiply = false; layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sum[BLOCK_SIZE]; +layout (binding = 3) readonly buffer AtomBuf {float precomputed_sum;}; -void main() { +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void rms_norm(uint num_iters, bool use_atomic_add) { const uint ncols = p.ne00; const uint nrows = gl_NumWorkGroups.x; const uint nchannels = gl_NumWorkGroups.y; - const uint row = gl_WorkGroupID.x; + const uint row = use_atomic_add ? 0 : gl_WorkGroupID.x; const uint channel = gl_WorkGroupID.y; const uint samp = gl_WorkGroupID.z; - const uint tid = gl_LocalInvocationID.x; + // When using atomic add, the work is split across multiple workgroups in the x dimension + const uint tid = use_atomic_add ? gl_GlobalInvocationID.x : gl_LocalInvocationID.x; const uint stride_row = p.nb01; const uint stride_channel = p.nb02; @@ -30,38 +33,106 @@ void main() { uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); - sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); - sum[tid] += xi * xi; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - sum[tid] += sum[tid + s]; + if (use_atomic_add) { + sum = precomputed_sum; + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + FLOAT_TYPE xi = FLOAT_TYPE(0); + if (col < ncols) { + xi = FLOAT_TYPE(data_a[a_offset + col]); + } + sum += xi * xi; } + + sumsh[tid] = sum; + // sum up partial sums and write back result barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum += sumsh[tid + s]; + sumsh[tid] = sum; + } + barrier(); + } + sum = sumsh[0]; } - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - if (do_multiply) { - if (ncols > p.ne10) { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + if (use_atomic_add) { + // One element per thread when using the atomic add + uint col = tid; + if (do_multiply) { + if (ncols > p.ne10) { + if (col < ncols) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + if (col < ncols) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } } } else { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + if (col < ncols) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } } else { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + if (do_multiply) { + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } } } } + +void main() { + const bool use_atomic_add = p.param3 != 0; + if (use_atomic_add) { + rms_norm(1, true); + return; + } + + // instantiate the rms_norm function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + rms_norm(num_blocks, false); + } else if (num_blocks > 16) { + rms_norm(32, false); + } else if (num_blocks > 8) { + rms_norm(16, false); + } else if (num_blocks > 4) { + rms_norm(8, false); + } else if (num_blocks == 4) { + rms_norm(4, false); + } else if (num_blocks == 3) { + rms_norm(3, false); + } else if (num_blocks == 2) { + rms_norm(2, false); + } else if (num_blocks == 1) { + rms_norm(1, false); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 123ae044914ed..342ea03810464 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -538,13 +538,15 @@ void process_shaders() { s += std::string(dst_f16 ? "_f16" : "_f32"); return s; }; - for (std::string op : {"add", "sub", "mul", "div"}) { + for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { for (auto rte : {false, true}) { + auto source = op == "add_rms" ? std::string("add") : op; auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); - string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + auto add_rms = op == "add_rms" ? "1" : "0"; + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); } } } @@ -745,7 +747,7 @@ void write_output_files() { } std::string suffixes[2] = {"_f32", "_f16"}; - for (const char *op : {"add", "sub", "mul", "div"}) { + for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) { fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e53f8e21a5a2..e60c6450f77b1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5842,6 +5842,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); } + for (uint32_t n : {1, 511, 1025, 8192}) { + test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f)); + } test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); From 5643b4a3bf10371e7268ebf8d9aad962213d643b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 13 Aug 2025 09:58:48 -0500 Subject: [PATCH 2/6] Change add+rms_norm optimization to write out an array of partial sums rather than using atomic add, to make it deterministic. The rms_norm shader fetches a subgroup's worth in parallel and uses subgroupAdd to add them up. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 103 +++++++++------- ggml/src/ggml-vulkan/vulkan-shaders/add.comp | 27 +++- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 115 +++++++----------- .../vulkan-shaders/rms_norm_partials.comp | 65 ++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + tests/test-backend-ops.cpp | 2 +- 6 files changed, 188 insertions(+), 125 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 804aa8442902b..373c748fd144b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -381,9 +381,8 @@ struct vk_device_struct { bool subgroup_shuffle; bool multi_add; - bool atomic_float_add; bool add_rms_fusion; - uint32_t atomic_binding_alignment; + uint32_t partials_binding_alignment; bool integer_dot_product; @@ -492,6 +491,8 @@ struct vk_device_struct { vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_mul_f32; + vk_pipeline pipeline_rms_norm_partials_f32; + vk_pipeline pipeline_rms_norm_mul_partials_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; @@ -1234,13 +1235,13 @@ struct ggml_backend_vk_context { size_t semaphore_idx, event_idx; ggml_vk_garbage_collector gc; - size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_atomic_add, prealloc_size_atomic_add_offset; - vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_atomic_add; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials; vk::Fence fence, almost_ready_fence; bool almost_ready_fence_pending {}; // Set before op_add and unset after op_rms_norm to indicate that the add should - // use atomics to accumulate the square of the vector components - bool do_add_rms_atomic; + // write partial sums to accumulate the square of the vector components + bool do_add_rms_partials; // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. vk_pipeline_struct * prealloc_y_last_pipeline_used {}; @@ -3002,8 +3003,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -3375,7 +3380,6 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_support = false; device->integer_dot_product = false; bool bfloat16_support = false; - bool atomic_float_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -3415,8 +3419,6 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif - } else if (strcmp("VK_EXT_shader_atomic_float", properties.extensionName) == 0) { - atomic_float_support = true; } } @@ -3633,14 +3635,6 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } - VkPhysicalDeviceShaderAtomicFloatFeaturesEXT atomic_float_features {}; - atomic_float_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_FLOAT_FEATURES_EXT; - if (atomic_float_support) { - last_struct->pNext = (VkBaseOutStructure *)&atomic_float_features; - last_struct = (VkBaseOutStructure *)&atomic_float_features; - device_extensions.push_back("VK_EXT_shader_atomic_float"); - } - vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -3652,7 +3646,6 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->pipeline_robustness = pl_robustness_features.pipelineRobustness; - device->atomic_float_add = atomic_float_features.shaderBufferFloat32AtomicAdd; device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && @@ -3974,9 +3967,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; device->add_rms_fusion = !device->disable_fusion && - device->subgroup_add && - device->atomic_float_add; - device->atomic_binding_alignment = + device->subgroup_add; + device->partials_binding_alignment = std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); return device; @@ -7144,13 +7136,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_ADD: { if (ctx->num_additional_fused_ops > 0) { - if (ctx->do_add_rms_atomic) { + if (ctx->do_add_rms_partials) { return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; } else { return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; } } - if (ctx->do_add_rms_atomic) { + if (ctx->do_add_rms_partials) { auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; } else { @@ -7279,7 +7271,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + if (ctx->do_add_rms_partials) { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32; + } else { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + } } return nullptr; case GGML_OP_RMS_NORM_BACK: @@ -7792,7 +7788,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } break; case GGML_OP_RMS_NORM: - if (ctx->do_add_rms_atomic) { + if (ctx->do_add_rms_partials) { // Run one element per thread, 128 threads per workgroup elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; } else { @@ -7947,8 +7943,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { - vk_buffer d_A = ctx->prealloc_atomic_add ? ctx->prealloc_atomic_add : d_X; - size_t a_buf_offset = ctx->prealloc_atomic_add ? ctx->prealloc_size_atomic_add_offset : 0; + vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; + size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, @@ -8159,7 +8155,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - 0.0f, 0.0f, ctx->do_add_rms_atomic, + 0.0f, 0.0f, ctx->do_add_rms_partials, }, dryrun); } @@ -8617,23 +8613,38 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } +static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t ne = (uint32_t)node->ne[0]; + const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0]; + const uint32_t num_partials = CEIL_DIV(ne, denom); + return num_partials; +} + +static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node); + const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment); + return num_bytes; +} + static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, ctx->do_add_rms_atomic, + op_params[0], 0.0f, (int32_t)param3, }, dryrun); - if (ctx->do_add_rms_atomic) { - ctx->prealloc_size_atomic_add_offset += ctx->device->atomic_binding_alignment; - ctx->do_add_rms_atomic = false; + if (ctx->do_add_rms_partials) { + ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); + ctx->do_add_rms_partials = false; } } @@ -9912,13 +9923,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); } - if (ctx->prealloc_atomic_add == nullptr || (ctx->prealloc_size_atomic_add > 0 && ctx->prealloc_atomic_add->size < ctx->prealloc_size_atomic_add)) { - VK_LOG_MEMORY("ggml_vk_preallocate_buffers(atomic_add_size: " << ctx->prealloc_atomic_add << ")"); + if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")"); // Resize buffer - if (ctx->prealloc_atomic_add != nullptr) { - ggml_vk_destroy_buffer(ctx->prealloc_atomic_add); + if (ctx->prealloc_add_rms_partials != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials); } - ctx->prealloc_atomic_add = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_atomic_add); + ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials); } } @@ -9983,9 +9994,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 && ctx->device->add_rms_fusion) { if (dryrun) { - ctx->prealloc_size_atomic_add += ctx->device->atomic_binding_alignment; + ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); } - ctx->do_add_rms_atomic = true; + ctx->do_add_rms_partials = true; } break; case GGML_OP_REPEAT: @@ -10113,7 +10124,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (node->op == GGML_OP_RMS_NORM) { - ctx->do_add_rms_atomic = false; + ctx->do_add_rms_partials = false; } return false; } @@ -11184,9 +11195,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); } - ctx->prealloc_size_atomic_add = 0; - ctx->prealloc_size_atomic_add_offset = 0; - ctx->do_add_rms_atomic = false; + ctx->prealloc_size_add_rms_partials = 0; + ctx->prealloc_size_add_rms_partials_offset = 0; + ctx->do_add_rms_partials = false; uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { @@ -11256,7 +11267,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; - if (ctx->prealloc_size_atomic_add) { + if (ctx->prealloc_size_add_rms_partials) { if (ctx->compute_ctx.expired()) { compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; @@ -11264,8 +11275,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } else { compute_ctx = ctx->compute_ctx.lock(); } - // initialize atomic sums to zero. - ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_atomic_add, 0, 0, ctx->prealloc_size_atomic_add); + // initialize partial sums to zero. + ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); } // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 9feb09f113912..00cf2dd62fddb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -2,7 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require #if ADD_RMS -#extension GL_EXT_shader_atomic_float : enable #extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_basic : enable #endif @@ -12,12 +11,18 @@ const uint num_threads = 256; -layout (binding = 3) buffer AtomBuf {float data_atom;}; +layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];}; layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + void main() { uint idx = get_idx(); + uint orig_idx = idx; // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; @@ -41,9 +46,23 @@ void main() { #if ADD_RMS if (p.param3 != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; sum_sq = subgroupAdd(sum_sq); - if (sum_sq != 0 && gl_SubgroupInvocationID == 0) { - atomicAdd(data_atom, sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index fbd51ddc1f5d8..41197e9301ad8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -4,26 +4,23 @@ #include "types.comp" #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 128 +#define BLOCK_SIZE 512 layout (constant_id = 1) const bool do_multiply = false; layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 3) readonly buffer AtomBuf {float precomputed_sum;}; - shared FLOAT_TYPE sumsh[BLOCK_SIZE]; -void rms_norm(uint num_iters, bool use_atomic_add) { +void rms_norm(uint num_iters) { const uint ncols = p.ne00; const uint nrows = gl_NumWorkGroups.x; const uint nchannels = gl_NumWorkGroups.y; - const uint row = use_atomic_add ? 0 : gl_WorkGroupID.x; + const uint row = gl_WorkGroupID.x; const uint channel = gl_WorkGroupID.y; const uint samp = gl_WorkGroupID.z; - // When using atomic add, the work is split across multiple workgroups in the x dimension - const uint tid = use_atomic_add ? gl_GlobalInvocationID.x : gl_LocalInvocationID.x; + const uint tid = gl_LocalInvocationID.x; const uint stride_row = p.nb01; const uint stride_channel = p.nb02; @@ -35,104 +32,74 @@ void rms_norm(uint num_iters, bool use_atomic_add) { FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp - if (use_atomic_add) { - sum = precomputed_sum; - } else { - [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { - FLOAT_TYPE xi = FLOAT_TYPE(0); - if (col < ncols) { - xi = FLOAT_TYPE(data_a[a_offset + col]); - } - sum += xi * xi; + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + FLOAT_TYPE xi = FLOAT_TYPE(0); + if (col < ncols) { + xi = FLOAT_TYPE(data_a[a_offset + col]); } + sum += xi * xi; + } - sumsh[tid] = sum; - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - sum += sumsh[tid + s]; - sumsh[tid] = sum; - } - barrier(); + sumsh[tid] = sum; + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum += sumsh[tid + s]; + sumsh[tid] = sum; } - sum = sumsh[0]; + barrier(); } + sum = sumsh[0]; const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - if (use_atomic_add) { - // One element per thread when using the atomic add - uint col = tid; - if (do_multiply) { - if (ncols > p.ne10) { - if (col < ncols) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); - } - } else { - if (col < ncols) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); - } - } - } else { - if (col < ncols) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); - } - } - } else { - if (do_multiply) { - if (ncols > p.ne10) { - [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { - if (col >= ncols) { - continue; - } - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); - } - } else { - [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { - if (col >= ncols) { - continue; - } - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + if (do_multiply) { + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); } } else { [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { if (col >= ncols) { continue; } - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); } } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } } } void main() { - const bool use_atomic_add = p.param3 != 0; - if (use_atomic_add) { - rms_norm(1, true); - return; - } - // instantiate the rms_norm function for several different // dimensions, to allow loop unrolling uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE; if (num_blocks > 32) { - rms_norm(num_blocks, false); + rms_norm(num_blocks); } else if (num_blocks > 16) { - rms_norm(32, false); + rms_norm(32); } else if (num_blocks > 8) { - rms_norm(16, false); + rms_norm(16); } else if (num_blocks > 4) { - rms_norm(8, false); + rms_norm(8); } else if (num_blocks == 4) { - rms_norm(4, false); + rms_norm(4); } else if (num_blocks == 3) { - rms_norm(3, false); + rms_norm(3); } else if (num_blocks == 2) { - rms_norm(2, false); + rms_norm(2); } else if (num_blocks == 1) { - rms_norm(1, false); + rms_norm(1); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp new file mode 100644 index 0000000000000..ba4677c293392 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "generic_binary_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +#define BLOCK_SIZE 128 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];}; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = 0; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + // The work is split across multiple workgroups in the x dimension. Each invocation + // processes one element + const uint tid = gl_GlobalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + uint32_t num_partials = p.param3; + for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) { + sum += partial_sums[i]; + } + sum = subgroupAdd(sum); + + uint col = tid; + if (col >= ncols) { + return; + } + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 342ea03810464..54beb3068f792 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -503,6 +503,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e60c6450f77b1..a97d60f453fc5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5842,7 +5842,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); } - for (uint32_t n : {1, 511, 1025, 8192}) { + for (uint32_t n : {1, 511, 1025, 8192, 33*512}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f)); } From 7856a7a89339167c4956e3953b6444941869276a Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 16 Aug 2025 23:29:30 -0500 Subject: [PATCH 3/6] complete rebase against fused adds - multi_add shader can also compute partial sums --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 53 +++++++++++++------ .../ggml-vulkan/vulkan-shaders/multi_add.comp | 42 ++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 +- tests/test-backend-ops.cpp | 14 +++-- 4 files changed, 89 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 373c748fd144b..142c1f8bf8b6f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -102,9 +102,9 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } struct ggml_backend_vk_context; -#define MAX_PARAMETER_COUNT 8 +#define MAX_PARAMETER_COUNT 12 // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. -#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2) +#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) struct vk_pipeline_struct { std::string name; @@ -468,6 +468,7 @@ struct vk_device_struct { // indexed by num_additional_fused_ops == num_adds - 1 vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; + vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS]; vk_pipeline pipeline_add_id_f32; @@ -830,8 +831,13 @@ struct vk_op_multi_add_push_constants { uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; // strides for srcs+dst - uint32_t nb[8][4]; + uint32_t nb[MAX_PARAMETER_COUNT][4]; + + uint32_t rms_partials; }; +// update multi_add.comp if this changes +static_assert(MAX_PARAMETER_COUNT == 12); +static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_add_id_push_constants { uint32_t ne0; @@ -3098,7 +3104,8 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->multi_add) { for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) { - ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); } } @@ -7107,7 +7114,7 @@ static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) return elements; } -static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: GGML_ASSERT(src1->type == GGML_TYPE_I32); @@ -8053,7 +8060,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor *tensors[MAX_PARAMETER_COUNT]; uint32_t num_srcs = ctx->num_additional_fused_ops + 2; uint32_t num_tensors = num_srcs + 1; - GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT); + GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT); tensors[0] = first_node->src[0]; tensors[1] = first_node->src[1]; @@ -8080,8 +8087,9 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float); pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float); } + pc.rms_partials = ctx->do_add_rms_partials; - vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op); if (pipeline == nullptr) { std::cerr << "ggml_vulkan: Error: Missing multi_add"; @@ -8119,6 +8127,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, buf[i] = buf[0]; offset[i] = 0; } + if (ctx->do_add_rms_partials) { + buf[num_tensors] = ctx->prealloc_add_rms_partials; + offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset; + } std::array elements; @@ -8131,6 +8143,7 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, elements = { ne, 1, 1 }; } + static_assert(MAX_PARAMETER_COUNT == 12); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE }, @@ -8141,6 +8154,10 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE }, vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE }, vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE }, + vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE }, }, pc, elements); } @@ -9988,17 +10005,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } break; case GGML_OP_ADD: - if (node_idx + 1 < cgraph->n_nodes && - cgraph->nodes[node_idx + 1]->op == GGML_OP_RMS_NORM && - cgraph->nodes[node_idx + 1]->src[0] == cgraph->nodes[node_idx] && - ggml_nrows(cgraph->nodes[node_idx + 1]) == 1 && - ctx->device->add_rms_fusion) { - if (dryrun) { - ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + if (dryrun) { + ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + } + ctx->do_add_rms_partials = true; } - ctx->do_add_rms_partials = true; - } - break; + } break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 0c7acb7060f07..f2f218b04ac34 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -3,6 +3,10 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_nonuniform_qualifier : enable #extension GL_EXT_control_flow_attributes : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif #include "rte.comp" #include "types.comp" @@ -14,12 +18,16 @@ layout (push_constant) uniform parameter2 uint ne20; uint ne21; uint ne22; uint ne23; // strides for srcs+dst - uint nb[8][4]; + uint nb[12][4]; + + uint rms_partials; } p; layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; +layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[]; + layout(constant_id = 0) const uint num_srcs = 2; uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) { @@ -42,14 +50,22 @@ const uint num_threads = 256; layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + void main() { uint idx = get_idx(); + uint orig_idx = idx; uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23; // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; + FLOAT_TYPE sum_sq = 0; + [[unroll]] for (uint i = 0; i < num_iter; ++i) { if (idx >= ne) { continue; @@ -61,8 +77,32 @@ void main() { [[unroll]] for (uint s = 0; s < num_srcs; ++s) { sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); } + sum_sq += sum*sum; d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); idx += num_threads; } + +#if ADD_RMS + if (p.rms_partials != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 54beb3068f792..50a27748317be 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -690,7 +690,8 @@ void process_shaders() { string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); for (auto &c : compiles) { c.wait(); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a97d60f453fc5..1e1e43f50594d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2858,6 +2858,7 @@ struct test_rms_norm_mul_add : public test_case { const std::array ne; const float eps; const bool broadcast; + const bool multi_add; // test a sequence of adds feeding into rms_norm std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -2867,13 +2868,13 @@ struct test_rms_norm_mul_add : public test_case { bool run_whole_graph() override { return true; } std::string vars() override { - return VARS_TO_STR4(type, ne, eps, broadcast); + return VARS_TO_STR5(type, ne, eps, broadcast, multi_add); } test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, - float eps = 1e-6f, bool broadcast = false) - : type(type), ne(ne), eps(eps), broadcast(broadcast) {} + float eps = 1e-6f, bool broadcast = false, bool multi_add = false) + : type(type), ne(ne), eps(eps), broadcast(broadcast), multi_add(multi_add) {} ggml_tensor * build_graph(ggml_context * ctx) override { std::array broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4}; @@ -2891,6 +2892,9 @@ struct test_rms_norm_mul_add : public test_case { // Use a, b and c early, so we don't end up with an OP_NONE between rms_norm and mul a = ggml_add(ctx, ggml_add(ctx, a, b), c); + if (multi_add) { + a = ggml_add(ctx, ggml_add(ctx, a, b), c); + } ggml_tensor * out = ggml_add(ctx, ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b), c); ggml_set_name(out, "out"); @@ -5843,7 +5847,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); } for (uint32_t n : {1, 511, 1025, 8192, 33*512}) { - test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f)); + for (bool multi_add : {false, true}) { + test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add)); + } } test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); From a675d0c3eb3925856ea590f43c262343076b4fbb Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 Aug 2025 10:15:06 -0500 Subject: [PATCH 4/6] fix validation errors --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 142c1f8bf8b6f..6a7c147a0dd7f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7950,8 +7950,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { - vk_buffer d_A = ctx->prealloc_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; - size_t a_buf_offset = ctx->prealloc_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; + size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, From 8d382bcb762e56e0d36d08d813c0c6adb37e72e9 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 Aug 2025 11:54:34 -0500 Subject: [PATCH 5/6] disable add_rms_fusion for Intel due to possible driver bug --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6a7c147a0dd7f..a80023d44d2ee 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3974,7 +3974,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; device->add_rms_fusion = !device->disable_fusion && - device->subgroup_add; + device->subgroup_add && + device->vendor_id != VK_VENDOR_ID_INTEL; device->partials_binding_alignment = std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); From e97e226a327b91af5b34895187d78d2f80793459 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 23 Aug 2025 09:20:17 -0500 Subject: [PATCH 6/6] resolve against #15489, sync after clearing partial sums --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a80023d44d2ee..2c8d9ecaa0a03 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7953,7 +7953,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; - ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, @@ -11297,6 +11296,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } // initialize partial sums to zero. ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); + ggml_vk_sync_buffers(ctx, compute_ctx); } // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.