1
1
#include " conv2d.cuh"
2
2
#include " convert.cuh"
3
3
4
- #include < cstdint>
5
-
6
4
struct conv_params {
7
5
const uint IW, IH;
8
6
const uint OW, OH;
@@ -88,6 +86,9 @@ template <typename layout> __device__ class float_mma {
88
86
#pragma unroll
89
87
for (uint i = 0 ; i < num_acc; i++) {
90
88
const uint e = lane_id + i * WARP_SIZE;
89
+ if (e >= WMMA_M * WMMA_N) {
90
+ continue ;
91
+ }
91
92
const uint m = e / WMMA_N;
92
93
const uint n = e % WMMA_N;
93
94
@@ -109,6 +110,9 @@ template <typename layout> __device__ class float_mma {
109
110
#pragma unroll
110
111
for (uint i = 0 ; i < num_acc; i++) {
111
112
const uint e = lane_id + i * WARP_SIZE;
113
+ if (e >= WMMA_M * WMMA_N) {
114
+ continue ;
115
+ }
112
116
const uint m = e / WMMA_N;
113
117
const uint n = e % WMMA_N;
114
118
@@ -164,6 +168,9 @@ template <typename layout> class half_mma {
164
168
# pragma unroll
165
169
for (uint l = 0 ; l < tile_acc::ne; ++l) {
166
170
const uint e = tile_acc::get_i (l) * WMMA_N + tile_acc::get_j (l);
171
+ if (e >= WMMA_M * WMMA_N) {
172
+ continue ;
173
+ }
167
174
const uint m = e / WMMA_N;
168
175
const uint n = e % WMMA_N;
169
176
@@ -313,8 +320,8 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo
313
320
const int in_y = calculate_input_coord (oh, kh, P.ST_Y , P.DL_Y , P.PD_Y );
314
321
const int in_x = calculate_input_coord (ow, kw, P.ST_X , P.DL_X , P.PD_X );
315
322
if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW ) {
316
- const int64_t in_idx = layout::input_index (n, ic, in_y, in_x, P);
317
- val = ggml_cuda_cast<T>(IN[in_idx]);
323
+ const uint64_t in_idx = layout::input_index (n, ic, in_y, in_x, P);
324
+ val = ggml_cuda_cast<T>(IN[in_idx]);
318
325
}
319
326
}
320
327
B_sh[brow * BS_NOHOW + bcol] = val;
@@ -359,7 +366,7 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo
359
366
}
360
367
361
368
template <typename T, template <typename > class mma >
362
- static void conv2d_cuda (const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
369
+ static void conv2d_cuda (const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) {
363
370
GGML_ASSERT (BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N);
364
371
365
372
const uint NUM_BL_OC = (P.OC + BS_OC - 1 ) / BS_OC;
0 commit comments