Skip to content

Commit 5c859b9

Browse files
committed
kernel_float::approx::sqrt(0) now returns 0
1 parent e6c8a7c commit 5c859b9

File tree

2 files changed

+98
-37
lines changed

2 files changed

+98
-37
lines changed

Diff for: include/kernel_float/approx.h

+10-5
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {
160160

161161
template<int Iter>
162162
KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) {
163+
// A small number added such that rsqrt(0) does not return NaN
164+
static constexpr double EPS = 0.00000768899917602539;
165+
163166
// Set top and bottom bits for both halfs, then shift by 1, then invert
164167
uint32_t r = ~((uint32_t(transmute<uint32_t>(x) >> 1)) | ~uint32_t(0x3fff3fff));
165-
//uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
166168

167-
// Add bias (0x199c)
168-
half2_t y = transmute<half2_t>(uint32_t(r) + uint32_t(0x199c199c));
169+
// Add bias
170+
static constexpr uint32_t BIAS = 0x199c199c;
171+
half2_t y = transmute<half2_t>(uint32_t(r) + BIAS);
169172

170173
// Newton-Raphson iterations
171174
#pragma unroll
172175
for (int i = 0; i < Iter; i++) {
173-
half2_t half_x = make_half2(-0.5) * x;
176+
half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS));
174177
half2_t correction = __hfma2(half_x, y * y, make_half2(0.5));
175178
y = __hfma2(correction, y, y); // y += y * correction
176179
}
@@ -365,7 +368,7 @@ template<int Level, typename F, typename T>
365368
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
366369
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
367370
T in2[2], out2[2];
368-
out2[0] = input[0];
371+
in2[0] = input[0];
369372
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
370373
output[0] = out2[0];
371374
}
@@ -396,6 +399,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
396399
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
397400
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
398401
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
402+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
403+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2)
399404
#endif
400405

401406
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED

Diff for: single_include/kernel_float.h

+88-32
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
//================================================================================
1818
// this file has been auto-generated, do not modify its contents!
19-
// date: 2024-11-18 16:57:58.817191
20-
// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
19+
// date: 2024-11-20 10:36:45.284577
20+
// git hash: 76501fda40df9e396998d11840bc8f10b11ea47b
2121
//================================================================================
2222

2323
#ifndef KERNEL_FLOAT_MACROS_H
@@ -813,7 +813,7 @@ struct approx_level_policy {};
813813
using approx_policy = approx_level_policy<>;
814814

815815
#ifndef KERNEL_FLOAT_POLICY
816-
#define KERNEL_FLOAT_POLICY accurate_policy;
816+
#define KERNEL_FLOAT_POLICY accurate_policy
817817
#endif
818818

