Skip to content

Commit 2cd9fb0

Browse files
committed
CUDA: conv2d support fp16 without wmma
* removed flash-attenion definition
1 parent 96db627 commit 2cd9fb0

File tree

1 file changed

+51
-3
lines changed

1 file changed

+51
-3
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
#include "conv2d.cuh"
22
#include "convert.cuh"
33

4-
#include <mma.h>
5-
using namespace nvcuda;
6-
4+
#ifdef FP16_MMA_AVAILABLE
5+
# if !defined(GGML_USE_HIP)
6+
# include <mma.h>
7+
# ifdef GGML_USE_MUSA
8+
namespace wmma = mtmusa::wmma;
9+
# else
10+
namespace wmma = nvcuda::wmma;
11+
# endif
12+
# else
13+
# include <rocwmma/rocwmma.hpp>
14+
namespace wmma = rocwmma;
15+
# endif
16+
#endif
717
struct conv_params {
818
const int64_t IW, IH;
919
const int64_t OW, OH;
@@ -111,6 +121,8 @@ class float_mma {
111121
__device__ __forceinline__ float * store_result() const { return buf; }
112122
};
113123

124+
#if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE)))
125+
114126
class half_mma {
115127
private:
116128
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc;
@@ -136,6 +148,42 @@ class half_mma {
136148
}
137149
};
138150

151+
#else
152+
153+
class half_mma {
154+
public:
155+
float * buf;
156+
157+
__device__ __forceinline__ half_mma(float * scratch) {
158+
buf = scratch;
159+
const int lane_id = threadIdx.x % warpSize;
160+
# pragma unroll
161+
for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize) {
162+
buf[i] = 0.0f;
163+
}
164+
}
165+
166+
__device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
167+
const int lane_id = threadIdx.x % warpSize;
168+
# pragma unroll
169+
for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize) {
170+
int m = e / WMMA_N;
171+
int n = e % WMMA_N;
172+
float sum = buf[m * WMMA_N + n];
173+
# pragma unroll
174+
for (int k = 0; k < WMMA_K; k++) {
175+
float a = A_sh[m * strideA + k];
176+
float b = B_sh[k * strideB + n];
177+
sum = fmaf(__half2float(a), __half2float(b), sum);
178+
}
179+
buf[m * WMMA_N + n] = sum;
180+
}
181+
}
182+
183+
__device__ __forceinline__ float * store_result() const { return buf; }
184+
};
185+
#endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE))
186+
139187
template <typename T, typename layout, typename mma>
140188
static __global__ void conv2d_kernel(const float * IN, const T * IK, float * OUT, const conv_params P) {
141189
extern __shared__ unsigned char smem_raw[];

0 commit comments

Comments
 (0)