From 19596b1793918df7c62be28fa8a682ffc5bf66cb Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 5 Sep 2025 11:32:58 +0530 Subject: [PATCH 1/8] CUDA: cov2d with tensor core --- ggml/src/ggml-cuda/conv2d.cu | 328 ++++++++++++++++++++++++++-------- ggml/src/ggml-cuda/conv2d.cuh | 9 +- 2 files changed, 265 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 142dd66903aaa..4914393acab2f 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,6 +1,9 @@ #include "conv2d.cuh" #include "convert.cuh" +#include +using namespace nvcuda; + struct conv_params { const int64_t IW, IH; const int64_t OW, OH; @@ -11,112 +14,292 @@ struct conv_params { const int64_t IC, OC; const int64_t B; const int64_t TOTAL; + // helpers + const int64_t IC_KH_KW, N_OH_OW; }; -struct kernel_bounds { - int64_t y_min, y_max; - int64_t x_min, x_max; +auto ceil_div = [](int a, int b) { + return (a + b - 1) / b; }; -__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { - return (a > b) ? a : b; -} - -__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { - return (a < b) ? a : b; -} - -__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { - kernel_bounds bounds; - bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); - bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); - return bounds; -} - -__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { +__device__ __forceinline__ static int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int64_t n, + int64_t c, + int64_t y, + int64_t x, + const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int64_t c_out, + int64_t c_in, + int64_t ky, + int64_t kx, + const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int64_t n, + int64_t c, + int64_t y, + int64_t x, + const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ static void unpack_indices(int64_t global_idx, - const conv_params & P, - int64_t & n, - int64_t & c, - int64_t & out_y, - int64_t & out_x) { - out_x = global_idx % P.OW; - out_y = (global_idx / P.OW) % P.OH; - c = (global_idx / (P.OW * P.OH)) % P.OC; - n = global_idx / (P.OW * P.OH * P.OC); + __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, + int64_t & ic, + int64_t & kh, + int64_t & kw, + const conv_params & P) { + ic = idx / (P.KW * P.KH); + int64_t r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; + } + + __device__ __forceinline__ static void unpack_nohow(int64_t idx, + int64_t & n, + int64_t & oh, + int64_t & ow, + const conv_params & P) { + n = idx / (P.OH * P.OW); + int64_t r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; + } +}; + +class float_mma { + public: + float * buf; + + __device__ __forceinline__ float_mma(float * scratch) { + buf = scratch; + const int lane_id = threadIdx.x % warpSize; +#pragma unroll + for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { + buf[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const float * A_sh, const float * B_sh, const int strideA, const int strideB) { + const int lane_id = threadIdx.x % warpSize; +#pragma unroll + for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { + int m = e / WMMA_N; + int n = e % WMMA_N; + float sum = buf[m * WMMA_N + n]; +#pragma unroll + for (int k = 0; k < WMMA_K; k++) { + float a = A_sh[m * strideA + k]; + float b = B_sh[k * strideB + n]; + sum = fmaf(a, b, sum); + } + buf[m * WMMA_N + n] = sum; + } } + + __device__ __forceinline__ float * store_result() const { return buf; } }; -template -static __global__ void conv2d_kernel(const float * __restrict__ input, - const T * __restrict__ kernel, - float * __restrict__ output, - const conv_params P) { - const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; +class half_mma { + private: + wmma::fragment acc; + wmma::fragment a_frag; + wmma::fragment b_frag; + public: + float * buf; + + __device__ __forceinline__ half_mma(float * scratch) { + buf = scratch; + wmma::fill_fragment(acc, 0.0f); + } + + __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { + wmma::load_matrix_sync(a_frag, A_sh, strideA); + wmma::load_matrix_sync(b_frag, B_sh, strideB); + wmma::mma_sync(acc, a_frag, b_frag, acc); + } - if (global_idx >= P.TOTAL) { - return; + __device__ __forceinline__ float * store_result() const { + wmma::store_matrix_sync(buf, acc, WMMA_N, wmma::mem_row_major); + return buf; } +}; + +template +static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) { + extern __shared__ unsigned char smem_raw[]; + + const int64_t OUTPUT_NUMEL = WMMA_M * WMMA_N; + const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + + const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + + const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int64_t tile_id = blockIdx.x; + const int64_t tile_oc = tile_id / NUM_BL_NOHOW; + const int64_t tile_nohow = tile_id % NUM_BL_NOHOW; + const int64_t BLOCK_OC_BASE = tile_oc * BS_OC; + const int64_t BLOCK_NOHOW_BASE = tile_nohow * BS_NOHOW; + + const int64_t laneId = threadIdx.x % WARP_SIZE; + const int64_t warpId = threadIdx.x / WARP_SIZE; + + const int64_t WARP_OC = warpId / WARPS_PER_NOHOW; + const int64_t WARP_NOHOW = warpId % WARPS_PER_NOHOW; - int64_t n, c_out, out_y, out_x; - Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; - float acc = 0.0f; + unsigned char * ptr = smem_raw; + T * A_sh = reinterpret_cast(ptr); - for (int64_t c_in = 0; c_in < P.IC; ++c_in) { - kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + size_t offsetA = BS_OC * BS_ICKHKW * sizeof(T); + ptr += offsetA; - for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { - const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); + T * B_sh = reinterpret_cast(ptr); + ptr += BS_ICKHKW * BS_NOHOW * sizeof(T); - for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { - const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + float * shared_scratch = reinterpret_cast(ptr); + float * warp_scratch = shared_scratch + warpId * (WMMA_M * WMMA_N); - const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; - const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; - acc += (input_val * ggml_cuda_cast(kernel_val)); + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + + mma acc(warp_scratch); + + const int64_t A_total = BS_OC * BS_ICKHKW; + const int64_t B_total = BS_ICKHKW * BS_NOHOW; + +#pragma unroll + for (int64_t t = 0; t < NUM_IC_TILES; ++t) { +#pragma unroll + for (int64_t tid = (threadIdx.x); tid < A_total; tid += blockDim.x) { + const int row = tid / BS_ICKHKW; + const int col = tid % BS_ICKHKW; + + int64_t shared_oc = BLOCK_OC_BASE + row; + int64_t shared_ickhkw = t * BS_ICKHKW + col; + + T val = ggml_cuda_cast(0); + if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { + int64_t ic, kh, kw; + layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); + + const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } + A_sh[row * BS_ICKHKW + col] = val; + } + +#pragma unroll + for (int64_t tid = (threadIdx.x); tid < B_total; tid += blockDim.x) { + const int brow = tid / BS_NOHOW; + const int bcol = tid % BS_NOHOW; + + int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow; + int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + + T val = ggml_cuda_cast(0); + if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { + int64_t n, oh, ow; + layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); + int64_t ic, kh, kw; + layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); + int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); + int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); + if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) { + const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P); + val = ggml_cuda_cast(IN[in_idx]); + } + } + B_sh[brow * BS_NOHOW + bcol] = val; + } + + __syncthreads(); + +#pragma unroll + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + const T * A_k_ptr = A_warp_base + k_tile; + const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; + + acc.mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); } + __syncthreads(); } - // [N, OC, OH, OW] - output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; + const float * out_buf = acc.store_result(); +#pragma unroll + for (int e = laneId; e < OUTPUT_NUMEL; e += WARP_SIZE) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < (P.N_OH_OW)) { + int64_t n, oh, ow; + layout::unpack_nohow(nohow, n, oh, ow, P); + const int64_t out_idx = layout::output_index(n, oc, oh, ow, P); + OUT[out_idx] = out_buf[e]; + } + } } -template -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; - conv2d_kernel<<>>(X_D, K_D, Y_D, P); +template +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, conv_params P, cudaStream_t st) + +{ + const int64_t NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + + int64_t TOTAL_TILES = NUM_BL_OC * NUM_BL_NOHOW; + TOTAL_TILES = std::min(TOTAL_TILES, (int64_t) INT_MAX); + + const int WARPS_PER_OC = std::max(1, ceil_div(BS_OC, WMMA_M)); + const int WARPS_PER_NOHOW = std::max(1, ceil_div(BS_NOHOW, WMMA_N)); + const int EXPECTED_WARPS = WARPS_PER_OC * WARPS_PER_NOHOW; + int N_THREADS = EXPECTED_WARPS * WARP_SIZE; + + const int MAX_TPB = 1024; + if (N_THREADS > MAX_TPB) { + N_THREADS = (MAX_TPB / WARP_SIZE) * WARP_SIZE; + } + + if (N_THREADS < WARP_SIZE) { + N_THREADS = WARP_SIZE; + } + + const int N_WARPS = N_THREADS / WARP_SIZE; + + // scratch_buff to store output, can't store directly using wmma, + // output mapping is unknown + const int64_t scratch_bytes = N_WARPS * (WMMA_M * WMMA_N) * sizeof(float); + + const int64_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); + const int64_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); + const int64_t shared_bytes = A_bytes + B_bytes + scratch_bytes; + + dim3 grid(TOTAL_TILES, 1, 1); + conv2d_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, conv_params & P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - conv2d_cuda(X_D, K_D, Y_D, P, st); +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, conv_params & P, cudaStream_t st) { + conv2d_cuda(X_D, K_D, Y_D, P, st); } void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -155,11 +338,14 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int OC = kernel->ne[3]; // ouptut_chanles const int B = input->ne[3]; // n_batches - const int64_t total = B * OC * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + const int64_t TOTAL = B * OC * OH * OW; + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t N_OH_OW = B * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, + PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW }; if (kernel->type == GGML_TYPE_F16) { - conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); } else { conv2d_cuda_f32(X_D, K_D, Y_D, params, st); } diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index ce4802c7ed797..ccf5b6192ed08 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,5 +1,12 @@ #pragma once #include "common.cuh" -#define CUDA_CONV2D_BLOCK_SIZE 256 +constexpr int BS_OC = 128; +constexpr int BS_ICKHKW = 16; +constexpr int BS_NOHOW = 128; + +constexpr int WMMA_M = 16; +constexpr int WMMA_N = 16; +constexpr int WMMA_K = 16; + void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 96db6275398f6a9301c206d1ffdf302dae853a90 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 5 Sep 2025 11:50:52 +0530 Subject: [PATCH 2/8] CUDA: conv2d added comment --- ggml/src/ggml-cuda/conv2d.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index ccf5b6192ed08..28c8b9bab6f98 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -5,6 +5,8 @@ constexpr int BS_OC = 128; constexpr int BS_ICKHKW = 16; constexpr int BS_NOHOW = 128; +// supported configuration +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes constexpr int WMMA_M = 16; constexpr int WMMA_N = 16; constexpr int WMMA_K = 16; From 2cd9fb0f56441fff24d67a48c60e1f9432612d31 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 5 Sep 2025 13:31:01 +0530 Subject: [PATCH 3/8] CUDA: conv2d support fp16 without wmma * removed flash-attenion definition --- ggml/src/ggml-cuda/conv2d.cu | 54 ++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 4914393acab2f..9802883f5aeda 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,9 +1,19 @@ #include "conv2d.cuh" #include "convert.cuh" -#include -using namespace nvcuda; - +#ifdef FP16_MMA_AVAILABLE +# if !defined(GGML_USE_HIP) +# include +# ifdef GGML_USE_MUSA +namespace wmma = mtmusa::wmma; +# else +namespace wmma = nvcuda::wmma; +# endif +# else +# include +namespace wmma = rocwmma; +# endif +#endif struct conv_params { const int64_t IW, IH; const int64_t OW, OH; @@ -111,6 +121,8 @@ class float_mma { __device__ __forceinline__ float * store_result() const { return buf; } }; +#if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE))) + class half_mma { private: wmma::fragment acc; @@ -136,6 +148,42 @@ class half_mma { } }; +#else + +class half_mma { + public: + float * buf; + + __device__ __forceinline__ half_mma(float * scratch) { + buf = scratch; + const int lane_id = threadIdx.x % warpSize; +# pragma unroll + for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { + buf[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { + const int lane_id = threadIdx.x % warpSize; +# pragma unroll + for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { + int m = e / WMMA_N; + int n = e % WMMA_N; + float sum = buf[m * WMMA_N + n]; +# pragma unroll + for (int k = 0; k < WMMA_K; k++) { + float a = A_sh[m * strideA + k]; + float b = B_sh[k * strideB + n]; + sum = fmaf(__half2float(a), __half2float(b), sum); + } + buf[m * WMMA_N + n] = sum; + } + } + + __device__ __forceinline__ float * store_result() const { return buf; } +}; +#endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) + template static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) { extern __shared__ unsigned char smem_raw[]; From d633cee19ced956a088ca3313dbc31b8698a0e16 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 12 Sep 2025 09:52:02 +0530 Subject: [PATCH 4/8] CUDA: conv2d using mma.cuh --- ggml/src/ggml-cuda/conv2d.cu | 381 ++++++++++++++++++++-------------- ggml/src/ggml-cuda/conv2d.cuh | 14 +- 2 files changed, 230 insertions(+), 165 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 9802883f5aeda..99799ac6db6f8 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,19 +1,6 @@ #include "conv2d.cuh" #include "convert.cuh" -#ifdef FP16_MMA_AVAILABLE -# if !defined(GGML_USE_HIP) -# include -# ifdef GGML_USE_MUSA -namespace wmma = mtmusa::wmma; -# else -namespace wmma = nvcuda::wmma; -# endif -# else -# include -namespace wmma = rocwmma; -# endif -#endif struct conv_params { const int64_t IW, IH; const int64_t OW, OH; @@ -28,10 +15,6 @@ struct conv_params { const int64_t IC_KH_KW, N_OH_OW; }; -auto ceil_div = [](int a, int b) { - return (a + b - 1) / b; -}; - __device__ __forceinline__ static int calculate_input_coord(int64_t out_coord, int64_t kern_coord, int64_t stride, @@ -88,151 +71,227 @@ struct whcn_layout { } }; -class float_mma { +template class float_mma { public: - float * buf; + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - __device__ __forceinline__ float_mma(float * scratch) { - buf = scratch; - const int lane_id = threadIdx.x % warpSize; + float acc[num_acc]; + + __device__ __forceinline__ float_mma() { #pragma unroll - for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { - buf[i] = 0.0f; + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; } } - __device__ __forceinline__ void mma(const float * A_sh, const float * B_sh, const int strideA, const int strideB) { - const int lane_id = threadIdx.x % warpSize; + __device__ __forceinline__ void clear() { #pragma unroll - for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { - int m = e / WMMA_N; - int n = e % WMMA_N; - float sum = buf[m * WMMA_N + n]; + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const float * __restrict__ A_sh, + const float * __restrict__ B_sh, + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + #pragma unroll for (int k = 0; k < WMMA_K; k++) { - float a = A_sh[m * strideA + k]; - float b = B_sh[k * strideB + n]; - sum = fmaf(a, b, sum); + const float a = A_sh[m * strideA + k]; + const float b = B_sh[k * strideB + n]; + acc[i] = fmaf(a, b, acc[i]); } - buf[m * WMMA_N + n] = sum; } } - __device__ __forceinline__ float * store_result() const { return buf; } + __device__ __forceinline__ void store_result(const int64_t OC_BASE, + const int64_t NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { + const int lane_id = threadIdx.x % WARP_SIZE; + +#pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < P.N_OH_OW) { + int64_t n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); + OUT[out_idx] = acc[i]; + } + } + } }; #if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE))) +# include "mma.cuh" +using namespace ggml_cuda_mma; + +typedef ggml_cuda_mma::tile tile_a; +typedef ggml_cuda_mma::tile tile_b; +typedef ggml_cuda_mma::tile tile_acc; -class half_mma { +template class half_mma { private: - wmma::fragment acc; - wmma::fragment a_frag; - wmma::fragment b_frag; + tile_a a_frag; + tile_b b_frag; + tile_acc c_frag; public: - float * buf; + __device__ __forceinline__ half_mma() {} - __device__ __forceinline__ half_mma(float * scratch) { - buf = scratch; - wmma::fill_fragment(acc, 0.0f); + __device__ __forceinline__ void clear() { +# pragma unroll + for (int l = 0; l < c_frag.ne; ++l) { + c_frag.x[l] = 0.0f; + } } __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { - wmma::load_matrix_sync(a_frag, A_sh, strideA); - wmma::load_matrix_sync(b_frag, B_sh, strideB); - wmma::mma_sync(acc, a_frag, b_frag, acc); + ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); + ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); + ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ float * store_result() const { - wmma::store_matrix_sync(buf, acc, WMMA_N, wmma::mem_row_major); - return buf; + __device__ __forceinline__ void store_result(const int64_t OC_BASE, + const int64_t NOHOW_BASE, + float * OUT, + const conv_params & P) const { +# pragma unroll + for (int l = 0; l < tile_acc::ne; ++l) { + const int64_t e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < (P.N_OH_OW)) { + int64_t n, oh, ow; + layout::unpack_nohow(nohow, n, oh, ow, P); + OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l]; + } + } } }; #else -class half_mma { +template class half_mma { public: - float * buf; + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + + float acc[num_acc]; - __device__ __forceinline__ half_mma(float * scratch) { - buf = scratch; - const int lane_id = threadIdx.x % warpSize; + __device__ __forceinline__ half_mma() { # pragma unroll - for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) { - buf[i] = 0.0f; + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; } } - __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { - const int lane_id = threadIdx.x % warpSize; + __device__ __forceinline__ void clear() { +# pragma unroll + for (int i = 0; i < num_acc; i++) { + acc[i] = 0.0f; + } + } + + __device__ __forceinline__ void mma(const half * __restrict__ A_sh, + const half * __restrict__ B_sh, + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; + # pragma unroll - for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) { - int m = e / WMMA_N; - int n = e % WMMA_N; - float sum = buf[m * WMMA_N + n]; + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + # pragma unroll for (int k = 0; k < WMMA_K; k++) { - float a = A_sh[m * strideA + k]; - float b = B_sh[k * strideB + n]; - sum = fmaf(__half2float(a), __half2float(b), sum); + const half a = A_sh[m * strideA + k]; + const half b = B_sh[k * strideB + n]; + acc[i] = fmaf(__half2float(a), __half2float(b), acc[i]); } - buf[m * WMMA_N + n] = sum; } } - __device__ __forceinline__ float * store_result() const { return buf; } + __device__ __forceinline__ void store_result(const int64_t OC_BASE, + const int64_t NOHOW_BASE, + float * __restrict__ OUT, + const conv_params & P) const { + const int lane_id = threadIdx.x % WARP_SIZE; + +# pragma unroll + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; + + const int64_t oc = OC_BASE + m; + const int64_t nohow = NOHOW_BASE + n; + + if (oc < P.OC && nohow < P.N_OH_OW) { + int64_t n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); + OUT[out_idx] = acc[i]; + } + } + } }; + #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) -template -static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) { +template +__global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const int64_t OUTPUT_NUMEL = WMMA_M * WMMA_N; const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int64_t warpId = threadIdx.y; - const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int64_t total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + const int64_t num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; - const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int64_t tile_id = blockIdx.x; - const int64_t tile_oc = tile_id / NUM_BL_NOHOW; - const int64_t tile_nohow = tile_id % NUM_BL_NOHOW; - const int64_t BLOCK_OC_BASE = tile_oc * BS_OC; - const int64_t BLOCK_NOHOW_BASE = tile_nohow * BS_NOHOW; + mma acc[num_work_per_warps]; - const int64_t laneId = threadIdx.x % WARP_SIZE; - const int64_t warpId = threadIdx.x / WARP_SIZE; + const int64_t num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int64_t BL_IDX_OC = blockIdx.x / num_block_nohow; + const int64_t BL_IDX_NOHOW = blockIdx.x % num_block_nohow; - const int64_t WARP_OC = warpId / WARPS_PER_NOHOW; - const int64_t WARP_NOHOW = warpId % WARPS_PER_NOHOW; + const int64_t BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int64_t BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; - const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + unsigned char * ptr = smem_raw; - unsigned char * ptr = smem_raw; - T * A_sh = reinterpret_cast(ptr); + const int64_t A_total = BS_OC * BS_ICKHKW; + const int64_t B_total = BS_ICKHKW * BS_NOHOW; - size_t offsetA = BS_OC * BS_ICKHKW * sizeof(T); + size_t offsetA = (size_t) A_total * sizeof(T); + T * A_sh = reinterpret_cast(ptr); ptr += offsetA; - T * B_sh = reinterpret_cast(ptr); - ptr += BS_ICKHKW * BS_NOHOW * sizeof(T); - - float * shared_scratch = reinterpret_cast(ptr); - float * warp_scratch = shared_scratch + warpId * (WMMA_M * WMMA_N); - - const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; - const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + size_t offsetB = (size_t) B_total * sizeof(T); + T * B_sh = reinterpret_cast(ptr); + ptr += offsetB; - mma acc(warp_scratch); - - const int64_t A_total = BS_OC * BS_ICKHKW; - const int64_t B_total = BS_ICKHKW * BS_NOHOW; - -#pragma unroll + int64_t ic, kh, kw; + int64_t n, oh, ow; for (int64_t t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (int64_t tid = (threadIdx.x); tid < A_total; tid += blockDim.x) { + for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { const int row = tid / BS_ICKHKW; const int col = tid % BS_ICKHKW; @@ -241,7 +300,6 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { - int64_t ic, kh, kw; layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); @@ -249,9 +307,8 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT } A_sh[row * BS_ICKHKW + col] = val; } - #pragma unroll - for (int64_t tid = (threadIdx.x); tid < B_total; tid += blockDim.x) { + for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { const int brow = tid / BS_NOHOW; const int bcol = tid % BS_NOHOW; @@ -260,9 +317,7 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { - int64_t n, oh, ow; layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); - int64_t ic, kh, kw; layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); @@ -277,76 +332,88 @@ static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT __syncthreads(); #pragma unroll - for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { - const T * A_k_ptr = A_warp_base + k_tile; - const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; - - acc.mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); + for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { + const int64_t WARP_OC = warp / WARPS_PER_NOHOW; + const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; +#pragma unroll + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + const T * A_k_ptr = A_warp_base + k_tile; + const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; + acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); + } } __syncthreads(); } - const float * out_buf = acc.store_result(); #pragma unroll - for (int e = laneId; e < OUTPUT_NUMEL; e += WARP_SIZE) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; - - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; - - if (oc < P.OC && nohow < (P.N_OH_OW)) { - int64_t n, oh, ow; - layout::unpack_nohow(nohow, n, oh, ow, P); - const int64_t out_idx = layout::output_index(n, oc, oh, ow, P); - OUT[out_idx] = out_buf[e]; - } + for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { + const int64_t WARP_OC = warp / WARPS_PER_NOHOW; + const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; + const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } -template -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, conv_params P, cudaStream_t st) - -{ - const int64_t NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; - const int64_t NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; +template class mma> +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int warp_size = 32; + const int max_block_size = 256; - int64_t TOTAL_TILES = NUM_BL_OC * NUM_BL_NOHOW; - TOTAL_TILES = std::min(TOTAL_TILES, (int64_t) INT_MAX); + GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); - const int WARPS_PER_OC = std::max(1, ceil_div(BS_OC, WMMA_M)); - const int WARPS_PER_NOHOW = std::max(1, ceil_div(BS_NOHOW, WMMA_N)); - const int EXPECTED_WARPS = WARPS_PER_OC * WARPS_PER_NOHOW; - int N_THREADS = EXPECTED_WARPS * WARP_SIZE; + const int num_block_oc = (P.OC + BS_OC - 1) / BS_OC; + const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int num_blocks = num_block_oc * num_block_nohow; - const int MAX_TPB = 1024; - if (N_THREADS > MAX_TPB) { - N_THREADS = (MAX_TPB / WARP_SIZE) * WARP_SIZE; + int nwarps_best = 1; + int niter_best = (BS_OC * BS_NOHOW + warp_size - 1) / (warp_size); + for (int nwarps = 2; nwarps <= max_block_size / warp_size; ++nwarps) { + const int niter = (BS_OC * BS_NOHOW + nwarps * warp_size - 1) / (nwarps * warp_size); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } } - if (N_THREADS < WARP_SIZE) { - N_THREADS = WARP_SIZE; + const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); + const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); + const size_t shared_bytes = A_bytes + B_bytes; + + dim3 grid(num_blocks, 1, 1); + dim3 block(warp_size, nwarps_best); + + switch (nwarps_best) { + case 1: + conv2d_kernel, 1><<>>(X_D, K_D, Y_D, P); + break; + case 2: + conv2d_kernel, 2><<>>(X_D, K_D, Y_D, P); + break; + case 4: + conv2d_kernel, 4><<>>(X_D, K_D, Y_D, P); + break; + case 8: + conv2d_kernel, 8><<>>(X_D, K_D, Y_D, P); + break; + case 16: + conv2d_kernel, 16><<>>(X_D, K_D, Y_D, P); + break; + case 32: + conv2d_kernel, 32><<>>(X_D, K_D, Y_D, P); + break; + default: + GGML_ABORT("UNSUPPROTED NWARPS_BEST"); } - - const int N_WARPS = N_THREADS / WARP_SIZE; - - // scratch_buff to store output, can't store directly using wmma, - // output mapping is unknown - const int64_t scratch_bytes = N_WARPS * (WMMA_M * WMMA_N) * sizeof(float); - - const int64_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); - const int64_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); - const int64_t shared_bytes = A_bytes + B_bytes + scratch_bytes; - - dim3 grid(TOTAL_TILES, 1, 1); - conv2d_kernel<<>>(X_D, K_D, Y_D, P); } -static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, conv_params & P, cudaStream_t st) { +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { conv2d_cuda(X_D, K_D, Y_D, P, st); } -static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, conv_params & P, cudaStream_t st) { +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { conv2d_cuda(X_D, K_D, Y_D, P, st); } diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index 28c8b9bab6f98..a1de712b54a66 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,14 +1,12 @@ #pragma once #include "common.cuh" -constexpr int BS_OC = 128; -constexpr int BS_ICKHKW = 16; -constexpr int BS_NOHOW = 128; +#define BS_OC 64 +#define BS_ICKHKW 16 +#define BS_NOHOW 64 -// supported configuration -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 16; +#define WMMA_M 16 +#define WMMA_N 16 +#define WMMA_K 16 void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From ac5e0c023c5ff96c9946f82ef33204bcf6c3a45e Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Fri, 12 Sep 2025 13:30:04 +0530 Subject: [PATCH 5/8] CUDA: conv2d convert int64_t to int --- ggml/src/ggml-cuda/conv2d.cu | 173 ++++++++++++++++------------------- 1 file changed, 81 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 99799ac6db6f8..f9cdd7786069d 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -2,56 +2,44 @@ #include "convert.cuh" struct conv_params { - const int64_t IW, IH; - const int64_t OW, OH; - const int64_t KW, KH; - const int64_t ST_X, ST_Y; - const int64_t PD_X, PD_Y; - const int64_t DL_X, DL_Y; - const int64_t IC, OC; - const int64_t B; + const int IW, IH; + const int OW, OH; + const int KW, KH; + const int ST_X, ST_Y; + const int PD_X, PD_Y; + const int DL_X, DL_Y; + const int IC, OC; + const int B; const int64_t TOTAL; // helpers - const int64_t IC_KH_KW, N_OH_OW; + const int IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ static int calculate_input_coord(int64_t out_coord, - int64_t kern_coord, - int64_t stride, - int64_t dilation, - int64_t padding) { +__device__ __forceinline__ static int calculate_input_coord(int out_coord, + int kern_coord, + int stride, + int dilation, + int padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static int64_t input_index(int64_t n, - int64_t c, - int64_t y, - int64_t x, - const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static int64_t kernel_index(int64_t c_out, - int64_t c_in, - int64_t ky, - int64_t kx, - const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static int64_t output_index(int64_t n, - int64_t c, - int64_t y, - int64_t x, - const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, - int64_t & ic, - int64_t & kh, - int64_t & kw, + int & ic, + int & kh, + int & kw, const conv_params & P) { ic = idx / (P.KW * P.KH); int64_t r = idx - ic * (P.KW * P.KH); @@ -60,9 +48,9 @@ struct whcn_layout { } __device__ __forceinline__ static void unpack_nohow(int64_t idx, - int64_t & n, - int64_t & oh, - int64_t & ow, + int & n, + int & oh, + int & ow, const conv_params & P) { n = idx / (P.OH * P.OW); int64_t r = idx - n * (P.OH * P.OW); @@ -111,8 +99,8 @@ template class float_mma { } } - __device__ __forceinline__ void store_result(const int64_t OC_BASE, - const int64_t NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { const int lane_id = threadIdx.x % WARP_SIZE; @@ -122,14 +110,13 @@ template class float_mma { const int m = e / WMMA_N; const int n = e % WMMA_N; - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int64_t n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); - const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); - OUT[out_idx] = acc[i]; + OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } } } @@ -158,27 +145,30 @@ template class half_mma { } } - __device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) { + __device__ __forceinline__ void mma(const half * __restrict__ A_sh, + const half * __restrict__ B_sh, + const int strideA, + const int strideB) { ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ void store_result(const int64_t OC_BASE, - const int64_t NOHOW_BASE, - float * OUT, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, + float * __restrict__ OUT, const conv_params & P) const { # pragma unroll for (int l = 0; l < tile_acc::ne; ++l) { - const int64_t e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - const int m = e / WMMA_N; - const int n = e % WMMA_N; + const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < (P.N_OH_OW)) { - int64_t n, oh, ow; + int n, oh, ow; layout::unpack_nohow(nohow, n, oh, ow, P); OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l]; } @@ -228,8 +218,8 @@ template class half_mma { } } - __device__ __forceinline__ void store_result(const int64_t OC_BASE, - const int64_t NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { const int lane_id = threadIdx.x % WARP_SIZE; @@ -239,14 +229,13 @@ template class half_mma { const int m = e / WMMA_N; const int n = e % WMMA_N; - const int64_t oc = OC_BASE + m; - const int64_t nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int64_t n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); - const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P); - OUT[out_idx] = acc[i]; + OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } } } @@ -258,26 +247,26 @@ template __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const int64_t warpId = threadIdx.y; + const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int warpId = threadIdx.y; - const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); - const int64_t total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - const int64_t num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; + const int WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + const int num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; mma acc[num_work_per_warps]; - const int64_t num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int64_t BL_IDX_OC = blockIdx.x / num_block_nohow; - const int64_t BL_IDX_NOHOW = blockIdx.x % num_block_nohow; + const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int BL_IDX_OC = blockIdx.x / num_block_nohow; + const int BL_IDX_NOHOW = blockIdx.x % num_block_nohow; - const int64_t BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const int64_t BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; unsigned char * ptr = smem_raw; - const int64_t A_total = BS_OC * BS_ICKHKW; - const int64_t B_total = BS_ICKHKW * BS_NOHOW; + const int A_total = BS_OC * BS_ICKHKW; + const int B_total = BS_ICKHKW * BS_NOHOW; size_t offsetA = (size_t) A_total * sizeof(T); T * A_sh = reinterpret_cast(ptr); @@ -287,33 +276,33 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const T * B_sh = reinterpret_cast(ptr); ptr += offsetB; - int64_t ic, kh, kw; - int64_t n, oh, ow; - for (int64_t t = 0; t < NUM_IC_TILES; ++t) { + int ic, kh, kw; + int n, oh, ow; + for (int t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { + for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { const int row = tid / BS_ICKHKW; const int col = tid % BS_ICKHKW; - int64_t shared_oc = BLOCK_OC_BASE + row; - int64_t shared_ickhkw = t * BS_ICKHKW + col; + int shared_oc = BLOCK_OC_BASE + row; + int shared_ickhkw = t * BS_ICKHKW + col; T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; + const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } A_sh[row * BS_ICKHKW + col] = val; } #pragma unroll - for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { + for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { const int brow = tid / BS_NOHOW; const int bcol = tid % BS_NOHOW; - int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow; - int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + int IC_KH_KW_IDX = t * BS_ICKHKW + brow; + int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { @@ -333,10 +322,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const #pragma unroll for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int64_t WARP_OC = warp / WARPS_PER_NOHOW; - const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; - const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; - const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + const int WARP_OC = warp / WARPS_PER_NOHOW; + const int WARP_NOHOW = warp % WARPS_PER_NOHOW; + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; + const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; #pragma unroll for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { const T * A_k_ptr = A_warp_base + k_tile; @@ -349,10 +338,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const #pragma unroll for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int64_t WARP_OC = warp / WARPS_PER_NOHOW; - const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW; - const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + const int WARP_OC = warp / WARPS_PER_NOHOW; + const int WARP_NOHOW = warp % WARPS_PER_NOHOW; + const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } @@ -454,8 +443,8 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int B = input->ne[3]; // n_batches const int64_t TOTAL = B * OC * OH * OW; - const int64_t IC_KH_KW = IC * KH * KW; - const int64_t N_OH_OW = B * OH * OW; + const int IC_KH_KW = IC * KH * KW; + const int N_OH_OW = B * OH * OW; conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW }; From 410171ae113962a1cf687f85052fe855ac6cc551 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Sat, 13 Sep 2025 15:02:07 +0530 Subject: [PATCH 6/8] CUDA: conv2d update block size --- ggml/src/ggml-cuda/conv2d.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index a1de712b54a66..3dcce2b4a2e3b 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,9 +1,9 @@ #pragma once #include "common.cuh" -#define BS_OC 64 +#define BS_OC 16 #define BS_ICKHKW 16 -#define BS_NOHOW 64 +#define BS_NOHOW 128 #define WMMA_M 16 #define WMMA_N 16 From 51f85ff57ad1c82fb69d0cd1faac92bfa95b8763 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Tue, 16 Sep 2025 03:22:36 +0530 Subject: [PATCH 7/8] CUDA: conv2d performance optimization --- ggml/src/ggml-cuda/conv2d.cu | 383 ++++++++++++++++------------------ ggml/src/ggml-cuda/conv2d.cuh | 6 +- 2 files changed, 185 insertions(+), 204 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index f9cdd7786069d..db92a40cd0dd0 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,97 +1,98 @@ #include "conv2d.cuh" #include "convert.cuh" +#include + struct conv_params { - const int IW, IH; - const int OW, OH; - const int KW, KH; - const int ST_X, ST_Y; - const int PD_X, PD_Y; - const int DL_X, DL_Y; - const int IC, OC; - const int B; - const int64_t TOTAL; + const uint IW, IH; + const uint OW, OH; + const uint KW, KH; + const uint ST_X, ST_Y; + const uint PD_X, PD_Y; + const uint DL_X, DL_Y; + const uint IC, OC; + const uint B; // helpers - const int IC_KH_KW, N_OH_OW; + const uint IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ static int calculate_input_coord(int out_coord, - int kern_coord, - int stride, - int dilation, - int padding) { +__device__ __forceinline__ static uint64_t calculate_input_coord(uint out_coord, + uint kern_coord, + uint stride, + uint dilation, + uint padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { + __device__ __forceinline__ static uint64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { + __device__ __forceinline__ static uint64_t kernel_index(uint c_out, + uint c_in, + uint ky, + uint kx, + const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { + __device__ __forceinline__ static uint64_t output_index(uint n, uint c, uint y, uint x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, - int & ic, - int & kh, - int & kw, + __device__ __forceinline__ static void unpack_ickhkw(uint64_t idx, + uint & ic, + uint & kh, + uint & kw, const conv_params & P) { - ic = idx / (P.KW * P.KH); - int64_t r = idx - ic * (P.KW * P.KH); - kh = r / P.KW; - kw = r - kh * P.KW; + ic = idx / (P.KW * P.KH); + uint r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; } - __device__ __forceinline__ static void unpack_nohow(int64_t idx, - int & n, - int & oh, - int & ow, + __device__ __forceinline__ static void unpack_nohow(uint64_t idx, + uint & n, + uint & oh, + uint & ow, const conv_params & P) { - n = idx / (P.OH * P.OW); - int64_t r = idx - n * (P.OH * P.OW); - oh = r / P.OW; - ow = r - oh * P.OW; + n = idx / (P.OH * P.OW); + uint r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; } }; -template class float_mma { - public: - static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - - float acc[num_acc]; +template __device__ class float_mma { + private: + static constexpr uint num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + // for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] ... [14,0] + // lane 1 will store and compute for [0,1], [2,1], [4,1] ... [14,1] + float acc[num_acc]; + public: __device__ __forceinline__ float_mma() { #pragma unroll - for (int i = 0; i < num_acc; i++) { - acc[i] = 0.0f; - } - } - - __device__ __forceinline__ void clear() { -#pragma unroll - for (int i = 0; i < num_acc; i++) { + for (uint i = 0; i < num_acc; i++) { acc[i] = 0.0f; } } __device__ __forceinline__ void mma(const float * __restrict__ A_sh, const float * __restrict__ B_sh, - const int strideA, - const int strideB) { - const int lane_id = threadIdx.x % WARP_SIZE; + const uint strideA, + const uint strideB) { + const uint lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint i = 0; i < num_acc; i++) { + const uint e = lane_id + i * WARP_SIZE; + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; #pragma unroll - for (int k = 0; k < WMMA_K; k++) { + for (uint k = 0; k < WMMA_K; k++) { const float a = A_sh[m * strideA + k]; const float b = B_sh[k * strideB + n]; acc[i] = fmaf(a, b, acc[i]); @@ -99,22 +100,23 @@ template class float_mma { } } - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, + __device__ __forceinline__ void store_result(const uint OC_BASE, + const uint NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const int lane_id = threadIdx.x % WARP_SIZE; + const uint lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint i = 0; i < num_acc; i++) { + const uint e = lane_id + i * WARP_SIZE; + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + const uint oc = OC_BASE + m; + const uint nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int n_, oh, ow; + uint n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -126,15 +128,16 @@ template class float_mma { # include "mma.cuh" using namespace ggml_cuda_mma; -typedef ggml_cuda_mma::tile tile_a; -typedef ggml_cuda_mma::tile tile_b; -typedef ggml_cuda_mma::tile tile_acc; +typedef tile tile_a; +typedef tile tile_b; +typedef tile tile_acc; template class half_mma { private: tile_a a_frag; tile_b b_frag; tile_acc c_frag; + public: __device__ __forceinline__ half_mma() {} @@ -147,30 +150,30 @@ template class half_mma { __device__ __forceinline__ void mma(const half * __restrict__ A_sh, const half * __restrict__ B_sh, - const int strideA, - const int strideB) { - ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); - ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); + const uint strideA, + const uint strideB) { + load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); + load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, + __device__ __forceinline__ void store_result(const uint OC_BASE, + const uint NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { # pragma unroll - for (int l = 0; l < tile_acc::ne; ++l) { - const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint l = 0; l < tile_acc::ne; ++l) { + const uint e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + const uint oc = OC_BASE + m; + const uint nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < (P.N_OH_OW)) { - int n, oh, ow; - layout::unpack_nohow(nohow, n, oh, ow, P); - OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l]; + uint n_, oh, ow; + layout::unpack_nohow(nohow, n_, oh, ow, P); + OUT[layout::output_index(n_, oc, oh, ow, P)] = c_frag.x[l]; } } } @@ -181,8 +184,8 @@ template class half_mma { template class half_mma { public: static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; - - float acc[num_acc]; + // eg. for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] .. [14,0] + float acc[num_acc]; __device__ __forceinline__ half_mma() { # pragma unroll @@ -191,13 +194,6 @@ template class half_mma { } } - __device__ __forceinline__ void clear() { -# pragma unroll - for (int i = 0; i < num_acc; i++) { - acc[i] = 0.0f; - } - } - __device__ __forceinline__ void mma(const half * __restrict__ A_sh, const half * __restrict__ B_sh, const int strideA, @@ -218,22 +214,22 @@ template class half_mma { } } - __device__ __forceinline__ void store_result(const int OC_BASE, - const int NOHOW_BASE, + __device__ __forceinline__ void store_result(const uint OC_BASE, + const uint NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const int lane_id = threadIdx.x % WARP_SIZE; + const uint lane_id = threadIdx.x % WARP_SIZE; # pragma unroll - for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const int m = e / WMMA_N; - const int n = e % WMMA_N; + for (uint e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const uint m = e / WMMA_N; + const uint n = e % WMMA_N; - const int oc = OC_BASE + m; - const int nohow = NOHOW_BASE + n; + const uint oc = OC_BASE + m; + const uint nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - int n_, oh, ow; + uint n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -243,30 +239,35 @@ template class half_mma { #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE)) -template -__global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) { +template class mma, int num_warps> +__global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const float * __restrict__ IN, + const T * __restrict__ IK, + float * __restrict__ Out, + const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const int warpId = threadIdx.y; + const uint warpId = threadIdx.y; + const uint linear_tid = threadIdx.y * blockDim.x + threadIdx.x; - const int WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N); - const int total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - const int num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps; + const uint NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const uint NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); + const uint NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - mma acc[num_work_per_warps]; + const uint NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; - const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int BL_IDX_OC = blockIdx.x / num_block_nohow; - const int BL_IDX_NOHOW = blockIdx.x % num_block_nohow; + mma acc[NUM_TILES_PER_WARP]; - const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const uint BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; + const uint BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; + + const uint BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const uint BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; unsigned char * ptr = smem_raw; - const int A_total = BS_OC * BS_ICKHKW; - const int B_total = BS_ICKHKW * BS_NOHOW; + const uint A_total = BS_OC * BS_ICKHKW; + const uint B_total = BS_ICKHKW * BS_NOHOW; size_t offsetA = (size_t) A_total * sizeof(T); T * A_sh = reinterpret_cast(ptr); @@ -276,40 +277,41 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const T * B_sh = reinterpret_cast(ptr); ptr += offsetB; - int ic, kh, kw; - int n, oh, ow; - for (int t = 0; t < NUM_IC_TILES; ++t) { + for (uint t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) { - const int row = tid / BS_ICKHKW; - const int col = tid % BS_ICKHKW; + for (uint tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { + const uint row = tid / BS_ICKHKW; + const uint col = tid % BS_ICKHKW; - int shared_oc = BLOCK_OC_BASE + row; - int shared_ickhkw = t * BS_ICKHKW + col; + const uint shared_oc = BLOCK_OC_BASE + row; + const uint shared_ickhkw = t * BS_ICKHKW + col; T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { + uint ic, kh, kw; layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; + const uint kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } A_sh[row * BS_ICKHKW + col] = val; } #pragma unroll - for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) { - const int brow = tid / BS_NOHOW; - const int bcol = tid % BS_NOHOW; + for (uint tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { + const uint brow = tid / BS_NOHOW; + const uint bcol = tid % BS_NOHOW; - int IC_KH_KW_IDX = t * BS_ICKHKW + brow; - int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + const uint IC_KH_KW_IDX = t * BS_ICKHKW + brow; + const uint N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { + uint n, oh, ow; + uint ic, kh, kw; layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); - int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); - int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); + const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); + const int in_x = calculate_input_coord(ow, kw, P.ST_X, P.DL_X, P.PD_X); if (in_y >= 0 && in_y < P.IH && in_x >= 0 && in_x < P.IW) { const int64_t in_idx = layout::input_index(n, ic, in_y, in_x, P); val = ggml_cuda_cast(IN[in_idx]); @@ -321,13 +323,19 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const __syncthreads(); #pragma unroll - for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int WARP_OC = warp / WARPS_PER_NOHOW; - const int WARP_NOHOW = warp % WARPS_PER_NOHOW; + for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { + const uint warp = warpId + i * num_warps; + if (warp >= NUM_WARPS_NEED) { + continue; + } + const uint WARP_OC = warp / NUM_WARPS_NOHOW; + const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; + #pragma unroll - for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + for (uint k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { const T * A_k_ptr = A_warp_base + k_tile; const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); @@ -337,65 +345,37 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const } #pragma unroll - for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) { - const int WARP_OC = warp / WARPS_PER_NOHOW; - const int WARP_NOHOW = warp % WARPS_PER_NOHOW; - const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { + const uint warp = warpId + i * num_warps; + if (warp >= NUM_WARPS_NEED) { + continue; + } + const uint WARP_OC = warp / NUM_WARPS_NOHOW; + const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const uint OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const uint NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } template class mma> static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { - const int warp_size = 32; - const int max_block_size = 256; - GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); - const int num_block_oc = (P.OC + BS_OC - 1) / BS_OC; - const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const int num_blocks = num_block_oc * num_block_nohow; - - int nwarps_best = 1; - int niter_best = (BS_OC * BS_NOHOW + warp_size - 1) / (warp_size); - for (int nwarps = 2; nwarps <= max_block_size / warp_size; ++nwarps) { - const int niter = (BS_OC * BS_NOHOW + nwarps * warp_size - 1) / (nwarps * warp_size); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } + const uint NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const uint NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; + + constexpr uint NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); const size_t shared_bytes = A_bytes + B_bytes; - dim3 grid(num_blocks, 1, 1); - dim3 block(warp_size, nwarps_best); - - switch (nwarps_best) { - case 1: - conv2d_kernel, 1><<>>(X_D, K_D, Y_D, P); - break; - case 2: - conv2d_kernel, 2><<>>(X_D, K_D, Y_D, P); - break; - case 4: - conv2d_kernel, 4><<>>(X_D, K_D, Y_D, P); - break; - case 8: - conv2d_kernel, 8><<>>(X_D, K_D, Y_D, P); - break; - case 16: - conv2d_kernel, 16><<>>(X_D, K_D, Y_D, P); - break; - case 32: - conv2d_kernel, 32><<>>(X_D, K_D, Y_D, P); - break; - default: - GGML_ABORT("UNSUPPROTED NWARPS_BEST"); - } + dim3 grid(NUM_BL, 1, 1); + dim3 block(WARP_SIZE, NUM_WARPS, 1); + + conv2d_kernel<<>>(X_D, K_D, Y_D, P); } static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { @@ -422,31 +402,30 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t st = ctx.stream(); const int32_t * p = (const int32_t *) dst->op_params; - const int ST_X = p[0]; // stride_x - const int ST_Y = p[1]; // stride_y - const int PD_X = p[2]; // padding_x - const int PD_Y = p[3]; // padding_y - const int DL_X = p[4]; // dilation_x - const int DL_Y = p[5]; // dilation_y + const uint ST_X = p[0]; // stride_x + const uint ST_Y = p[1]; // stride_y + const uint PD_X = p[2]; // padding_x + const uint PD_Y = p[3]; // padding_y + const uint DL_X = p[4]; // dilation_x + const uint DL_Y = p[5]; // dilation_y // No cwhn GGML_ASSERT(p[6] == false); - const int IW = input->ne[0]; // input_w - const int IH = input->ne[1]; // input_h - const int OW = dst->ne[0]; // output_w - const int OH = dst->ne[1]; // output_h - const int KW = kernel->ne[0]; // kernel_w - const int KH = kernel->ne[1]; // kernel_h - const int IC = input->ne[2]; // input_channels - const int OC = kernel->ne[3]; // ouptut_chanles - const int B = input->ne[3]; // n_batches - - const int64_t TOTAL = B * OC * OH * OW; - const int IC_KH_KW = IC * KH * KW; - const int N_OH_OW = B * OH * OW; - conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, - PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW }; + const uint IW = input->ne[0]; // input_w + const uint IH = input->ne[1]; // input_h + const uint OW = dst->ne[0]; // output_w + const uint OH = dst->ne[1]; // output_h + const uint KW = kernel->ne[0]; // kernel_w + const uint KH = kernel->ne[1]; // kernel_h + const uint IC = input->ne[2]; // input_channels + const uint OC = kernel->ne[3]; // ouptut_chanles + const uint B = input->ne[3]; // n_batches + + const uint IC_KH_KW = IC * KH * KW; + const uint N_OH_OW = B * OH * OW; + const conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, + PD_Y, DL_X, DL_Y, IC, OC, B, IC_KH_KW, N_OH_OW }; if (kernel->type == GGML_TYPE_F16) { conv2d_cuda_f16(X_D, (const half *) K_D, Y_D, params, st); diff --git a/ggml/src/ggml-cuda/conv2d.cuh b/ggml/src/ggml-cuda/conv2d.cuh index 3dcce2b4a2e3b..3a1a5f28b572c 100644 --- a/ggml/src/ggml-cuda/conv2d.cuh +++ b/ggml/src/ggml-cuda/conv2d.cuh @@ -1,12 +1,14 @@ #pragma once #include "common.cuh" -#define BS_OC 16 +#define BS_OC 32 #define BS_ICKHKW 16 -#define BS_NOHOW 128 +#define BS_NOHOW 32 #define WMMA_M 16 #define WMMA_N 16 #define WMMA_K 16 +#define CUDA_CONV2D_BLOCK_SIZE 128 + void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From 604957644fe5751c961ca1222ffed6b759751830 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Tue, 16 Sep 2025 04:58:58 +0530 Subject: [PATCH 8/8] CUDA: conv2d minor fixes CUDA: uint to int and added assertion --- ggml/src/ggml-cuda/conv2d.cu | 297 +++++++++++++++++------------------ 1 file changed, 148 insertions(+), 149 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index db92a40cd0dd0..deaca3d648d5a 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -4,95 +4,94 @@ #include struct conv_params { - const uint IW, IH; - const uint OW, OH; - const uint KW, KH; - const uint ST_X, ST_Y; - const uint PD_X, PD_Y; - const uint DL_X, DL_Y; - const uint IC, OC; - const uint B; + const int IW, IH; + const int OW, OH; + const int KW, KH; + const int ST_X, ST_Y; + const int PD_X, PD_Y; + const int DL_X, DL_Y; + const int IC, OC; + const int B; // helpers - const uint IC_KH_KW, N_OH_OW; + const int IC_KH_KW, N_OH_OW; }; -__device__ __forceinline__ static uint64_t calculate_input_coord(uint out_coord, - uint kern_coord, - uint stride, - uint dilation, - uint padding) { +__device__ __forceinline__ static int calculate_input_coord(int out_coord, + int kern_coord, + int stride, + int dilation, + int padding) { return out_coord * stride + kern_coord * dilation - padding; } struct whcn_layout { - __device__ __forceinline__ static uint64_t input_index(int n, int c, int y, int x, const conv_params & P) { + __device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; } - __device__ __forceinline__ static uint64_t kernel_index(uint c_out, - uint c_in, - uint ky, - uint kx, - const conv_params & P) { + __device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) { return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; } - __device__ __forceinline__ static uint64_t output_index(uint n, uint c, uint y, uint x, const conv_params & P) { + __device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) { return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; } - __device__ __forceinline__ static void unpack_ickhkw(uint64_t idx, - uint & ic, - uint & kh, - uint & kw, + __device__ __forceinline__ static void unpack_ickhkw(int64_t idx, + int & ic, + int & kh, + int & kw, const conv_params & P) { - ic = idx / (P.KW * P.KH); - uint r = idx - ic * (P.KW * P.KH); - kh = r / P.KW; - kw = r - kh * P.KW; + ic = idx / (P.KW * P.KH); + int r = idx - ic * (P.KW * P.KH); + kh = r / P.KW; + kw = r - kh * P.KW; } - __device__ __forceinline__ static void unpack_nohow(uint64_t idx, - uint & n, - uint & oh, - uint & ow, + __device__ __forceinline__ static void unpack_nohow(int64_t idx, + int & n, + int & oh, + int & ow, const conv_params & P) { - n = idx / (P.OH * P.OW); - uint r = idx - n * (P.OH * P.OW); - oh = r / P.OW; - ow = r - oh * P.OW; + n = idx / (P.OH * P.OW); + int r = idx - n * (P.OH * P.OW); + oh = r / P.OW; + ow = r - oh * P.OW; } }; -template __device__ class float_mma { +template class float_mma { private: - static constexpr uint num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; + static constexpr int num_acc = (WMMA_M * WMMA_N + WARP_SIZE - 1) / WARP_SIZE; // for tile [16,16], lane 0 will store and compute for [0,0], [2,0], [4,0] ... [14,0] // lane 1 will store and compute for [0,1], [2,1], [4,1] ... [14,1] - float acc[num_acc]; + float acc[num_acc]; public: __device__ __forceinline__ float_mma() { #pragma unroll - for (uint i = 0; i < num_acc; i++) { + for (int i = 0; i < num_acc; i++) { acc[i] = 0.0f; } } __device__ __forceinline__ void mma(const float * __restrict__ A_sh, const float * __restrict__ B_sh, - const uint strideA, - const uint strideB) { - const uint lane_id = threadIdx.x % WARP_SIZE; + const int strideA, + const int strideB) { + const int lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (uint i = 0; i < num_acc; i++) { - const uint e = lane_id + i * WARP_SIZE; - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int i = 0; i < num_acc; i++) { + const int e = lane_id + i * WARP_SIZE; + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; #pragma unroll - for (uint k = 0; k < WMMA_K; k++) { + for (int k = 0; k < WMMA_K; k++) { const float a = A_sh[m * strideA + k]; const float b = B_sh[k * strideB + n]; acc[i] = fmaf(a, b, acc[i]); @@ -100,23 +99,26 @@ template __device__ class float_mma { } } - __device__ __forceinline__ void store_result(const uint OC_BASE, - const uint NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const uint lane_id = threadIdx.x % WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; #pragma unroll - for (uint i = 0; i < num_acc; i++) { - const uint e = lane_id + i * WARP_SIZE; - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int i = 0; i < num_acc; i++) { + const int e = lane_id + i * WARP_SIZE; + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const uint oc = OC_BASE + m; - const uint nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - uint n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -141,37 +143,33 @@ template class half_mma { public: __device__ __forceinline__ half_mma() {} - __device__ __forceinline__ void clear() { -# pragma unroll - for (int l = 0; l < c_frag.ne; ++l) { - c_frag.x[l] = 0.0f; - } - } - __device__ __forceinline__ void mma(const half * __restrict__ A_sh, const half * __restrict__ B_sh, - const uint strideA, - const uint strideB) { + const int strideA, + const int strideB) { load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2); load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2); ggml_cuda_mma::mma(c_frag, a_frag, b_frag); } - __device__ __forceinline__ void store_result(const uint OC_BASE, - const uint NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { # pragma unroll - for (uint l = 0; l < tile_acc::ne; ++l) { - const uint e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int l = 0; l < tile_acc::ne; ++l) { + const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l); + if (e >= WMMA_M * WMMA_N) { + continue; + } + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const uint oc = OC_BASE + m; - const uint nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < (P.N_OH_OW)) { - uint n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = c_frag.x[l]; } @@ -214,22 +212,22 @@ template class half_mma { } } - __device__ __forceinline__ void store_result(const uint OC_BASE, - const uint NOHOW_BASE, + __device__ __forceinline__ void store_result(const int OC_BASE, + const int NOHOW_BASE, float * __restrict__ OUT, const conv_params & P) const { - const uint lane_id = threadIdx.x % WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; # pragma unroll - for (uint e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { - const uint m = e / WMMA_N; - const uint n = e % WMMA_N; + for (int e = lane_id, i = 0; e < WMMA_M * WMMA_N; e += WARP_SIZE, i++) { + const int m = e / WMMA_N; + const int n = e % WMMA_N; - const uint oc = OC_BASE + m; - const uint nohow = NOHOW_BASE + n; + const int oc = OC_BASE + m; + const int nohow = NOHOW_BASE + n; if (oc < P.OC && nohow < P.N_OH_OW) { - uint n_, oh, ow; + int n_, oh, ow; layout::unpack_nohow(nohow, n_, oh, ow, P); OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i]; } @@ -246,28 +244,28 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo const conv_params P) { extern __shared__ unsigned char smem_raw[]; - const uint warpId = threadIdx.y; - const uint linear_tid = threadIdx.y * blockDim.x + threadIdx.x; + const int warpId = threadIdx.y; + const int linear_tid = threadIdx.y * blockDim.x + threadIdx.x; - const uint NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; - const uint NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); - const uint NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); + const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW; + const int NUM_WARPS_NOHOW = max(1, BS_NOHOW / WMMA_N); + const int NUM_WARPS_NEED = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N)); - const uint NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; + const int NUM_TILES_PER_WARP = (NUM_WARPS_NEED + num_warps - 1) / num_warps; mma acc[NUM_TILES_PER_WARP]; - const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const uint BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; - const uint BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; + const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int BL_IDX_OC = blockIdx.x / NUM_BL_NOHOW; + const int BL_IDX_NOHOW = blockIdx.x % NUM_BL_NOHOW; - const uint BLOCK_OC_BASE = BL_IDX_OC * BS_OC; - const uint BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; + const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC; + const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW; unsigned char * ptr = smem_raw; - const uint A_total = BS_OC * BS_ICKHKW; - const uint B_total = BS_ICKHKW * BS_NOHOW; + const int A_total = BS_OC * BS_ICKHKW; + const int B_total = BS_ICKHKW * BS_NOHOW; size_t offsetA = (size_t) A_total * sizeof(T); T * A_sh = reinterpret_cast(ptr); @@ -277,37 +275,37 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo T * B_sh = reinterpret_cast(ptr); ptr += offsetB; - for (uint t = 0; t < NUM_IC_TILES; ++t) { + for (int t = 0; t < NUM_IC_TILES; ++t) { #pragma unroll - for (uint tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { - const uint row = tid / BS_ICKHKW; - const uint col = tid % BS_ICKHKW; + for (int tid = linear_tid; tid < A_total; tid += (blockDim.x * blockDim.y)) { + const int row = tid / BS_ICKHKW; + const int col = tid % BS_ICKHKW; - const uint shared_oc = BLOCK_OC_BASE + row; - const uint shared_ickhkw = t * BS_ICKHKW + col; + const int shared_oc = BLOCK_OC_BASE + row; + const int shared_ickhkw = t * BS_ICKHKW + col; T val = ggml_cuda_cast(0); if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) { - uint ic, kh, kw; + int ic, kh, kw; layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P); - const uint kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); - val = IK[kidx]; + const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P); + val = IK[kidx]; } A_sh[row * BS_ICKHKW + col] = val; } #pragma unroll - for (uint tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { - const uint brow = tid / BS_NOHOW; - const uint bcol = tid % BS_NOHOW; + for (int tid = linear_tid; tid < B_total; tid += (blockDim.x * blockDim.y)) { + const int brow = tid / BS_NOHOW; + const int bcol = tid % BS_NOHOW; - const uint IC_KH_KW_IDX = t * BS_ICKHKW + brow; - const uint N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; + const int IC_KH_KW_IDX = t * BS_ICKHKW + brow; + const int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol; T val = ggml_cuda_cast(0); if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) { - uint n, oh, ow; - uint ic, kh, kw; + int n, oh, ow; + int ic, kh, kw; layout::unpack_nohow(N_OH_OW_IDX, n, oh, ow, P); layout::unpack_ickhkw(IC_KH_KW_IDX, ic, kh, kw, P); const int in_y = calculate_input_coord(oh, kh, P.ST_Y, P.DL_Y, P.PD_Y); @@ -323,19 +321,19 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo __syncthreads(); #pragma unroll - for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { - const uint warp = warpId + i * num_warps; + for (int i = 0; i < NUM_TILES_PER_WARP; i++) { + const int warp = warpId + i * num_warps; if (warp >= NUM_WARPS_NEED) { continue; } - const uint WARP_OC = warp / NUM_WARPS_NOHOW; - const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const int WARP_OC = warp / NUM_WARPS_NOHOW; + const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW; const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N; #pragma unroll - for (uint k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { + for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) { const T * A_k_ptr = A_warp_base + k_tile; const T * B_k_ptr = B_warp_base + k_tile * BS_NOHOW; acc[i].mma(A_k_ptr, B_k_ptr, BS_ICKHKW, BS_NOHOW); @@ -345,28 +343,29 @@ __global__ void __launch_bounds__(num_warps * WARP_SIZE) conv2d_kernel(const flo } #pragma unroll - for (uint i = 0; i < NUM_TILES_PER_WARP; i++) { - const uint warp = warpId + i * num_warps; + for (int i = 0; i < NUM_TILES_PER_WARP; i++) { + const int warp = warpId + i * num_warps; if (warp >= NUM_WARPS_NEED) { continue; } - const uint WARP_OC = warp / NUM_WARPS_NOHOW; - const uint WARP_NOHOW = warp % NUM_WARPS_NOHOW; - const uint OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; - const uint NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; + const int WARP_OC = warp / NUM_WARPS_NOHOW; + const int WARP_NOHOW = warp % NUM_WARPS_NOHOW; + const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M; + const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N; acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P); } } template class mma> -static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params & P, cudaStream_t st) { GGML_ASSERT(BS_OC >= WMMA_M && BS_ICKHKW >= WMMA_K && BS_NOHOW >= WMMA_N); + GGML_ASSERT(BS_ICKHKW % WMMA_K == 0); - const uint NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; - const uint NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; - const uint NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; + const int NUM_BL_OC = (P.OC + BS_OC - 1) / BS_OC; + const int NUM_BL_NOHOW = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW; + const int NUM_BL = NUM_BL_OC * NUM_BL_NOHOW; - constexpr uint NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = (CUDA_CONV2D_BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; const size_t A_bytes = BS_OC * BS_ICKHKW * sizeof(T); const size_t B_bytes = BS_ICKHKW * BS_NOHOW * sizeof(T); @@ -402,28 +401,28 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { cudaStream_t st = ctx.stream(); const int32_t * p = (const int32_t *) dst->op_params; - const uint ST_X = p[0]; // stride_x - const uint ST_Y = p[1]; // stride_y - const uint PD_X = p[2]; // padding_x - const uint PD_Y = p[3]; // padding_y - const uint DL_X = p[4]; // dilation_x - const uint DL_Y = p[5]; // dilation_y + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y // No cwhn GGML_ASSERT(p[6] == false); - const uint IW = input->ne[0]; // input_w - const uint IH = input->ne[1]; // input_h - const uint OW = dst->ne[0]; // output_w - const uint OH = dst->ne[1]; // output_h - const uint KW = kernel->ne[0]; // kernel_w - const uint KH = kernel->ne[1]; // kernel_h - const uint IC = input->ne[2]; // input_channels - const uint OC = kernel->ne[3]; // ouptut_chanles - const uint B = input->ne[3]; // n_batches - - const uint IC_KH_KW = IC * KH * KW; - const uint N_OH_OW = B * OH * OW; + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int IC_KH_KW = IC * KH * KW; + const int N_OH_OW = B * OH * OW; const conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, IC_KH_KW, N_OH_OW };