819819
/**
@@ -1448,6 +1448,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
14481448
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f")
14491449
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")
14501450

1451+
#define KERNEL_FLOAT_FAST_F32_MAP(F) \
1452+
F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)
1453+
14511454
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f")
14521455
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
14531456
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
@@ -1724,15 +1727,15 @@ using zip_common_type = vector<
17241727
* vec<float, 3> c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f]
17251728
* ```
17261729
*/
1727-
template<typename F, typename L, typename R>
1730+
template<typename Accuracy = default_policy, typename F, typename L, typename R>
17281731
KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, const R& right) {
17291732
using T = promoted_vector_value_type<L, R>;
17301733
using O = result_t<F, T, T>;
17311734
using E = broadcast_vector_extent_type<L, R>;
17321735

17331736
vector_storage<O, extent_size<E>> result;
17341737

1735-
detail::default_map_impl<F, extent_size<E>, O, T, T>::call(
1738+
detail::map_impl<Accuracy, F, extent_size<E>, O, T, T>::call(
17361739
fun,
17371740
result.data(),
17381741
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
@@ -1745,10 +1748,17 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
17451748
return result;
17461749
}
17471750

1748-
#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
1749-
template<typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
1750-
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
1751-
return zip_common(ops::NAME<C> {}, static_cast<L&&>(left), static_cast<R&&>(right)); \
1751+
#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
1752+
template< \
1753+
typename Accuracy = default_policy, \
1754+
typename L, \
1755+
typename R, \
1756+
typename C = promoted_vector_value_type<L, R>> \
1757+
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
1758+
return zip_common<Accuracy>( \
1759+
ops::NAME<C> {}, \
1760+
static_cast<L&&>(left), \
1761+
static_cast<R&&>(right)); \
17521762
}
17531763

17541764
#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \
@@ -3887,11 +3897,20 @@ struct vector: public S {
38873897
}
38883898

38893899
/**
3890-
* Returns the result of `*this + lhs * rhs`.
3900+
* Returns the result of `this + lhs * rhs`.
38913901
*
38923902
* The operation is performed using a single `kernel_float::fma` call, which may be faster then perform
38933903
* the addition and multiplication separately.
38943904
*/
3905+
template<
3906+
typename L,
3907+
typename R,
3908+
typename T2 = promote_t<T, vector_value_type<L>, vector_value_type<R>>,
3909+
typename E2 = broadcast_extent<E, vector_extent_type<L>, vector_extent_type<R>>>
3910+
KERNEL_FLOAT_INLINE vector<T2, E2> add_mul(const L& lhs, const R& rhs) const {
3911+
return ::kernel_float::fma(lhs, rhs, *this);
3912+
}
3913+
38953914
template<
38963915
typename L,
38973916
typename R,
@@ -4138,6 +4157,22 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
41384157
result[0] = r.x, result[1] = r.y;
41394158
}
41404159
};
4160+
4161+
// clang-format off
4162+
#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \
4163+
template<size_t N> \
4164+
struct apply_impl<fast_policy, ops::OP<half_t>, N, half_t, half_t> { \
4165+
KERNEL_FLOAT_INLINE static void \
4166+
call(ops::OP<half_t>, half_t* output, const half_t* input) { \
4167+
float v[N]; \
4168+
map_impl<fast_policy, ops::cast<half_t, float>, N, float, half_t>::call({}, v, input); \
4169+
map_impl<fast_policy, ops::OP<float>, N, float, float>::call({}, v, v); \
4170+
map_impl<fast_policy, ops::cast<float, half_t>, N, half_t, float>::call({}, output, v); \
4171+
} \
4172+
};
4173+
// clang-format on
4174+
4175+
KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
41414176
} // namespace detail
41424177
#endif
41434178

@@ -4390,6 +4425,22 @@ struct apply_impl<
43904425
result[0] = r.x, result[1] = r.y;
43914426
}
43924427
};
4428+
4429+
// clang-format off
4430+
#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \
4431+
template<size_t N> \
4432+
struct apply_impl<fast_policy, ops::OP<bfloat16_t>, N, bfloat16_t, bfloat16_t> { \
4433+
KERNEL_FLOAT_INLINE static void \
4434+
call(ops::OP<bfloat16_t>, bfloat16_t* output, const bfloat16_t* input) { \
4435+
float v[N]; \
4436+
map_impl<fast_policy, ops::cast<bfloat16_t, float>, N, float, bfloat16_t>::call({}, v, input); \
4437+
map_impl<fast_policy, ops::OP<float>, N, float, float>::call({}, v, v); \
4438+
map_impl<fast_policy, ops::cast<float, bfloat16_t>, N, bfloat16_t, float>::call({}, output, v); \
4439+
} \
4440+
};
4441+
// clang-format on
4442+
4443+
KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH)
43934444
} // namespace detail
43944445
#endif
43954446

