@@ -20,43 +20,49 @@ struct cast<T, T, m> {
20
20
};
21
21
22
22
template <typename T>
23
- struct cast <T, T, RoundingMode::ANY > {
23
+ struct cast <T, T> {
24
24
KERNEL_FLOAT_INLINE T operator ()(T input) noexcept {
25
25
return input;
26
26
}
27
27
};
28
28
29
- template <typename T, typename R, typename = void >
30
- struct cast_float_fallback ;
31
-
32
- template <typename T, typename R, typename >
33
- struct cast_float_fallback {
29
+ template <typename T, typename R>
30
+ struct cast <T, R> {
34
31
KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
35
- return R (input);
32
+ if constexpr (
33
+ detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value) {
34
+ return cast<float , R> {}(cast<T, float > {}(input));
35
+ } else {
36
+ return R (input);
37
+ }
36
38
}
37
39
};
38
40
39
- // clang-format off
40
- template <typename T, typename R>
41
- struct cast_float_fallback <
42
- T,
43
- R,
44
- enable_if_t <
45
- !is_same_type<T, float > &&
46
- !is_same_type<R, float > &&
47
- (detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value)
48
- >
49
- > {
50
- KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
51
- return cast<float , R> {}(cast<T, float > {}(input));
41
+ template <>
42
+ struct cast <float , float > {
43
+ KERNEL_FLOAT_INLINE float operator ()(float input) noexcept {
44
+ return input;
52
45
}
53
46
};
54
- // clang-format on
55
47
56
- template <typename T, typename R>
57
- struct cast <T, R, RoundingMode::ANY> {
58
- KERNEL_FLOAT_INLINE R operator ()(T input) noexcept {
59
- return cast_float_fallback<T, R> {}(input);
48
+ template <RoundingMode m>
49
+ struct cast <float , float , m> {
50
+ KERNEL_FLOAT_INLINE float operator ()(float input) noexcept {
51
+ return input;
52
+ }
53
+ };
54
+
55
+ template <typename T>
56
+ struct cast <T, float > {
57
+ KERNEL_FLOAT_INLINE float operator ()(T input) noexcept {
58
+ return float (input);
59
+ }
60
+ };
61
+
62
+ template <typename T>
63
+ struct cast <float , T> {
64
+ KERNEL_FLOAT_INLINE T operator ()(float input) noexcept {
65
+ return T (input);
60
66
}
61
67
};
62
68
@@ -255,7 +261,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input))
255
261
}
256
262
257
263
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX (double , rcp, " rcp.approx.ftz.f64" , " d" )
258
- KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double , rsqrt, " rsqrt.approx.f64" , " d" )
264
+ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double , rsqrt, " rsqrt.approx.ftz. f64" , " d" )
259
265
260
266
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float , exp2, " ex2.approx.f32" , " f" )
261
267
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float , sqrt, " sqrt.approx.f32" , " f" )
0 commit comments