1
1
#include " set_rows.hpp"
2
2
3
- typedef void (*set_rows_kernel_t )(const char * src, char * dst);
4
-
5
- static void set_rows_1_f32_f32 (const char * src, char * dst) {
6
- const float * src_f = (const float *) src;
7
- float * dst_f = (float *) dst;
8
- *dst_f = *src_f;
9
- }
10
-
11
- static void set_rows_1_f32_f16 (const char * src, char * dst) {
12
- const float * src_f = (const float *) src;
13
- sycl::half * dst_h = (sycl::half *) dst;
14
- *dst_h = sycl::vec<float , 1 >(*src_f).convert <sycl::half, sycl::rounding_mode::automatic>()[0 ];
3
+ template <typename TIn, typename TOut>
4
+ static inline void convert (const char * src, char * dst) {
5
+ auto src_val = *reinterpret_cast <const TIn*>(src);
6
+ auto dst_val = sycl::vec<TIn, 1 >(src_val).template convert <TOut>()[0 ];
7
+ *reinterpret_cast <TOut*>(dst) = dst_val;
15
8
}
16
9
17
- template <set_rows_kernel_t set_rows_1 >
10
+ template <typename TIn, typename TOut >
18
11
static void k_set_rows (
19
12
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
20
13
const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
@@ -38,18 +31,17 @@ static void k_set_rows(
38
31
39
32
const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3 >({nb10, nb11, nb12}, {i10, i11, i12}));
40
33
41
-
42
34
const char * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
43
35
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
44
36
45
37
for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
46
38
const char * src_elem = src0_row + col * src_type_size;
47
39
char * dst_elem = dst_row_ptr + col * dst_type_size;
48
- set_rows_1 (src_elem, dst_elem);
40
+ convert<TIn, TOut> (src_elem, dst_elem);
49
41
}
50
42
}
51
43
52
- template <set_rows_kernel_t set_rows_1 >
44
+ template <typename TIn, typename TOut >
53
45
static void set_rows_sycl (
54
46
const char * src0_d, const int64_t * src1_d, char * dst_d,
55
47
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -73,7 +65,7 @@ static void set_rows_sycl(
73
65
stream,
74
66
sycl::nd_range<3 >(grid_size * block_size, block_size),
75
67
[=](sycl::nd_item<3 > item_ct1) {
76
- k_set_rows<set_rows_1 >(
68
+ k_set_rows<TIn, TOut >(
77
69
src0_d, src1_d, dst_d,
78
70
ne00, ne01, ne11, ne12,
79
71
nb01, nb02, nb03,
@@ -103,7 +95,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
103
95
dpct::queue_ptr stream = ctx.stream ();
104
96
switch (dst->type ) {
105
97
case GGML_TYPE_F32:
106
- set_rows_sycl<set_rows_1_f32_f32 >(
98
+ set_rows_sycl<float , float >(
107
99
(const char *)dst->src [0 ]->data , src1_dd, (char *)dst->data ,
108
100
ne00, ne01, ne02, ne03,
109
101
ne11, ne12,
@@ -116,7 +108,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
116
108
break ;
117
109
case GGML_TYPE_F16:
118
110
dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
119
- set_rows_sycl<set_rows_1_f32_f16 >(
111
+ set_rows_sycl<float , sycl::half >(
120
112
(const char *)dst->src [0 ]->data , src1_dd, (char *)dst->data ,
121
113
ne00, ne01, ne02, ne03,
122
114
ne11, ne12,
0 commit comments