@@ -483,6 +483,7 @@ struct vk_device_struct {
483
483
vk_pipeline pipeline_rwkv_wkv6_f32;
484
484
vk_pipeline pipeline_rwkv_wkv7_f32;
485
485
vk_pipeline pipeline_opt_step_adamw_f32;
486
+ vk_pipeline pipeline_opt_step_sgd_f32;
486
487
vk_pipeline pipeline_conv2d_f32;
487
488
vk_pipeline pipeline_conv2d_dw_whcn_f32;
488
489
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
@@ -3046,6 +3047,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
3046
3047
3047
3048
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3048
3049
3050
+ ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
3051
+
3049
3052
// conv2d
3050
3053
uint32_t conv2d_WG_SIZE = 256;
3051
3054
uint32_t conv2d_BS_K = 128;
@@ -6954,7 +6957,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6954
6957
return nullptr;
6955
6958
case GGML_OP_OPT_STEP_SGD:
6956
6959
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6957
- // TODO
6960
+ return ctx->device->pipeline_opt_step_sgd_f32;
6958
6961
}
6959
6962
return nullptr;
6960
6963
case GGML_OP_LEAKY_RELU:
@@ -7430,6 +7433,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7430
7433
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
7431
7434
ggml_vk_sync_buffers(subctx);
7432
7435
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
7436
+ } else if (op == GGML_OP_OPT_STEP_SGD) {
7437
+ // OPT_STEP_SGD works on src0, it does not need dst
7438
+ ggml_vk_sync_buffers(subctx);
7439
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
7433
7440
} else if (use_src2) {
7434
7441
ggml_vk_sync_buffers(subctx);
7435
7442
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
@@ -7768,18 +7775,10 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
7768
7775
);
7769
7776
}
7770
7777
7771
- static void ggml_vk_op_f32_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
7772
- GGML_ASSERT(0 && "SGD vulkan unimplemented"); // TODO
7773
- }
7774
-
7775
- static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
7778
+ static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
7776
7779
const size_t n = ggml_nelements(dst->src[0]);
7777
7780
7778
- ggml_vk_op_f32_opt_step_sgd(
7779
- ctx, subctx, dst,
7780
- { (uint32_t)n, 0, 0.0f, 0.0f },
7781
- dryrun
7782
- );
7781
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
7783
7782
}
7784
7783
7785
7784
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9313,6 +9312,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9313
9312
case GGML_OP_LEAKY_RELU:
9314
9313
case GGML_OP_FLASH_ATTN_EXT:
9315
9314
case GGML_OP_OPT_STEP_ADAMW:
9315
+ case GGML_OP_OPT_STEP_SGD:
9316
9316
break;
9317
9317
default:
9318
9318
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -9377,6 +9377,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9377
9377
case GGML_OP_CONV_2D:
9378
9378
case GGML_OP_CONV_2D_DW:
9379
9379
case GGML_OP_LEAKY_RELU:
9380
+ case GGML_OP_OPT_STEP_SGD:
9380
9381
{
9381
9382
// These operations all go through ggml_vk_op_f32, so short-circuit and
9382
9383
// do the only thing needed for the dryrun.
@@ -9624,8 +9625,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9624
9625
break;
9625
9626
9626
9627
case GGML_OP_OPT_STEP_SGD:
9627
- return false; // TODO
9628
- ggml_vk_opt_step_sgd(ctx, compute_ctx, node, dryrun);
9628
+ ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
9629
9629
9630
9630
break;
9631
9631
default:
@@ -9729,10 +9729,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
9729
9729
case GGML_OP_REPEAT:
9730
9730
case GGML_OP_REPEAT_BACK:
9731
9731
case GGML_OP_OPT_STEP_ADAMW:
9732
+ case GGML_OP_OPT_STEP_SGD:
9732
9733
buf = tensor->buffer;
9733
9734
break;
9734
- case GGML_OP_OPT_STEP_SGD:
9735
- return false;
9736
9735
case GGML_OP_UNARY:
9737
9736
switch (ggml_get_unary_op(tensor)) {
9738
9737
case GGML_UNARY_OP_SILU:
@@ -10860,6 +10859,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10860
10859
case GGML_OP_SIN:
10861
10860
case GGML_OP_COS:
10862
10861
case GGML_OP_CLAMP:
10862
+ case GGML_OP_LEAKY_RELU:
10863
+ case GGML_OP_OPT_STEP_ADAMW:
10864
+ case GGML_OP_OPT_STEP_SGD:
10863
10865
return op->src[0]->type == GGML_TYPE_F32;
10864
10866
case GGML_OP_UPSCALE:
10865
10867
case GGML_OP_ACC:
@@ -10881,11 +10883,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10881
10883
case GGML_OP_POOL_2D:
10882
10884
case GGML_OP_RWKV_WKV6:
10883
10885
case GGML_OP_RWKV_WKV7:
10884
- case GGML_OP_LEAKY_RELU:
10885
- case GGML_OP_OPT_STEP_ADAMW:
10886
10886
return true;
10887
- case GGML_OP_OPT_STEP_SGD:
10888
- return false;
10889
10887
case GGML_OP_CONV_TRANSPOSE_1D:
10890
10888
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
10891
10889
case GGML_OP_CONV_2D:
0 commit comments