16
16
17
17
// ================================================================================
18
18
// this file has been auto-generated, do not modify its contents!
19
- // date: 2024-11-18 16:57:58.817191
20
- // git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
19
+ // date: 2024-11-20 10:36:45.284577
20
+ // git hash: 76501fda40df9e396998d11840bc8f10b11ea47b
21
21
// ================================================================================
22
22
23
23
#ifndef KERNEL_FLOAT_MACROS_H
@@ -813,7 +813,7 @@ struct approx_level_policy {};
813
813
using approx_policy = approx_level_policy<>;
814
814
815
815
#ifndef KERNEL_FLOAT_POLICY
816
- #define KERNEL_FLOAT_POLICY accurate_policy;
816
+ #define KERNEL_FLOAT_POLICY accurate_policy
817
817
#endif
818
818
819
819
/* *
@@ -1448,6 +1448,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
1448
1448
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float , rsqrt, " rsqrt.approx.f32" , " f" )
1449
1449
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float , tanh, " tanh.approx.f32;" , " f" )
1450
1450
1451
+ #define KERNEL_FLOAT_FAST_F32_MAP (F ) \
1452
+ F (exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)
1453
+
1451
1454
// KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f")
1452
1455
// KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
1453
1456
// KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
@@ -1724,15 +1727,15 @@ using zip_common_type = vector<
1724
1727
* vec<float, 3> c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f]
1725
1728
* ```
1726
1729
*/
1727
- template <typename F, typename L, typename R>
1730
+ template <typename Accuracy = default_policy, typename F, typename L, typename R>
1728
1731
KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common (F fun, const L& left, const R& right) {
1729
1732
using T = promoted_vector_value_type<L, R>;
1730
1733
using O = result_t <F, T, T>;
1731
1734
using E = broadcast_vector_extent_type<L, R>;
1732
1735
1733
1736
vector_storage<O, extent_size<E>> result;
1734
1737
1735
- detail::default_map_impl< F, extent_size<E>, O, T, T>::call (
1738
+ detail::map_impl<Accuracy, F, extent_size<E>, O, T, T>::call (
1736
1739
fun,
1737
1740
result.data (),
1738
1741
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call (
@@ -1745,10 +1748,17 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
1745
1748
return result;
1746
1749
}
1747
1750
1748
- #define KERNEL_FLOAT_DEFINE_BINARY_FUN (NAME ) \
1749
- template <typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
1750
- KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME (L&& left, R&& right) { \
1751
- return zip_common (ops::NAME<C> {}, static_cast <L&&>(left), static_cast <R&&>(right)); \
1751
+ #define KERNEL_FLOAT_DEFINE_BINARY_FUN (NAME ) \
1752
+ template < \
1753
+ typename Accuracy = default_policy, \
1754
+ typename L, \
1755
+ typename R, \
1756
+ typename C = promoted_vector_value_type<L, R>> \
1757
+ KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME (L&& left, R&& right) { \
1758
+ return zip_common<Accuracy>( \
1759
+ ops::NAME<C> {}, \
1760
+ static_cast <L&&>(left), \
1761
+ static_cast <R&&>(right)); \
1752
1762
}
1753
1763
1754
1764
#define KERNEL_FLOAT_DEFINE_BINARY (NAME, EXPR, EXPR_F64, EXPR_F32 ) \
@@ -3887,11 +3897,20 @@ struct vector: public S {
3887
3897
}
3888
3898
3889
3899
/* *
3890
- * Returns the result of `* this + lhs * rhs`.
3900
+ * Returns the result of `this + lhs * rhs`.
3891
3901
*
3892
3902
* The operation is performed using a single `kernel_float::fma` call, which may be faster then perform
3893
3903
* the addition and multiplication separately.
3894
3904
*/
3905
+ template <
3906
+ typename L,
3907
+ typename R,
3908
+ typename T2 = promote_t <T, vector_value_type<L>, vector_value_type<R>>,
3909
+ typename E2 = broadcast_extent<E, vector_extent_type<L>, vector_extent_type<R>>>
3910
+ KERNEL_FLOAT_INLINE vector<T2, E2 > add_mul (const L& lhs, const R& rhs) const {
3911
+ return ::kernel_float::fma (lhs, rhs, *this );
3912
+ }
3913
+
3895
3914
template <
3896
3915
typename L,
3897
3916
typename R,
@@ -4138,6 +4157,22 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
4138
4157
result[0 ] = r.x , result[1 ] = r.y ;
4139
4158
}
4140
4159
};
4160
+
4161
+ // clang-format off
4162
+ #define KERNEL_FLOAT_FAST_FP16_DISPATCH (OP ) \
4163
+ template <size_t N> \
4164
+ struct apply_impl <fast_policy, ops::OP<half_t >, N, half_t , half_t > { \
4165
+ KERNEL_FLOAT_INLINE static void \
4166
+ call (ops::OP<half_t >, half_t * output, const half_t * input) { \
4167
+ float v[N]; \
4168
+ map_impl<fast_policy, ops::cast<half_t , float >, N, float , half_t >::call ({}, v, input); \
4169
+ map_impl<fast_policy, ops::OP<float >, N, float , float >::call ({}, v, v); \
4170
+ map_impl<fast_policy, ops::cast<float , half_t >, N, half_t , float >::call ({}, output, v); \
4171
+ } \
4172
+ };
4173
+ // clang-format on
4174
+
4175
+ KERNEL_FLOAT_FAST_F32_MAP (KERNEL_FLOAT_FAST_FP16_DISPATCH)
4141
4176
} // namespace detail
4142
4177
#endif
4143
4178
@@ -4390,6 +4425,22 @@ struct apply_impl<
4390
4425
result[0 ] = r.x , result[1 ] = r.y ;
4391
4426
}
4392
4427
};
4428
+
4429
+ // clang-format off
4430
+ #define KERNEL_FLOAT_FAST_BF16_DISPATCH (OP ) \
4431
+ template <size_t N> \
4432
+ struct apply_impl <fast_policy, ops::OP<bfloat16_t >, N, bfloat16_t , bfloat16_t > { \
4433
+ KERNEL_FLOAT_INLINE static void \
4434
+ call (ops::OP<bfloat16_t >, bfloat16_t * output, const bfloat16_t * input) { \
4435
+ float v[N]; \
4436
+ map_impl<fast_policy, ops::cast<bfloat16_t , float >, N, float , bfloat16_t >::call ({}, v, input); \
4437
+ map_impl<fast_policy, ops::OP<float >, N, float , float >::call ({}, v, v); \
4438
+ map_impl<fast_policy, ops::cast<float , bfloat16_t >, N, bfloat16_t , float >::call ({}, output, v); \
4439
+ } \
4440
+ };
4441
+ // clang-format on
4442
+
4443
+ KERNEL_FLOAT_FAST_F32_MAP (KERNEL_FLOAT_FAST_BF16_DISPATCH)
4393
4444
} // namespace detail
4394
4445
#endif
4395
4446
@@ -4631,17 +4682,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {
4631
4682
4632
4683
template <int Iter>
4633
4684
KERNEL_FLOAT_DEVICE half2_t rsqrt (half2_t x) {
4685
+ // A small number added such that rsqrt(0) does not return NaN
4686
+ static constexpr double EPS = 0.00000768899917602539 ;
4687
+
4634
4688
// Set top and bottom bits for both halfs, then shift by 1, then invert
4635
4689
uint32_t r = ~((uint32_t (transmute<uint32_t >(x) >> 1 )) | ~uint32_t (0x3fff3fff ));
4636
- // uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;
4637
4690
4638
- // Add bias (0x199c)
4639
- half2_t y = transmute<half2_t >(uint32_t (r) + uint32_t (0x199c199c ));
4691
+ // Add bias
4692
+ static constexpr uint32_t BIAS = 0x199c199c ;
4693
+ half2_t y = transmute<half2_t >(uint32_t (r) + BIAS);
4640
4694
4641
4695
// Newton-Raphson iterations
4642
4696
#pragma unroll
4643
4697
for (int i = 0 ; i < Iter; i++) {
4644
- half2_t half_x = make_half2 (-0.5 ) * x ;
4698
+ half2_t half_x = __hfma2 ( make_half2 (-0.5 ), x, make_half2 (-EPS)) ;
4645
4699
half2_t correction = __hfma2 (half_x, y * y, make_half2 (0.5 ));
4646
4700
y = __hfma2 (correction, y, y); // y += y * correction
4647
4701
}
@@ -4836,7 +4890,7 @@ template<int Level, typename F, typename T>
4836
4890
struct apply_impl <approx_level_policy<Level>, F, 1 , T, T> {
4837
4891
KERNEL_FLOAT_INLINE static void call (F fun, T* output, const T* input) {
4838
4892
T in2[2 ], out2[2 ];
4839
- out2 [0 ] = input[0 ];
4893
+ in2 [0 ] = input[0 ];
4840
4894
apply_impl<approx_level_policy<Level>, F, 2 , T, T>::call (fun, out2, in2);
4841
4895
output[0 ] = out2[0 ];
4842
4896
}
@@ -4867,6 +4921,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
4867
4921
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , rcp, 1 )
4868
4922
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , exp, 0 )
4869
4923
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , log, 0 )
4924
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , asin, 2 )
4925
+ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t , acos, 2 )
4870
4926
#endif
4871
4927
4872
4928
#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
@@ -4960,7 +5016,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
4960
5016
#define KERNEL_FLOAT_FP8_CAST2 (T, FP8_TY, FP8_INTERP ) \
4961
5017
namespace detail { \
4962
5018
template <> \
4963
- struct apply_impl <ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
5019
+ struct apply_impl <accurate_policy, ops::cast<T, FP8_TY>, 2 , FP8_TY, T> { \
4964
5020
KERNEL_FLOAT_INLINE static void call (ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
4965
5021
__half2_raw x; \
4966
5022
memcpy (&x, v, 2 * sizeof (T)); \
@@ -4969,7 +5025,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
4969
5025
} \
4970
5026
}; \
4971
5027
template <> \
4972
- struct apply_impl <ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
5028
+ struct apply_impl <accurate_policy, ops::cast<FP8_TY, T>, 2 , T, FP8_TY> { \
4973
5029
KERNEL_FLOAT_INLINE static void call (ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
4974
5030
__nv_fp8x2_storage_t x; \
4975
5031
memcpy (&x, v, 2 * sizeof (FP8_TY)); \
@@ -4987,12 +5043,12 @@ KERNEL_FLOAT_FP8_CAST(double)
4987
5043
4988
5044
4989
5045
namespace kernel_float {
4990
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e4m3)
4991
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__half , __nv_fp8_e5m2)
5046
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e4m3)
5047
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (half_t , __nv_fp8_e5m2)
4992
5048
4993
- KERNEL_FLOAT_FP8_CAST (__half )
4994
- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e4m3, __NV_E4M3)
4995
- KERNEL_FLOAT_FP8_CAST2 (__half , __nv_fp8_e5m2, __NV_E5M2)
5049
+ KERNEL_FLOAT_FP8_CAST (half_t )
5050
+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e4m3, __NV_E4M3)
5051
+ KERNEL_FLOAT_FP8_CAST2 (half_t , __nv_fp8_e5m2, __NV_E5M2)
4996
5052
4997
5053
} // namespace kernel_float
4998
5054
#endif // KERNEL_FLOAT_FP16_AVAILABLE
@@ -5001,12 +5057,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
5001
5057
5002
5058
5003
5059
namespace kernel_float {
5004
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e4m3)
5005
- KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (__nv_bfloat16 , __nv_fp8_e5m2)
5060
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e4m3)
5061
+ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (bfloat16_t , __nv_fp8_e5m2)
5006
5062
5007
- KERNEL_FLOAT_FP8_CAST (__nv_bfloat16 )
5008
- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e4m3, __NV_E4M3)
5009
- KERNEL_FLOAT_FP8_CAST2 (__nv_bfloat16 , __nv_fp8_e5m2, __NV_E5M2)
5063
+ KERNEL_FLOAT_FP8_CAST (bfloat16_t )
5064
+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e4m3, __NV_E4M3)
5065
+ KERNEL_FLOAT_FP8_CAST2 (bfloat16_t , __nv_fp8_e5m2, __NV_E5M2)
5010
5066
} // namespace kernel_float
5011
5067
#endif // KERNEL_FLOAT_BF16_AVAILABLE
5012
5068
@@ -5075,14 +5131,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double)
5075
5131
KERNEL_FLOAT_TYPE_ALIAS (float64x, double )
5076
5132
5077
5133
#if KERNEL_FLOAT_FP16_AVAILABLE
5078
- KERNEL_FLOAT_TYPE_ALIAS (half, __half )
5079
- KERNEL_FLOAT_TYPE_ALIAS(f16x, __half )
5080
- KERNEL_FLOAT_TYPE_ALIAS(float16x, __half )
5134
+ KERNEL_FLOAT_TYPE_ALIAS (half, half_t )
5135
+ KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t )
5136
+ KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t )
5081
5137
#endif
5082
5138
5083
5139
#if KERNEL_FLOAT_BF16_AVAILABLE
5084
- KERNEL_FLOAT_TYPE_ALIAS (bfloat16x, __bfloat16 )
5085
- KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16 )
5140
+ KERNEL_FLOAT_TYPE_ALIAS (bfloat16x, bfloat16_t )
5141
+ KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t )
5086
5142
#endif
5087
5143
5088
5144
#if KERNEL_FLOAT_BF8_AVAILABLE
0 commit comments