@@ -649,6 +649,7 @@ struct vk_device_struct {
649649 vk_pipeline pipeline_sin_f32;
650650 vk_pipeline pipeline_cos_f32;
651651 vk_pipeline pipeline_log[2];
652+ vk_pipeline pipeline_tri[2];
652653 vk_pipeline pipeline_clamp_f32;
653654 vk_pipeline pipeline_pad_f32;
654655 vk_pipeline pipeline_roll_f32;
@@ -3876,6 +3877,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
38763877 ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38773878 }
38783879
3880+ ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3881+ ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3882+
38793883 ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38803884
38813885 ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8290,6 +8294,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82908294 return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
82918295 }
82928296 return nullptr;
8297+ case GGML_OP_TRI:
8298+ if (src0->type == dst->type &&
8299+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8300+ return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
8301+ }
8302+ return nullptr;
82938303 case GGML_OP_CLAMP:
82948304 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
82958305 return ctx->device->pipeline_clamp_f32;
@@ -8991,6 +9001,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
89919001 case GGML_OP_SIN:
89929002 case GGML_OP_COS:
89939003 case GGML_OP_LOG:
9004+ case GGML_OP_TRI:
89949005 case GGML_OP_CLAMP:
89959006 case GGML_OP_PAD:
89969007 case GGML_OP_ROLL:
@@ -9671,6 +9682,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
96719682 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
96729683}
96739684
9685+ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9686+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
9687+ p.param1 = ggml_get_op_params_f32(dst, 0);
9688+
9689+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
9690+ }
9691+
96749692static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
96759693 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
96769694 p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11794,6 +11812,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1179411812 case GGML_OP_LOG:
1179511813 ggml_vk_log(ctx, compute_ctx, src0, node);
1179611814
11815+ break;
11816+ case GGML_OP_TRI:
11817+ ggml_vk_tri(ctx, compute_ctx, src0, node);
11818+
1179711819 break;
1179811820 case GGML_OP_CLAMP:
1179911821 ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -13919,7 +13941,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1391913941 case GGML_OP_OPT_STEP_SGD:
1392013942 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1392113943 case GGML_OP_LOG:
13922- return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
13944+ case GGML_OP_TRI:
13945+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13946+ op->type == op->src[0]->type;
1392313947 case GGML_OP_ARGSORT:
1392413948 {
1392513949 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
@@ -14510,6 +14534,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1451014534 tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
1451114535 } else if (tensor->op == GGML_OP_LOG) {
1451214536 tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
14537+ } else if (tensor->op == GGML_OP_TRI) {
14538+ tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
1451314539 } else if (tensor->op == GGML_OP_CLAMP) {
1451414540 const float * params = (const float *)tensor->op_params;
1451514541 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
0 commit comments