Skip to content

Commit 9b41485

Browse files
committed
Move cast definitions to resolve NVRTC error 'explicit specialization of class must precede its first use'
1 parent 81efb0f commit 9b41485

File tree

3 files changed

+216
-216
lines changed

3 files changed

+216
-216
lines changed

include/kernel_float/bf16.h

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,56 @@ struct allow_float_fallback<bfloat16_t> {
5656
};
5757
}; // namespace detail
5858

59+
#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \
60+
namespace ops { \
61+
template<> \
62+
struct cast<T, bfloat16_t> { \
63+
KERNEL_FLOAT_INLINE bfloat16_t operator()(T input) { \
64+
return TO_HALF; \
65+
} \
66+
}; \
67+
template<> \
68+
struct cast<bfloat16_t, T> { \
69+
KERNEL_FLOAT_INLINE T operator()(bfloat16_t input) { \
70+
return FROM_HALF; \
71+
} \
72+
}; \
73+
}
74+
75+
KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input))
76+
KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input))
77+
78+
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
79+
// clang-format off
80+
// there are no official char casts. Instead, cast to int and then to char
81+
KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input));
82+
KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input));
83+
KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input));
84+
85+
KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input));
86+
KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input));
87+
KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input)));
88+
KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input));
89+
90+
KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input));
91+
KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input));
92+
KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input)));
93+
KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
94+
// clang-format on
95+
#endif
96+
97+
#if KERNEL_FLOAT_IS_CUDA
98+
//KERNEL_FLOAT_BF16_CAST(
99+
// bool,
100+
// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
101+
// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
102+
#elif KERNEL_FLOAT_IS_HIP
103+
KERNEL_FLOAT_BF16_CAST(
104+
bool,
105+
__ushort_as_bfloat16(input ? (unsigned short)0 : (unsigned short)0x3C00),
106+
(__bfloat16_as_ushort(input) & 0x7FFF) != 0);
107+
#endif
108+
59109
#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \
60110
namespace ops { \
61111
template<> \
@@ -220,56 +270,6 @@ KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH)
220270
} // namespace detail
221271
#endif
222272

223-
#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \
224-
namespace ops { \
225-
template<> \
226-
struct cast<T, bfloat16_t> { \
227-
KERNEL_FLOAT_INLINE bfloat16_t operator()(T input) { \
228-
return TO_HALF; \
229-
} \
230-
}; \
231-
template<> \
232-
struct cast<bfloat16_t, T> { \
233-
KERNEL_FLOAT_INLINE T operator()(bfloat16_t input) { \
234-
return FROM_HALF; \
235-
} \
236-
}; \
237-
}
238-
239-
KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input))
240-
KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input))
241-
242-
#if KERNEL_FLOAT_BF16_OPS_AVAILABLE
243-
// clang-format off
244-
// there are no official char casts. Instead, cast to int and then to char
245-
KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input));
246-
KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input));
247-
KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input));
248-
249-
KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input));
250-
KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input));
251-
KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input)));
252-
KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input));
253-
254-
KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input));
255-
KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input));
256-
KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input)));
257-
KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input));
258-
// clang-format on
259-
#endif
260-
261-
#if KERNEL_FLOAT_IS_CUDA
262-
//KERNEL_FLOAT_BF16_CAST(
263-
// bool,
264-
// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00},
265-
// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0);
266-
#elif KERNEL_FLOAT_IS_HIP
267-
KERNEL_FLOAT_BF16_CAST(
268-
bool,
269-
__ushort_as_bfloat16(input ? (unsigned short)0 : (unsigned short)0x3C00),
270-
(__bfloat16_as_ushort(input) & 0x7FFF) != 0);
271-
#endif
272-
273273
KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, bfloat16_t)
274274
//KERNEL_FLOAT_TYPE_ALIAS(float16x, bfloat16_t)
275275
//KERNEL_FLOAT_TYPE_ALIAS(f16x, bfloat16_t)

include/kernel_float/fp16.h

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,63 @@ struct allow_float_fallback<half_t> {
4949
};
5050
} // namespace detail
5151

