1
1
#include " conv2d.cuh"
2
2
#include " convert.cuh"
3
3
4
- #include < mma.h>
5
- using namespace nvcuda ;
6
-
4
+ #ifdef FP16_MMA_AVAILABLE
5
+ # if !defined(GGML_USE_HIP)
6
+ # include < mma.h>
7
+ # ifdef GGML_USE_MUSA
8
+ namespace wmma = mtmusa::wmma;
9
+ # else
10
+ namespace wmma = nvcuda::wmma;
11
+ # endif
12
+ # else
13
+ # include < rocwmma/rocwmma.hpp>
14
+ namespace wmma = rocwmma;
15
+ # endif
16
+ #endif
7
17
struct conv_params {
8
18
const int64_t IW, IH;
9
19
const int64_t OW, OH;
@@ -111,6 +121,8 @@ class float_mma {
111
121
__device__ __forceinline__ float * store_result () const { return buf; }
112
122
};
113
123
124
+ #if (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(FP16_MMA_AVAILABLE)))
125
+
114
126
class half_mma {
115
127
private:
116
128
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float > acc;
@@ -136,6 +148,42 @@ class half_mma {
136
148
}
137
149
};
138
150
151
+ #else
152
+
153
+ class half_mma {
154
+ public:
155
+ float * buf;
156
+
157
+ __device__ __forceinline__ half_mma (float * scratch) {
158
+ buf = scratch;
159
+ const int lane_id = threadIdx .x % warpSize ;
160
+ # pragma unroll
161
+ for (int i = lane_id; i < WMMA_M * WMMA_N; i += warpSize ) {
162
+ buf[i] = 0 .0f ;
163
+ }
164
+ }
165
+
166
+ __device__ __forceinline__ void mma (const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
167
+ const int lane_id = threadIdx .x % warpSize ;
168
+ # pragma unroll
169
+ for (int e = lane_id; e < (WMMA_M * WMMA_N); e += warpSize ) {
170
+ int m = e / WMMA_N;
171
+ int n = e % WMMA_N;
172
+ float sum = buf[m * WMMA_N + n];
173
+ # pragma unroll
174
+ for (int k = 0 ; k < WMMA_K; k++) {
175
+ float a = A_sh[m * strideA + k];
176
+ float b = B_sh[k * strideB + n];
177
+ sum = fmaf (__half2float (a), __half2float (b), sum);
178
+ }
179
+ buf[m * WMMA_N + n] = sum;
180
+ }
181
+ }
182
+
183
+ __device__ __forceinline__ float * store_result () const { return buf; }
184
+ };
185
+ #endif // defined((__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || defined(FP16_MMA_AVAILABLE))
186
+
139
187
template <typename T, typename layout, typename mma>
140
188
static __global__ void conv2d_kernel (const float * IN, const T * IK, float * OUT, const conv_params P) {
141
189
extern __shared__ unsigned char smem_raw[];
0 commit comments