Skip to content

Commit 74f4907

Browse files
committed
CUDA: conv2d minor fixes
1 parent 51f85ff commit 74f4907

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#include "conv2d.cuh"
22
#include "convert.cuh"
33

4-
#include <cstdint>
5-
64
struct conv_params {
75
const uint IW, IH;
86
const uint OW, OH;
@@ -88,6 +86,9 @@ template <typename layout> __device__ class float_mma {
8886
#pragma unroll
8987
for (uint i = 0; i < num_acc; i++) {
9088
const uint e = lane_id + i * WARP_SIZE;
89+
if (e >= WMMA_M * WMMA_N) {
90+
continue;
91+
}
9192
const uint m = e / WMMA_N;
9293
const uint n = e % WMMA_N;
9394

@@ -109,6 +110,9 @@ template <typename layout> __device__ class float_mma {
109110
#pragma unroll
110111
for (uint i = 0; i < num_acc; i++) {
111112
const uint e = lane_id + i * WARP_SIZE;
113+
if (e >= WMMA_M * WMMA_N) {
114+
continue;
115+
}
112116
const uint m = e / WMMA_N;
113117
const uint n = e % WMMA_N;
114118

@@ -164,6 +168,9 @@ template <typename layout> class half_mma {
164168
# pragma unroll
165169
for (uint l = 0; l < tile_acc::ne; ++l) {
166170
const uint e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l);
171+
if (e >= WMMA_M * WMMA_N) {
172+
continue;
173+
}
167174
const uint m = e / WMMA_N;
168175
const uint n = e % WMMA_N;
169176

@@ -313,8 +320,8 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo
313320
const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y);
314321
const int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X);
315322
if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) {
316-
const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P);
317-
val = ggml_cuda_cast<T>(IN[in_idx]);
323+
const uint64_t in_idx = layout::input_index(n, ic, in_y, in_x, P);
324+
val = ggml_cuda_cast<T>(IN[in_idx]);
318325
}
319326
}
320327
B_sh[brow * BS_NOHOW + bcol] = val;
@@ -359,7 +366,7 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo
359366
}
360367

361368
template <typename T, template <typename> class mma>
362-
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
369+
static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) {
363370
GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N);
364371

365372
const uint NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC;

0 commit comments

Comments
 (0)