Skip to content

Commit 126737c

Browse files
committed
Change ops::cast to get rid of cast_float_fallback
1 parent 09dc820 commit 126737c

File tree

2 files changed

+257
-202
lines changed

2 files changed

+257
-202
lines changed

include/kernel_float/unops.h

+32-26
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,49 @@ struct cast<T, T, m> {
2020
};
2121

2222
template<typename T>
23-
struct cast<T, T, RoundingMode::ANY> {
23+
struct cast<T, T> {
2424
KERNEL_FLOAT_INLINE T operator()(T input) noexcept {
2525
return input;
2626
}
2727
};
2828

29-
template<typename T, typename R, typename = void>
30-
struct cast_float_fallback;
31-
32-
template<typename T, typename R, typename>
33-
struct cast_float_fallback {
29+
template<typename T, typename R>
30+
struct cast<T, R> {
3431
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
35-
return R(input);
32+
if constexpr (
33+
detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value) {
34+
return cast<float, R> {}(cast<T, float> {}(input));
35+
} else {
36+
return R(input);
37+
}
3638
}
3739
};
3840

39-
// clang-format off
40-
template<typename T, typename R>
41-
struct cast_float_fallback<
42-
T,
43-
R,
44-
enable_if_t<
45-
!is_same_type<T, float> &&
46-
!is_same_type<R, float> &&
47-
(detail::allow_float_fallback<T>::value || detail::allow_float_fallback<R>::value)
48-
>
49-
> {
50-
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
51-
return cast<float, R> {}(cast<T, float> {}(input));
41+
template<>
42+
struct cast<float, float> {
43+
KERNEL_FLOAT_INLINE float operator()(float input) noexcept {
44+
return input;
5245
}
5346
};
54-
// clang-format on
5547

56-
template<typename T, typename R>
57-
struct cast<T, R, RoundingMode::ANY> {
58-
KERNEL_FLOAT_INLINE R operator()(T input) noexcept {
59-
return cast_float_fallback<T, R> {}(input);
48+
template<RoundingMode m>
49+
struct cast<float, float, m> {
50+
KERNEL_FLOAT_INLINE float operator()(float input) noexcept {
51+
return input;
52+
}
53+
};
54+
55+
template<typename T>
56+
struct cast<T, float> {
57+
KERNEL_FLOAT_INLINE float operator()(T input) noexcept {
58+
return float(input);
59+
}
60+
};
61+
62+
template<typename T>
63+
struct cast<float, T> {
64+
KERNEL_FLOAT_INLINE T operator()(float input) noexcept {
65+
return T(input);
6066
}
6167
};
6268

@@ -255,7 +261,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input))
255261
}
256262

257263
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d")
258-
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d")
264+
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.ftz.f64", "d")
259265

260266
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
261267
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f")

0 commit comments

Comments
 (0)