@@ -36,25 +36,16 @@ static void k_set_rows(
36
36
const int i11 = i02 % ne11;
37
37
const int i10 = i01;
38
38
39
- const int64_t dst_row = *(const int64_t *)((const char *)src1 + i10* nb10 + i11* nb11 + i12*nb12 );
39
+ const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset< 3 >({ nb10, nb11, nb12}, {i10, i11, i12}) );
40
40
41
- const char * src0_row = src0 + i01*nb01 + i02*nb02 + i03*nb03;
41
+
42
+ const char * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
42
43
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
43
- // Optimize for same-type operations: use collective memory copy
44
- if (src_type_size == dst_type_size) {
45
- // All threads in the work-group cooperatively copy the row
46
- const size_t row_bytes = ne00 * src_type_size;
47
- // Each thread copies a chunk of the row
48
- for (size_t byte_idx = item_ct1.get_local_id (0 ); byte_idx < row_bytes; byte_idx += item_ct1.get_local_range (0 )) {
49
- dst_row_ptr[byte_idx] = src0_row[byte_idx];
50
- }
51
- } else {
52
- // Type conversion required, use element-wise approach
53
- for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
54
- const char * src_elem = src0_row + col * src_type_size;
55
- char * dst_elem = dst_row_ptr + col * dst_type_size;
56
- set_rows_1 (src_elem, dst_elem);
57
- }
44
+
45
+ for (int col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
46
+ const char * src_elem = src0_row + col * src_type_size;
47
+ char * dst_elem = dst_row_ptr + col * dst_type_size;
48
+ set_rows_1 (src_elem, dst_elem);
58
49
}
59
50
}
60
51
@@ -68,10 +59,10 @@ static void set_rows_sycl(
68
59
const size_t src_type_size, const size_t dst_type_size,
69
60
queue_ptr stream) {
70
61
71
- const int max_threads_per_row = 256 ; // KEEPING 256 for now
62
+ constexpr int max_threads_per_row = 64 ; // KEEPING 64 for now
72
63
const int threads_per_row = std::min ((int )ne00, max_threads_per_row);
73
64
74
- const int max_threads_per_block = 256 ;
65
+ constexpr int max_threads_per_block = 64 ;
75
66
const int rows_per_block = std::max (1 , max_threads_per_block / threads_per_row);
76
67
77
68
const sycl::range<3 > block_size (1 , rows_per_block, threads_per_row);
0 commit comments