@@ -4631,17 +4682,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {
46314682

46324683
template<int Iter>
46334684
KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) {
4685+
// A small number added such that rsqrt(0) does not return NaN
4686+
static constexpr double EPS = 0.00000768899917602539;
4687+
46344688
// Set top and bottom bits for both halfs, then shift by 1, then invert
46354689
uint32_t r = ~((uint32_t(transmute<uint32_t>(x) >> 1)) | ~uint32_t(0x3fff3fff));
4636-
//uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
46374690

4638-
// Add bias (0x199c)
4639-
half2_t y = transmute<half2_t>(uint32_t(r) + uint32_t(0x199c199c));
4691+
// Add bias
4692+
static constexpr uint32_t BIAS = 0x199c199c;
4693+
half2_t y = transmute<half2_t>(uint32_t(r) + BIAS);
46404694

46414695
// Newton-Raphson iterations
46424696
#pragma unroll
46434697
for (int i = 0; i < Iter; i++) {
4644-
half2_t half_x = make_half2(-0.5) * x;
4698+
half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS));
46454699
half2_t correction = __hfma2(half_x, y * y, make_half2(0.5));
46464700
y = __hfma2(correction, y, y); // y += y * correction
46474701
}
@@ -4836,7 +4890,7 @@ template<int Level, typename F, typename T>
48364890
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
48374891
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
48384892
T in2[2], out2[2];
4839-
out2[0] = input[0];
4893+
in2[0] = input[0];
48404894
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
48414895
output[0] = out2[0];
48424896
}
@@ -4867,6 +4921,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
48674921
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
48684922
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
48694923
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
4924+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
4925+
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2)
48704926
#endif
48714927

48724928
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
@@ -4960,7 +5016,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
49605016
#define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \
49615017
namespace detail { \
49625018
template<> \
4963-
struct apply_impl<ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
5019+
struct apply_impl<accurate_policy, ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
49645020
KERNEL_FLOAT_INLINE static void call(ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
49655021
__half2_raw x; \
49665022
memcpy(&x, v, 2 * sizeof(T)); \
@@ -4969,7 +5025,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
49695025
} \
49705026
}; \
49715027
template<> \
4972-
struct apply_impl<ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
5028+
struct apply_impl<accurate_policy, ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
49735029
KERNEL_FLOAT_INLINE static void call(ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
49745030
__nv_fp8x2_storage_t x; \
49755031
memcpy(&x, v, 2 * sizeof(FP8_TY)); \
@@ -4987,12 +5043,12 @@ KERNEL_FLOAT_FP8_CAST(double)
49875043

49885044

49895045
namespace kernel_float {
4990-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
4991-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
5046+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3)
5047+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2)
49925048

4993-
KERNEL_FLOAT_FP8_CAST(__half)
4994-
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3)
4995-
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
5049+
KERNEL_FLOAT_FP8_CAST(half_t)
5050+
KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3)
5051+
KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2)
49965052

49975053
} // namespace kernel_float
49985054
#endif // KERNEL_FLOAT_FP16_AVAILABLE
@@ -5001,12 +5057,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
50015057

50025058

50035059
namespace kernel_float {
5004-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
5005-
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
5060+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3)
5061+
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2)
50065062

5007-
KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
5008-
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
5009-
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
5063+
KERNEL_FLOAT_FP8_CAST(bfloat16_t)
5064+
KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3)
5065+
KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2)
50105066
} // namespace kernel_float
50115067
#endif // KERNEL_FLOAT_BF16_AVAILABLE
50125068

@@ -5075,14 +5131,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double)
50755131
KERNEL_FLOAT_TYPE_ALIAS(float64x, double)
50765132

50775133
#if KERNEL_FLOAT_FP16_AVAILABLE
5078-
KERNEL_FLOAT_TYPE_ALIAS(half, __half)
5079-
KERNEL_FLOAT_TYPE_ALIAS(f16x, __half)
5080-
KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
5134+
KERNEL_FLOAT_TYPE_ALIAS(half, half_t)
5135+
KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t)
5136+
KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t)
50815137
#endif
50825138

50835139
#if KERNEL_FLOAT_BF16_AVAILABLE
5084-
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16)
5085-
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16)
5140+
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, bfloat16_t)
5141+
KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t)
50865142
#endif
50875143

50885144
#if KERNEL_FLOAT_BF8_AVAILABLE

0 commit comments

Comments
 (0)