Skip to content

Commit 35cf888

Browse files
authored
vulkan: Implement GGML_OP_TRI (#17503)
* vulkan: Implement GGML_OP_TRI * check types match
1 parent 15d2b46 commit 35cf888

File tree

3 files changed

+73
-1
lines changed

3 files changed

+73
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ struct vk_device_struct {
649649
vk_pipeline pipeline_sin_f32;
650650
vk_pipeline pipeline_cos_f32;
651651
vk_pipeline pipeline_log[2];
652+
vk_pipeline pipeline_tri[2];
652653
vk_pipeline pipeline_clamp_f32;
653654
vk_pipeline pipeline_pad_f32;
654655
vk_pipeline pipeline_roll_f32;
@@ -3876,6 +3877,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
38763877
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38773878
}
38783879

3880+
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3881+
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3882+
38793883
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
38803884

38813885
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8290,6 +8294,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82908294
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
82918295
}
82928296
return nullptr;
8297+
case GGML_OP_TRI:
8298+
if (src0->type == dst->type &&
8299+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8300+
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
8301+
}
8302+
return nullptr;
82938303
case GGML_OP_CLAMP:
82948304
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
82958305
return ctx->device->pipeline_clamp_f32;
@@ -8991,6 +9001,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
89919001
case GGML_OP_SIN:
89929002
case GGML_OP_COS:
89939003
case GGML_OP_LOG:
9004+
case GGML_OP_TRI:
89949005
case GGML_OP_CLAMP:
89959006
case GGML_OP_PAD:
89969007
case GGML_OP_ROLL:
@@ -9671,6 +9682,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
96719682
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
96729683
}
96739684

9685+
static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9686+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
9687+
p.param1 = ggml_get_op_params_f32(dst, 0);
9688+
9689+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
9690+
}
9691+
96749692
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
96759693
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
96769694
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11794,6 +11812,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1179411812
case GGML_OP_LOG:
1179511813
ggml_vk_log(ctx, compute_ctx, src0, node);
1179611814

11815+
break;
11816+
case GGML_OP_TRI:
11817+
ggml_vk_tri(ctx, compute_ctx, src0, node);
11818+
1179711819
break;
1179811820
case GGML_OP_CLAMP:
1179911821
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -13919,7 +13941,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1391913941
case GGML_OP_OPT_STEP_SGD:
1392013942
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1392113943
case GGML_OP_LOG:
13922-
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
13944+
case GGML_OP_TRI:
13945+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
13946+
op->type == op->src[0]->type;
1392313947
case GGML_OP_ARGSORT:
1392413948
{
1392513949
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
@@ -14510,6 +14534,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1451014534
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
1451114535
} else if (tensor->op == GGML_OP_LOG) {
1451214536
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
14537+
} else if (tensor->op == GGML_OP_TRI) {
14538+
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
1451314539
} else if (tensor->op == GGML_OP_CLAMP) {
1451414540
const float * params = (const float *)tensor->op_params;
1451514541
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#version 450
2+
3+
#include "rte.glsl"
4+
#include "types.glsl"
5+
#include "generic_unary_head.glsl"
6+
7+
#define GGML_TRI_TYPE_UPPER_DIAG 0
8+
#define GGML_TRI_TYPE_UPPER 1
9+
#define GGML_TRI_TYPE_LOWER_DIAG 2
10+
#define GGML_TRI_TYPE_LOWER 3
11+
12+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
13+
14+
void main() {
15+
const uint idx = get_idx();
16+
17+
if (idx >= p.ne) {
18+
return;
19+
}
20+
21+
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
22+
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
23+
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
24+
const uint i02_offset = i02*p.ne01*p.ne00;
25+
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
26+
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
27+
28+
int param = floatBitsToInt(p.param1);
29+
bool pass = false;
30+
switch (param) {
31+
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
32+
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
33+
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
34+
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
35+
}
36+
37+
if (pass) {
38+
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
39+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
40+
} else {
41+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
42+
}
43+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,9 @@ void process_shaders() {
846846
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
847847
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
848848

849+
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
850+
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
851+
849852
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
850853
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
851854

0 commit comments

Comments
 (0)