Skip to content

Commit 3f944dd

Browse files
committed
Deduplicate conversion function
1 parent 10a80a1 commit 3f944dd

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

ggml/src/ggml-sycl/set_rows.cpp

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
#include "set_rows.hpp"
22

3-
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
4-
5-
static void set_rows_1_f32_f32(const char * src, char * dst) {
6-
const float * src_f = (const float *) src;
7-
float * dst_f = (float *) dst;
8-
*dst_f = *src_f;
9-
}
10-
11-
static void set_rows_1_f32_f16(const char * src, char * dst) {
12-
const float * src_f = (const float *) src;
13-
sycl::half * dst_h = (sycl::half *) dst;
14-
*dst_h = sycl::vec<float, 1>(*src_f).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
3+
template<typename TIn, typename TOut>
4+
static inline void convert(const char* src, char* dst) {
5+
auto src_val = *reinterpret_cast<const TIn*>(src);
6+
auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut>()[0];
7+
*reinterpret_cast<TOut*>(dst) = dst_val;
158
}
169

17-
template<set_rows_kernel_t set_rows_1>
10+
template<typename TIn, typename TOut>
1811
static void k_set_rows(
1912
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
2013
const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
@@ -38,18 +31,17 @@ static void k_set_rows(
3831

3932
const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
4033

41-
4234
const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
4335
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
4436

4537
for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) {
4638
const char * src_elem = src0_row + col * src_type_size;
4739
char * dst_elem = dst_row_ptr + col * dst_type_size;
48-
set_rows_1(src_elem, dst_elem);
40+
convert<TIn, TOut>(src_elem, dst_elem);
4941
}
5042
}
5143

52-
template<set_rows_kernel_t set_rows_1>
44+
template<typename TIn, typename TOut>
5345
static void set_rows_sycl(
5446
const char * src0_d, const int64_t * src1_d, char * dst_d,
5547
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -73,7 +65,7 @@ static void set_rows_sycl(
7365
stream,
7466
sycl::nd_range<3>(grid_size * block_size, block_size),
7567
[=](sycl::nd_item<3> item_ct1) {
76-
k_set_rows<set_rows_1>(
68+
k_set_rows<TIn, TOut>(
7769
src0_d, src1_d, dst_d,
7870
ne00, ne01, ne11, ne12,
7971
nb01, nb02, nb03,
@@ -103,7 +95,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10395
dpct::queue_ptr stream = ctx.stream();
10496
switch (dst->type) {
10597
case GGML_TYPE_F32:
106-
set_rows_sycl<set_rows_1_f32_f32>(
98+
set_rows_sycl<float, float>(
10799
(const char *)dst->src[0]->data, src1_dd, (char *)dst->data,
108100
ne00, ne01, ne02, ne03,
109101
ne11, ne12,
@@ -116,7 +108,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
116108
break;
117109
case GGML_TYPE_F16:
118110
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
119-
set_rows_sycl<set_rows_1_f32_f16>(
111+
set_rows_sycl<float, sycl::half>(
120112
(const char *)dst->src[0]->data, src1_dd, (char *)dst->data,
121113
ne00, ne01, ne02, ne03,
122114
ne11, ne12,

0 commit comments

Comments
 (0)