52+
#define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \
53+
namespace ops { \
54+
template<> \
55+
struct cast<T, half_t> { \
56+
KERNEL_FLOAT_INLINE half_t operator()(T input) { \
57+
return TO_HALF; \
58+
} \
59+
}; \
60+
template<> \
61+
struct cast<half_t, T> { \
62+
KERNEL_FLOAT_INLINE T operator()(half_t input) { \
63+
return FROM_HALF; \
64+
} \
65+
}; \
66+
}
67+
68+
// Only CUDA has a special `__double2half` intrinsic
69+
#if KERNEL_FLOAT_IS_HIP
70+
#define KERNEL_FLOAT_FP16_CAST_FWD(T) \
71+
KERNEL_FLOAT_FP16_CAST(T, static_cast<_Float16>(input), static_cast<T>(input))
72+
73+
KERNEL_FLOAT_FP16_CAST_FWD(double)
74+
KERNEL_FLOAT_FP16_CAST_FWD(float)
75+
76+
KERNEL_FLOAT_FP16_CAST_FWD(char)
77+
KERNEL_FLOAT_FP16_CAST_FWD(signed char)
78+
KERNEL_FLOAT_FP16_CAST_FWD(unsigned char)
79+
80+
KERNEL_FLOAT_FP16_CAST_FWD(signed short)
81+
KERNEL_FLOAT_FP16_CAST_FWD(signed int)
82+
KERNEL_FLOAT_FP16_CAST_FWD(signed long)
83+
KERNEL_FLOAT_FP16_CAST_FWD(signed long long)
84+
85+
KERNEL_FLOAT_FP16_CAST_FWD(unsigned short)
86+
KERNEL_FLOAT_FP16_CAST_FWD(unsigned int)
87+
KERNEL_FLOAT_FP16_CAST_FWD(unsigned long)
88+
KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long)
89+
#else
90+
KERNEL_FLOAT_FP16_CAST(double, __double2half(input), double(__half2float(input)));
91+
KERNEL_FLOAT_FP16_CAST(float, __float2half(input), __half2float(input));
92+
93+
// there are no official char casts. Instead, cast to int and then to char
94+
KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input));
95+
KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input));
96+
KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input));
97+
98+
KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input));
99+
KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input));
100+
KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input)));
101+
KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input));
102+
103+
KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input));
104+
KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input));
105+
KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input)));
106+
KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input));
107+
#endif
108+
52109
#if KERNEL_FLOAT_IS_DEVICE
53110
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \
54111
namespace ops { \
@@ -179,63 +236,6 @@ KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
179236
#endif // KERNEL_FLOAT_IS_DEVICE
180237
#endif //KERNEL_FLOAT_FP16_OPS_AVAILABLE
181238

182-
#define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \
183-
namespace ops { \
184-
template<> \
185-
struct cast<T, half_t> { \
186-
KERNEL_FLOAT_INLINE half_t operator()(T input) { \
187-
return TO_HALF; \
188-
} \
189-
}; \
190-
template<> \
191-
struct cast<half_t, T> { \
192-
KERNEL_FLOAT_INLINE T operator()(half_t input) { \
193-
return FROM_HALF; \
194-
} \
195-
}; \
196-
}
197-
198-
// Only CUDA has a special `__double2half` intrinsic
199-
#if KERNEL_FLOAT_IS_HIP
200-
#define KERNEL_FLOAT_FP16_CAST_FWD(T) \
201-
KERNEL_FLOAT_FP16_CAST(T, static_cast<_Float16>(input), static_cast<T>(input))
202-
203-
KERNEL_FLOAT_FP16_CAST_FWD(double)
204-
KERNEL_FLOAT_FP16_CAST_FWD(float)
205-
206-
KERNEL_FLOAT_FP16_CAST_FWD(char)
207-
KERNEL_FLOAT_FP16_CAST_FWD(signed char)
208-
KERNEL_FLOAT_FP16_CAST_FWD(unsigned char)
209-
210-
KERNEL_FLOAT_FP16_CAST_FWD(signed short)
211-
KERNEL_FLOAT_FP16_CAST_FWD(signed int)
212-
KERNEL_FLOAT_FP16_CAST_FWD(signed long)
213-
KERNEL_FLOAT_FP16_CAST_FWD(signed long long)
214-
215-
KERNEL_FLOAT_FP16_CAST_FWD(unsigned short)
216-
KERNEL_FLOAT_FP16_CAST_FWD(unsigned int)
217-
KERNEL_FLOAT_FP16_CAST_FWD(unsigned long)
218-
KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long)
219-
#else
220-
KERNEL_FLOAT_FP16_CAST(double, __double2half(input), double(__half2float(input)));
221-
KERNEL_FLOAT_FP16_CAST(float, __float2half(input), __half2float(input));
222-
223-
// there are no official char casts. Instead, cast to int and then to char
224-
KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input));
225-
KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input));
226-
KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input));
227-
228-
KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input));
229-
KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input));
230-
KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input)));
231-
KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input));
232-
233-
KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input));
234-
KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input));
235-
KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input)));
236-
KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input));
237-
#endif
238-
239239
KERNEL_FLOAT_VECTOR_ALIAS(half, half_t)
240240
//KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t)
241241
//KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t)

0 commit comments

Comments
 (0)