@@ -49,6 +49,63 @@ struct allow_float_fallback<half_t> {
4949};
5050} // namespace detail
5151
52+ #define KERNEL_FLOAT_FP16_CAST (T, TO_HALF, FROM_HALF ) \
53+ namespace ops { \
54+ template <> \
55+ struct cast <T, half_t > { \
56+ KERNEL_FLOAT_INLINE half_t operator ()(T input) { \
57+ return TO_HALF; \
58+ } \
59+ }; \
60+ template <> \
61+ struct cast <half_t , T> { \
62+ KERNEL_FLOAT_INLINE T operator ()(half_t input) { \
63+ return FROM_HALF; \
64+ } \
65+ }; \
66+ }
67+
68+ // Only CUDA has a special `__double2half` intrinsic
69+ #if KERNEL_FLOAT_IS_HIP
70+ #define KERNEL_FLOAT_FP16_CAST_FWD (T ) \
71+ KERNEL_FLOAT_FP16_CAST (T, static_cast <_Float16>(input), static_cast<T>(input))
72+
73+ KERNEL_FLOAT_FP16_CAST_FWD(double )
74+ KERNEL_FLOAT_FP16_CAST_FWD(float )
75+
76+ KERNEL_FLOAT_FP16_CAST_FWD(char )
77+ KERNEL_FLOAT_FP16_CAST_FWD(signed char )
78+ KERNEL_FLOAT_FP16_CAST_FWD(unsigned char )
79+
80+ KERNEL_FLOAT_FP16_CAST_FWD(signed short )
81+ KERNEL_FLOAT_FP16_CAST_FWD(signed int )
82+ KERNEL_FLOAT_FP16_CAST_FWD(signed long )
83+ KERNEL_FLOAT_FP16_CAST_FWD(signed long long )
84+
85+ KERNEL_FLOAT_FP16_CAST_FWD(unsigned short )
86+ KERNEL_FLOAT_FP16_CAST_FWD(unsigned int )
87+ KERNEL_FLOAT_FP16_CAST_FWD(unsigned long )
88+ KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long )
89+ #else
90+ KERNEL_FLOAT_FP16_CAST (double , __double2half(input), double(__half2float(input)));
91+ KERNEL_FLOAT_FP16_CAST (float , __float2half(input), __half2float(input));
92+
93+ // there are no official char casts. Instead, cast to int and then to char
94+ KERNEL_FLOAT_FP16_CAST (char , __int2half_rn(input), (char )__half2int_rz(input));
95+ KERNEL_FLOAT_FP16_CAST (signed char , __int2half_rn(input), (signed char )__half2int_rz(input));
96+ KERNEL_FLOAT_FP16_CAST (unsigned char , __int2half_rn(input), (unsigned char )__half2int_rz(input));
97+
98+ KERNEL_FLOAT_FP16_CAST (signed short , __short2half_rn(input), __half2short_rz(input));
99+ KERNEL_FLOAT_FP16_CAST (signed int , __int2half_rn(input), __half2int_rz(input));
100+ KERNEL_FLOAT_FP16_CAST (signed long , __ll2half_rn(input), (signed long )(__half2ll_rz(input)));
101+ KERNEL_FLOAT_FP16_CAST (signed long long , __ll2half_rn(input), __half2ll_rz(input));
102+
103+ KERNEL_FLOAT_FP16_CAST (unsigned short , __ushort2half_rn(input), __half2ushort_rz(input));
104+ KERNEL_FLOAT_FP16_CAST (unsigned int , __uint2half_rn(input), __half2uint_rz(input));
105+ KERNEL_FLOAT_FP16_CAST (unsigned long , __ull2half_rn(input), (unsigned long )(__half2ull_rz(input)));
106+ KERNEL_FLOAT_FP16_CAST (unsigned long long , __ull2half_rn(input), __half2ull_rz(input));
107+ #endif
108+
52109#if KERNEL_FLOAT_IS_DEVICE
53110#define KERNEL_FLOAT_FP16_UNARY_FUN (NAME, FUN1, FUN2 ) \
54111 namespace ops { \
@@ -179,63 +236,6 @@ KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
179236#endif // KERNEL_FLOAT_IS_DEVICE
180237#endif // KERNEL_FLOAT_FP16_OPS_AVAILABLE
181238
182- #define KERNEL_FLOAT_FP16_CAST (T, TO_HALF, FROM_HALF ) \
183- namespace ops { \
184- template <> \
185- struct cast <T, half_t > { \
186- KERNEL_FLOAT_INLINE half_t operator ()(T input) { \
187- return TO_HALF; \
188- } \
189- }; \
190- template <> \
191- struct cast <half_t , T> { \
192- KERNEL_FLOAT_INLINE T operator ()(half_t input) { \
193- return FROM_HALF; \
194- } \
195- }; \
196- }
197-
198- // Only CUDA has a special `__double2half` intrinsic
199- #if KERNEL_FLOAT_IS_HIP
200- #define KERNEL_FLOAT_FP16_CAST_FWD (T ) \
201- KERNEL_FLOAT_FP16_CAST (T, static_cast <_Float16>(input), static_cast<T>(input))
202-
203- KERNEL_FLOAT_FP16_CAST_FWD(double )
204- KERNEL_FLOAT_FP16_CAST_FWD(float )
205-
206- KERNEL_FLOAT_FP16_CAST_FWD(char )
207- KERNEL_FLOAT_FP16_CAST_FWD(signed char )
208- KERNEL_FLOAT_FP16_CAST_FWD(unsigned char )
209-
210- KERNEL_FLOAT_FP16_CAST_FWD(signed short )
211- KERNEL_FLOAT_FP16_CAST_FWD(signed int )
212- KERNEL_FLOAT_FP16_CAST_FWD(signed long )
213- KERNEL_FLOAT_FP16_CAST_FWD(signed long long )
214-
215- KERNEL_FLOAT_FP16_CAST_FWD(unsigned short )
216- KERNEL_FLOAT_FP16_CAST_FWD(unsigned int )
217- KERNEL_FLOAT_FP16_CAST_FWD(unsigned long )
218- KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long )
219- #else
220- KERNEL_FLOAT_FP16_CAST (double , __double2half(input), double(__half2float(input)));
221- KERNEL_FLOAT_FP16_CAST (float , __float2half(input), __half2float(input));
222-
223- // there are no official char casts. Instead, cast to int and then to char
224- KERNEL_FLOAT_FP16_CAST (char , __int2half_rn(input), (char )__half2int_rz(input));
225- KERNEL_FLOAT_FP16_CAST (signed char , __int2half_rn(input), (signed char )__half2int_rz(input));
226- KERNEL_FLOAT_FP16_CAST (unsigned char , __int2half_rn(input), (unsigned char )__half2int_rz(input));
227-
228- KERNEL_FLOAT_FP16_CAST (signed short , __short2half_rn(input), __half2short_rz(input));
229- KERNEL_FLOAT_FP16_CAST (signed int , __int2half_rn(input), __half2int_rz(input));
230- KERNEL_FLOAT_FP16_CAST (signed long , __ll2half_rn(input), (signed long )(__half2ll_rz(input)));
231- KERNEL_FLOAT_FP16_CAST (signed long long , __ll2half_rn(input), __half2ll_rz(input));
232-
233- KERNEL_FLOAT_FP16_CAST (unsigned short , __ushort2half_rn(input), __half2ushort_rz(input));
234- KERNEL_FLOAT_FP16_CAST (unsigned int , __uint2half_rn(input), __half2uint_rz(input));
235- KERNEL_FLOAT_FP16_CAST (unsigned long , __ull2half_rn(input), (unsigned long )(__half2ull_rz(input)));
236- KERNEL_FLOAT_FP16_CAST (unsigned long long , __ull2half_rn(input), __half2ull_rz(input));
237- #endif
238-
239239KERNEL_FLOAT_VECTOR_ALIAS (half, half_t )
240240// KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t)
241241// KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t)
0 commit comments