From 5490ea756b41c688b66dd69e776a58c9ce5b1ef2 Mon Sep 17 00:00:00 2001 From: stijn Date: Fri, 1 Nov 2024 17:19:59 +0100 Subject: [PATCH 01/25] Make code compile with HIP on Lumi --- CMakeLists.txt | 31 +- README.md | 15 +- examples/hip_compat.h | 22 ++ examples/pi/CMakeLists.txt | 18 +- examples/pi/main.cu | 7 +- examples/vector_add/CMakeLists.txt | 16 +- examples/vector_add/main.cu | 3 +- examples/vector_add_tiling/CMakeLists.txt | 16 +- examples/vector_add_tiling/main.cu | 1 + include/kernel_float/base.h | 4 + include/kernel_float/bf16.h | 218 ++++++----- include/kernel_float/binops.h | 2 +- include/kernel_float/fp16.h | 65 +++- include/kernel_float/macros.h | 79 ++-- include/kernel_float/reduce.h | 30 +- include/kernel_float/tiling.h | 4 +- include/kernel_float/unops.h | 31 +- include/kernel_float/vector.h | 1 + kernel_tuner/vector_add.cu | 1 + kernel_tuner/vector_add.py | 3 +- single_include/kernel_float.h | 444 ++++++++++++---------- tests/CMakeLists.txt | 20 +- tests/common.h | 41 +- 23 files changed, 639 insertions(+), 433 deletions(-) create mode 100644 examples/hip_compat.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f563f74..bd5c28c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,13 +1,34 @@ cmake_minimum_required(VERSION 3.20) set (PROJECT_NAME kernel_float) -project(${PROJECT_NAME} CXX CUDA) +project(${PROJECT_NAME} LANGUAGES CXX) -set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +# Validate and enable the appropriate language +if (NOT DEFINED KERNEL_FLOAT_LANGUAGE) + set(KERNEL_FLOAT_LANGUAGE "CUDA") +endif() + +if (KERNEL_FLOAT_LANGUAGE STREQUAL "CUDA") + enable_language(CUDA) + set(KERNEL_FLOAT_LANGUAGE_CUDA ON) +elseif (KERNEL_FLOAT_LANGUAGE STREQUAL "HIP") + enable_language(HIP) + set(KERNEL_FLOAT_LANGUAGE_HIP ON) +else() + message(FATAL_ERROR "KERNEL_FLOAT_LANGUAGE must be either 'HIP' or 'CUDA'") +endif() + +# Create an interface library for kernel_float add_library(${PROJECT_NAME} INTERFACE) target_include_directories(${PROJECT_NAME} INTERFACE "${PROJECT_SOURCE_DIR}/include") +# Optionally build tests and examples if the corresponding flags are set +option(KERNEL_FLOAT_BUILD_TEST "Build kernel float tests" OFF) +option(KERNEL_FLOAT_BUILD_EXAMPLE "Build kernel float examples" OFF) + if (KERNEL_FLOAT_BUILD_TEST) add_subdirectory(tests) endif() @@ -15,3 +36,9 @@ endif() if (KERNEL_FLOAT_BUILD_EXAMPLE) add_subdirectory(examples) endif() + +# Display configuration +message(STATUS "=== Kernel Float ===") +message(STATUS "Using GPU Language: ${KERNEL_FLOAT_LANGUAGE}") +message(STATUS "Building Tests: ${KERNEL_FLOAT_BUILD_TEST}") +message(STATUS "Building Examples: ${KERNEL_FLOAT_BUILD_EXAMPLE}") \ No newline at end of file diff --git a/README.md b/README.md index 0670187..79748b0 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,12 @@ ![GitHub Repo stars](https://img.shields.io/github/stars/KernelTuner/kernel_float?style=social) -_Kernel Float_ is a header-only library for CUDA that simplifies working with vector types and reduced precision floating-point arithmetic in GPU code. +_Kernel Float_ is a header-only library for CUDA/HIP that simplifies working with vector types and reduced precision floating-point arithmetic in GPU code. ## Summary -CUDA natively offers several reduced precision floating-point types (`__half`, `__nv_bfloat16`, `__nv_fp8_e4m3`, `__nv_fp8_e5m2`) +CUDA/HIP natively offers several reduced precision floating-point types (`__half`, `__nv_bfloat16`, `__nv_fp8_e4m3`, `__nv_fp8_e5m2`) and vector types (e.g., `__half2`, `__nv_fp8x4_e4m3`, `float3`). However, working with these types is cumbersome: mathematical operations require intrinsics (e.g., `__hadd2` performs addition for `__half2`), @@ -24,9 +24,9 @@ and some functionality is missing (e.g., one cannot convert a `__half` to `__nv_ _Kernel Float_ resolves this by offering a single data type `kernel_float::vec` that stores `N` elements of type `T`. Internally, the data is stored as a fixed-sized array of elements. Operator overloading (like `+`, `*`, `&&`) has been implemented such that the most optimal intrinsic for the available types is selected automatically. -Many mathetical functions (like `log`, `exp`, `sin`) and common operations (such as `sum`, `range`, `for_each`) are also available. +Many mathematical functions (like `log`, `exp`, `sin`) and common operations (such as `sum`, `range`, `for_each`) are also available. -By using this library, developers can avoid the complexity of working with reduced precision floating-point types in CUDA and focus on their applications. +Using Kernel Float, developers avoid the complexity of reduced precision floating-point types in CUDA and can focus on their applications. ## Features @@ -40,6 +40,7 @@ In a nutshell, _Kernel Float_ offers the following features: * Easy integration as a single header file. * Written for C++17. * Compatible with NVCC (NVIDIA Compiler) and NVRTC (NVIDIA Runtime Compilation). +* Compatible with HIPCC (AMD HIP Compiler) ## Example @@ -49,7 +50,7 @@ Check out the [examples](https://github.com/KernelTuner/kernel_float/tree/master Below shows a simple example of a CUDA kernel that adds a `constant` to the `input` array and writes the results to the `output` array. Each thread processes two elements. -Notice how easy it would be change the precision (for example, `double` to `half`) or the vector size (for example, 4 instead of 2 items per thread). +Notice how easy it would be to change the precision (for example, `double` to `half`) or the vector size (for example, 4 instead of 2 items per thread). ```cpp @@ -63,14 +64,14 @@ __global__ void kernel(const kf::vec* input, float constant, kf::vec #include +#include "../hip_compat.h" #include "kernel_float.h" #define CUDA_CHECK(call) \ @@ -9,12 +10,12 @@ if (__err != cudaSuccess) { \ fprintf( \ stderr, \ - "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \ + "CUDA error at %s:%d (%s): %s (code %d) \n", \ __FILE__, \ __LINE__, \ - __err, \ + #call, \ cudaGetErrorString(__err), \ - #call); \ + __err); \ exit(EXIT_FAILURE); \ } \ } while (0) diff --git a/examples/vector_add/CMakeLists.txt b/examples/vector_add/CMakeLists.txt index b4d7bb8..eec05b9 100644 --- a/examples/vector_add/CMakeLists.txt +++ b/examples/vector_add/CMakeLists.txt @@ -1,12 +1,18 @@ cmake_minimum_required(VERSION 3.17) set (PROJECT_NAME kernel_float_vecadd) -project(${PROJECT_NAME} LANGUAGES CXX CUDA) -set (CMAKE_CXX_STANDARD 17) +project(${PROJECT_NAME} LANGUAGES CXX) +set (CMAKE_CXX_STANDARD 17) add_executable(${PROJECT_NAME} "${PROJECT_SOURCE_DIR}/main.cu") target_link_libraries(${PROJECT_NAME} kernel_float) -set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") -find_package(CUDA REQUIRED) -target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties("${PROJECT_SOURCE_DIR}/main.cu" PROPERTIES LANGUAGE HIP) +endif() + +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") +endif() diff --git a/examples/vector_add/main.cu b/examples/vector_add/main.cu index 705cf8c..4c9d8b8 100644 --- a/examples/vector_add/main.cu +++ b/examples/vector_add/main.cu @@ -3,6 +3,7 @@ #include #include +#include "../hip_compat.h" #include "kernel_float.h" namespace kf = kernel_float; @@ -21,7 +22,7 @@ __global__ void my_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * N < length) { - output(i) = kf::fma(input[i], input[i], kf::cast<__half>(constant)); + output(i) = kf::fma(input[i], input[i], kf::cast(constant)); } } diff --git a/examples/vector_add_tiling/CMakeLists.txt b/examples/vector_add_tiling/CMakeLists.txt index a744c34..1992272 100644 --- a/examples/vector_add_tiling/CMakeLists.txt +++ b/examples/vector_add_tiling/CMakeLists.txt @@ -1,12 +1,18 @@ cmake_minimum_required(VERSION 3.17) set (PROJECT_NAME kernel_float_vecadd_tiling) -project(${PROJECT_NAME} LANGUAGES CXX CUDA) -set (CMAKE_CXX_STANDARD 17) +project(${PROJECT_NAME} LANGUAGES CXX) +set (CMAKE_CXX_STANDARD 17) add_executable(${PROJECT_NAME} "${PROJECT_SOURCE_DIR}/main.cu") target_link_libraries(${PROJECT_NAME} kernel_float) -set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") -find_package(CUDA REQUIRED) -target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties("${PROJECT_SOURCE_DIR}/main.cu" PROPERTIES LANGUAGE HIP) +endif() + +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") +endif() diff --git a/examples/vector_add_tiling/main.cu b/examples/vector_add_tiling/main.cu index ddd28ab..14a0983 100644 --- a/examples/vector_add_tiling/main.cu +++ b/examples/vector_add_tiling/main.cu @@ -3,6 +3,7 @@ #include #include +#include "../hip_compat.h" #include "kernel_float.h" #include "kernel_float/tiling.h" namespace kf = kernel_float; diff --git a/include/kernel_float/base.h b/include/kernel_float/base.h index 403bceb..8c7c400 100644 --- a/include/kernel_float/base.h +++ b/include/kernel_float/base.h @@ -4,6 +4,10 @@ #include "macros.h" #include "meta.h" +#if KERNEL_FLOAT_IS_HIP +#include +#endif + namespace kernel_float { template diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 7db292c..7aa99f1 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -4,7 +4,11 @@ #include "macros.h" #if KERNEL_FLOAT_BF16_AVAILABLE +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif #include "binops.h" #include "reduce.h" @@ -12,58 +16,66 @@ namespace kernel_float { +#if KERNEL_FLOAT_IS_CUDA +using __bfloat16 = __nv_bfloat16; +using __bfloat162 = __nv_bfloat162; +#elif KERNEL_FLOAT_IS_HIP +using __bfloat16 = __hip_bfloat16; +using __bfloat162 = __hip_bfloat162; +#endif + +#if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 +#define KERNEL_FLOAT_BF16_OPS_SUPPORTED 1 +#endif + template<> -struct preferred_vector_size<__nv_bfloat16> { +struct preferred_vector_size<__bfloat16> { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __bfloat16) template<> -struct into_vector_impl<__nv_bfloat162> { - using value_type = __nv_bfloat16; +struct into_vector_impl<__bfloat162> { + using value_type = __bfloat16; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__nv_bfloat16, 2> call(__nv_bfloat162 input) { + static vector_storage<__bfloat16, 2> call(__bfloat162 input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__nv_bfloat16> { +struct allow_float_fallback<__bfloat16> { static constexpr bool value = true; }; }; // namespace detail -#if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat16* result, const __nv_bfloat16* a) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__bfloat16> { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __bfloat16, __bfloat16> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ + __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } -#else -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) -#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) @@ -81,33 +93,29 @@ KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) #endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 \ - operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void call( \ - ops::NAME<__nv_bfloat16>, \ - __nv_bfloat16* result, \ - const __nv_bfloat16* a, \ - const __nv_bfloat16* b) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}, __nv_bfloat162 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__bfloat16> { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \ + return FUN1(left, right); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16> { \ + KERNEL_FLOAT_INLINE static void call( \ + ops::NAME<__bfloat16>, \ + __bfloat16* result, \ + const __bfloat16* a, \ + const __bfloat16* b) { \ + __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } -#else -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) -#endif KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) @@ -122,13 +130,13 @@ KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2) KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) +#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED namespace ops { template<> -struct fma<__nv_bfloat16> { - KERNEL_FLOAT_INLINE __nv_bfloat16 - operator()(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) const { +struct fma<__bfloat16> { + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 a, __bfloat16 b, __bfloat16 c) const { return __hfma(a, b, c); } }; @@ -136,23 +144,15 @@ struct fma<__nv_bfloat16> { namespace detail { template<> -struct apply_impl< - ops::fma<__nv_bfloat16>, - 2, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16> { +struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16, __bfloat16> { KERNEL_FLOAT_INLINE static void call( - ops::fma<__nv_bfloat16>, - __nv_bfloat16* result, - const __nv_bfloat16* a, - const __nv_bfloat16* b, - const __nv_bfloat16* c) { - __nv_bfloat162 r = __hfma2( - __nv_bfloat162 {a[0], a[1]}, - __nv_bfloat162 {b[0], b[1]}, - __nv_bfloat162 {c[0], c[1]}); + ops::fma<__bfloat16>, + __bfloat16* result, + const __bfloat16* a, + const __bfloat16* b, + const __bfloat16* c) { + __bfloat162 r = + __hfma2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}, __bfloat162 {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; @@ -161,44 +161,44 @@ struct apply_impl< namespace ops { template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) { +struct cast { + KERNEL_FLOAT_INLINE __bfloat16 operator()(double input) { return __double2bfloat16(input); }; }; template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(float input) { +struct cast { + KERNEL_FLOAT_INLINE __bfloat16 operator()(float input) { return __float2bfloat16(input); }; }; template<> -struct cast<__nv_bfloat16, float> { - KERNEL_FLOAT_INLINE float operator()(__nv_bfloat16 input) { +struct cast<__bfloat16, float> { + KERNEL_FLOAT_INLINE float operator()(__bfloat16 input) { return __bfloat162float(input); }; }; } // namespace ops -#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ - namespace ops { \ - template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) { \ - return TO_HALF; \ - } \ - }; \ - template<> \ - struct cast<__nv_bfloat16, T> { \ - KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) { \ - return FROM_HALF; \ - } \ - }; \ +#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ + namespace ops { \ + template<> \ + struct cast { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(T input) { \ + return TO_HALF; \ + } \ + }; \ + template<> \ + struct cast<__bfloat16, T> { \ + KERNEL_FLOAT_INLINE T operator()(__bfloat16 input) { \ + return FROM_HALF; \ + } \ + }; \ } -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED // clang-format off // there are no official char casts. Instead, cast to int and then to char KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input)); @@ -215,17 +215,25 @@ KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_ KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); // clang-format on -#else +#endif + +#if KERNEL_FLOAT_IS_CUDA KERNEL_FLOAT_BF16_CAST( bool, __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); + +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_BF16_CAST( + bool, + __hip_bfloat16 {input ? (unsigned short)0 : (unsigned short)0x3C00}, + (__hip_bfloat16(input).data & 0x7FFF) != 0); #endif -using bfloat16 = __nv_bfloat16; -KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16) +using bfloat16 = __bfloat16; +KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __bfloat16) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, __bfloat16) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, __bfloat16) } // namespace kernel_float @@ -234,12 +242,12 @@ KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) namespace kernel_float { template<> -struct promote_type<__nv_bfloat16, __half> { +struct promote_type<__bfloat16, __half> { using type = float; }; template<> -struct promote_type<__half, __nv_bfloat16> { +struct promote_type<__half, __bfloat16> { using type = float; }; diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index eb958f1..77bc6ae 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -144,7 +144,7 @@ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool(left) ^ bool(right), boo // clang-format on // clang-format off -template typename F, typename T, typename E, typename R> +template typename F, typename T, typename E, typename R> static constexpr bool is_vector_assign_allowed = is_vector_broadcastable && is_implicit_convertible< diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 0b62d9b..4228514 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -4,7 +4,11 @@ #include "macros.h" #if KERNEL_FLOAT_FP16_AVAILABLE +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif #include "vector.h" @@ -60,21 +64,21 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) #endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil) -KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos) -KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp) -KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor) -KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log) -KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin) -KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) -KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) +KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) +KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) +KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp) #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ @@ -100,12 +104,16 @@ KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) #endif +// There are not available in HIP +#if KERNEL_FLOAT_IS_CUDA +KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) +KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) +#endif + KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2) KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div) -KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) -KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __hneu, __hneu2) @@ -152,6 +160,28 @@ struct apply_impl, 2, __half, __half, __half, __half> { }; \ } +// Only CUDA has a special `__double2half` intrinsic +#if KERNEL_FLOAT_IS_HIP +#define KERNEL_FLOAT_FP16_CAST_FWD(T) \ + KERNEL_FLOAT_FP16_CAST(T, static_cast<_Float16>(input), static_cast(input)) + +KERNEL_FLOAT_FP16_CAST_FWD(double) +KERNEL_FLOAT_FP16_CAST_FWD(float) + +KERNEL_FLOAT_FP16_CAST_FWD(char) +KERNEL_FLOAT_FP16_CAST_FWD(signed char) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned char) + +KERNEL_FLOAT_FP16_CAST_FWD(signed short) +KERNEL_FLOAT_FP16_CAST_FWD(signed int) +KERNEL_FLOAT_FP16_CAST_FWD(signed long) +KERNEL_FLOAT_FP16_CAST_FWD(signed long long) + +KERNEL_FLOAT_FP16_CAST_FWD(unsigned short) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned int) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long) +#else KERNEL_FLOAT_FP16_CAST(double, __double2half(input), double(__half2float(input))); KERNEL_FLOAT_FP16_CAST(float, __float2half(input), __half2float(input)); @@ -169,6 +199,7 @@ KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); +#endif using half = __half; KERNEL_FLOAT_VECTOR_ALIAS(half, __half) diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 4eee8f1..01b0254 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -1,42 +1,51 @@ #ifndef KERNEL_FLOAT_MACROS_H #define KERNEL_FLOAT_MACROS_H +#ifdef __HIPCC__ +#include "hip/hip_runtime.h" +#endif + +// clang-format off #ifdef __CUDACC__ -#define KERNEL_FLOAT_CUDA (1) + #define KERNEL_FLOAT_IS_CUDA (1) + + #ifdef __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __device__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else // __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __host__ + #define KERNEL_FLOAT_IS_HOST (1) + #endif // __CUDA_ARCH__ +#elif defined(__HIPCC__) + #define KERNEL_FLOAT_IS_HIP (1) -#ifdef __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __device__ -#define KERNEL_FLOAT_IS_DEVICE (1) -#define KERNEL_FLOAT_IS_HOST (0) -#define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__) -#else // __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __host__ -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDA_ARCH__ -#else // __CUDACC__ -#define KERNEL_FLOAT_INLINE inline -#define KERNEL_FLOAT_CUDA (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDACC__ + #ifdef __HIP_DEVICE_COMPILE__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ + #define KERNEL_FLOAT_IS_HOST (1) + #endif + +#else + #define KERNEL_FLOAT_INLINE inline + #define KERNEL_FLOAT_IS_HOST (1) +#endif #ifndef KERNEL_FLOAT_FP16_AVAILABLE -#define KERNEL_FLOAT_FP16_AVAILABLE (1) + #define KERNEL_FLOAT_FP16_AVAILABLE (1) #endif // KERNEL_FLOAT_FP16_AVAILABLE #ifndef KERNEL_FLOAT_BF16_AVAILABLE -#define KERNEL_FLOAT_BF16_AVAILABLE (1) + #define KERNEL_FLOAT_BF16_AVAILABLE (1) #endif // KERNEL_FLOAT_BF16_AVAILABLE #ifndef KERNEL_FLOAT_FP8_AVAILABLE -#ifdef __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) -#else // __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (0) -#endif // __CUDACC_VER_MAJOR__ + #ifdef __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) + #else // __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (0) + #endif // __CUDACC_VER_MAJOR__ #endif // KERNEL_FLOAT_FP8_AVAILABLE #define KERNEL_FLOAT_ASSERT(expr) \ @@ -51,20 +60,22 @@ // TOOD: check if this way is support across all compilers #if defined(__has_builtin) && 0 // Seems that `__builtin_assume_aligned` leads to segfaults -#if __has_builtin(__builtin_assume_aligned) -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( - __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) -#else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) -#endif + #if __has_builtin(__builtin_assume_aligned) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( + __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) + #else + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #endif #else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) #endif #define KERNEL_FLOAT_MAX_ALIGNMENT (32) #if KERNEL_FLOAT_FAST_MATH -#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; + #define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; #endif + // clang-format on + #endif //KERNEL_FLOAT_MACROS_H diff --git a/include/kernel_float/reduce.h b/include/kernel_float/reduce.h index c859265..e616f17 100644 --- a/include/kernel_float/reduce.h +++ b/include/kernel_float/reduce.h @@ -267,21 +267,21 @@ struct magnitude_impl { }; // The 3-argument overload of hypot is only available on host from C++17 -#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST -template<> -struct magnitude_impl { - static float call(const float* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; - -template<> -struct magnitude_impl { - static double call(const double* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; -#endif +//#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST +//template<> +//struct magnitude_impl { +// static float call(const float* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +// +//template<> +//struct magnitude_impl { +// static double call(const double* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +//#endif } // namespace detail diff --git a/include/kernel_float/tiling.h b/include/kernel_float/tiling.h index 10ad379..d4f944a 100644 --- a/include/kernel_float/tiling.h +++ b/include/kernel_float/tiling.h @@ -329,9 +329,7 @@ struct tiling { using index_type = IndexType; using point_type = vector>; -#if KERNEL_FLOAT_IS_DEVICE - __forceinline__ __device__ tiling() : block_(threadIdx) {} -#endif + __forceinline__ __device__ tiling() : block_(dim3(threadIdx)) {} KERNEL_FLOAT_INLINE tiling(BlockDim block, vec offset = {}) : block_(block), offset_(offset) {} diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index fce130e..8381d60 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -180,9 +180,13 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc) KERNEL_FLOAT_DEFINE_UNARY_MATH(rint) KERNEL_FLOAT_DEFINE_UNARY_MATH(rsqrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(round) + +// There are not support on HIP +#if !KERNEL_FLOAT_IS_HIP KERNEL_FLOAT_DEFINE_UNARY_MATH(signbit) KERNEL_FLOAT_DEFINE_UNARY_MATH(isinf) KERNEL_FLOAT_DEFINE_UNARY_MATH(isnan) +#endif // CUDA offers special reciprocal functions (rcp), but only on the device. #if KERNEL_FLOAT_IS_DEVICE @@ -211,8 +215,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) -#if KERNEL_FLOAT_IS_DEVICE - +// This PTX is only supported on CUDA +#if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ namespace detail { \ template<> \ @@ -231,24 +235,23 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) template<> \ struct apply_fastmath_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F fun, T* result, const T* inputs) { \ - asm(INSTR : "=" REG(*result) : REG(*inputs)); \ + asm(INSTR " %0, %1;" : "=" REG(*result) : REG(*inputs)); \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64 %0, %1;", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64 %0, %1;", "d") - -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32 %0, %1;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32 %0, %1;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") #endif } // namespace kernel_float diff --git a/include/kernel_float/vector.h b/include/kernel_float/vector.h index e52fa02..e7b1737 100644 --- a/include/kernel_float/vector.h +++ b/include/kernel_float/vector.h @@ -325,6 +325,7 @@ template using vec8 = vec; #define KERNEL_FLOAT_VECTOR_ALIAS(NAME, T) \ template \ + using v##NAME = vec; \ using NAME##1 = vec; \ using NAME##2 = vec; \ using NAME##3 = vec; \ diff --git a/kernel_tuner/vector_add.cu b/kernel_tuner/vector_add.cu index 5260bf8..2c0f339 100644 --- a/kernel_tuner/vector_add.cu +++ b/kernel_tuner/vector_add.cu @@ -1,6 +1,7 @@ #include "kernel_float.h" namespace kf = kernel_float; +extern "C" __global__ void vector_add( kf::vec* c, const kf::vec* a, diff --git a/kernel_tuner/vector_add.py b/kernel_tuner/vector_add.py index 8c5b2ee..4085cba 100644 --- a/kernel_tuner/vector_add.py +++ b/kernel_tuner/vector_add.py @@ -9,6 +9,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/../" flags = [f"-I{ROOT_DIR}/include", "-std=c++17"] + def tune(): # Prepare input data @@ -54,7 +55,7 @@ def tune(): answer=answer, observers=observers, metrics=metrics, - lang="cupy", + lang="CUDA", compiler_options=flags ) diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index a79da18..7e62861 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,49 +16,58 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-09-23 14:12:25.024358 -// git hash: 3a88b56a57cce5e1f3365aa6e8efb76a14f7f865 +// date: 2024-11-01 17:19:21.255671 +// git hash: 8333b040cbdbb1ec66e1ab4e459f597d889fcd7e //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H #define KERNEL_FLOAT_MACROS_H +#ifdef __HIPCC__ +#include "hip/hip_runtime.h" +#endif + +// clang-format off #ifdef __CUDACC__ -#define KERNEL_FLOAT_CUDA (1) - -#ifdef __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __device__ -#define KERNEL_FLOAT_IS_DEVICE (1) -#define KERNEL_FLOAT_IS_HOST (0) -#define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__) -#else // __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __host__ -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDA_ARCH__ -#else // __CUDACC__ -#define KERNEL_FLOAT_INLINE inline -#define KERNEL_FLOAT_CUDA (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDACC__ + #define KERNEL_FLOAT_IS_CUDA (1) + + #ifdef __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __device__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else // __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __host__ + #define KERNEL_FLOAT_IS_HOST (1) + #endif // __CUDA_ARCH__ +#elif defined(__HIPCC__) + #define KERNEL_FLOAT_IS_HIP (1) + + #ifdef __HIP_DEVICE_COMPILE__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ + #define KERNEL_FLOAT_IS_HOST (1) + #endif + +#else + #define KERNEL_FLOAT_INLINE inline + #define KERNEL_FLOAT_IS_HOST (1) +#endif #ifndef KERNEL_FLOAT_FP16_AVAILABLE -#define KERNEL_FLOAT_FP16_AVAILABLE (1) + #define KERNEL_FLOAT_FP16_AVAILABLE (1) #endif // KERNEL_FLOAT_FP16_AVAILABLE #ifndef KERNEL_FLOAT_BF16_AVAILABLE -#define KERNEL_FLOAT_BF16_AVAILABLE (1) + #define KERNEL_FLOAT_BF16_AVAILABLE (1) #endif // KERNEL_FLOAT_BF16_AVAILABLE #ifndef KERNEL_FLOAT_FP8_AVAILABLE -#ifdef __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) -#else // __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (0) -#endif // __CUDACC_VER_MAJOR__ + #ifdef __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) + #else // __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (0) + #endif // __CUDACC_VER_MAJOR__ #endif // KERNEL_FLOAT_FP8_AVAILABLE #define KERNEL_FLOAT_ASSERT(expr) \ @@ -73,22 +82,24 @@ // TOOD: check if this way is support across all compilers #if defined(__has_builtin) && 0 // Seems that `__builtin_assume_aligned` leads to segfaults -#if __has_builtin(__builtin_assume_aligned) -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( - __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) -#else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) -#endif + #if __has_builtin(__builtin_assume_aligned) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( + __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) + #else + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #endif #else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) #endif #define KERNEL_FLOAT_MAX_ALIGNMENT (32) #if KERNEL_FLOAT_FAST_MATH -#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; + #define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; #endif + // clang-format on + #endif //KERNEL_FLOAT_MACROS_H #ifndef KERNEL_FLOAT_CORE_H #define KERNEL_FLOAT_CORE_H @@ -383,6 +394,10 @@ constexpr size_t round_up_to_power_of_two(size_t n) { +#if KERNEL_FLOAT_IS_HIP +#include +#endif + namespace kernel_float { template @@ -1297,9 +1312,13 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(trunc) KERNEL_FLOAT_DEFINE_UNARY_MATH(rint) KERNEL_FLOAT_DEFINE_UNARY_MATH(rsqrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(round) + +// There are not support on HIP +#if !KERNEL_FLOAT_IS_HIP KERNEL_FLOAT_DEFINE_UNARY_MATH(signbit) KERNEL_FLOAT_DEFINE_UNARY_MATH(isinf) KERNEL_FLOAT_DEFINE_UNARY_MATH(isnan) +#endif // CUDA offers special reciprocal functions (rcp), but only on the device. #if KERNEL_FLOAT_IS_DEVICE @@ -1328,8 +1347,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) -#if KERNEL_FLOAT_IS_DEVICE - +// This PTX is only supported on CUDA +#if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ namespace detail { \ template<> \ @@ -1348,24 +1367,23 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) template<> \ struct apply_fastmath_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F fun, T* result, const T* inputs) { \ - asm(INSTR : "=" REG(*result) : REG(*inputs)); \ + asm(INSTR " %0, %1;" : "=" REG(*result) : REG(*inputs)); \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64 %0, %1;", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64 %0, %1;", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32 %0, %1;", "f") - -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32 %0, %1;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") #endif } // namespace kernel_float @@ -1734,7 +1752,7 @@ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool(left) ^ bool(right), boo // clang-format on // clang-format off -template typename F, typename T, typename E, typename R> +template typename F, typename T, typename E, typename R> static constexpr bool is_vector_assign_allowed = is_vector_broadcastable && is_implicit_convertible< @@ -3458,21 +3476,21 @@ struct magnitude_impl { }; // The 3-argument overload of hypot is only available on host from C++17 -#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST -template<> -struct magnitude_impl { - static float call(const float* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; - -template<> -struct magnitude_impl { - static double call(const double* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; -#endif +//#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST +//template<> +//struct magnitude_impl { +// static float call(const float* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +// +//template<> +//struct magnitude_impl { +// static double call(const double* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +//#endif } // namespace detail @@ -3821,6 +3839,7 @@ template using vec8 = vec; #define KERNEL_FLOAT_VECTOR_ALIAS(NAME, T) \ template \ + using v##NAME = vec; \ using NAME##1 = vec; \ using NAME##2 = vec; \ using NAME##3 = vec; \ @@ -3874,7 +3893,11 @@ vec(Args&&... args) -> vec, sizeof...(Args)>; #if KERNEL_FLOAT_FP16_AVAILABLE +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif @@ -3930,21 +3953,21 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) #endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil) -KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos) -KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp) -KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor) -KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log) -KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin) -KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) -KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) +KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) +KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) +KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp) #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ @@ -3970,12 +3993,16 @@ KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) #endif +// There are not available in HIP +#if KERNEL_FLOAT_IS_CUDA +KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) +KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) +#endif + KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2) KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div) -KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) -KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __hneu, __hneu2) @@ -4022,6 +4049,28 @@ struct apply_impl, 2, __half, __half, __half, __half> { }; \ } +// Only CUDA has a special `__double2half` intrinsic +#if KERNEL_FLOAT_IS_HIP +#define KERNEL_FLOAT_FP16_CAST_FWD(T) \ + KERNEL_FLOAT_FP16_CAST(T, static_cast<_Float16>(input), static_cast(input)) + +KERNEL_FLOAT_FP16_CAST_FWD(double) +KERNEL_FLOAT_FP16_CAST_FWD(float) + +KERNEL_FLOAT_FP16_CAST_FWD(char) +KERNEL_FLOAT_FP16_CAST_FWD(signed char) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned char) + +KERNEL_FLOAT_FP16_CAST_FWD(signed short) +KERNEL_FLOAT_FP16_CAST_FWD(signed int) +KERNEL_FLOAT_FP16_CAST_FWD(signed long) +KERNEL_FLOAT_FP16_CAST_FWD(signed long long) + +KERNEL_FLOAT_FP16_CAST_FWD(unsigned short) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned int) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long) +#else KERNEL_FLOAT_FP16_CAST(double, __double2half(input), double(__half2float(input))); KERNEL_FLOAT_FP16_CAST(float, __float2half(input), __half2float(input)); @@ -4039,6 +4088,7 @@ KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); +#endif using half = __half; KERNEL_FLOAT_VECTOR_ALIAS(half, __half) @@ -4056,7 +4106,11 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, __half) #if KERNEL_FLOAT_BF16_AVAILABLE +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif @@ -4064,58 +4118,66 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, __half) namespace kernel_float { +#if KERNEL_FLOAT_IS_CUDA +using __bfloat16 = __nv_bfloat16; +using __bfloat162 = __nv_bfloat162; +#elif KERNEL_FLOAT_IS_HIP +using __bfloat16 = __hip_bfloat16; +using __bfloat162 = __hip_bfloat162; +#endif + +#if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 +#define KERNEL_FLOAT_BF16_OPS_SUPPORTED 1 +#endif + template<> -struct preferred_vector_size<__nv_bfloat16> { +struct preferred_vector_size<__bfloat16> { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __bfloat16) template<> -struct into_vector_impl<__nv_bfloat162> { - using value_type = __nv_bfloat16; +struct into_vector_impl<__bfloat162> { + using value_type = __bfloat16; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__nv_bfloat16, 2> call(__nv_bfloat162 input) { + static vector_storage<__bfloat16, 2> call(__bfloat162 input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__nv_bfloat16> { +struct allow_float_fallback<__bfloat16> { static constexpr bool value = true; }; }; // namespace detail -#if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat16* result, const __nv_bfloat16* a) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__bfloat16> { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __bfloat16, __bfloat16> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ + __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } -#else -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) -#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) @@ -4133,33 +4195,29 @@ KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) #endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 \ - operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void call( \ - ops::NAME<__nv_bfloat16>, \ - __nv_bfloat16* result, \ - const __nv_bfloat16* a, \ - const __nv_bfloat16* b) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}, __nv_bfloat162 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__bfloat16> { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \ + return FUN1(left, right); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16> { \ + KERNEL_FLOAT_INLINE static void call( \ + ops::NAME<__bfloat16>, \ + __bfloat16* result, \ + const __bfloat16* a, \ + const __bfloat16* b) { \ + __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } -#else -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) -#endif KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) @@ -4174,13 +4232,13 @@ KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2) KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) +#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED namespace ops { template<> -struct fma<__nv_bfloat16> { - KERNEL_FLOAT_INLINE __nv_bfloat16 - operator()(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) const { +struct fma<__bfloat16> { + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 a, __bfloat16 b, __bfloat16 c) const { return __hfma(a, b, c); } }; @@ -4188,23 +4246,15 @@ struct fma<__nv_bfloat16> { namespace detail { template<> -struct apply_impl< - ops::fma<__nv_bfloat16>, - 2, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16> { +struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16, __bfloat16> { KERNEL_FLOAT_INLINE static void call( - ops::fma<__nv_bfloat16>, - __nv_bfloat16* result, - const __nv_bfloat16* a, - const __nv_bfloat16* b, - const __nv_bfloat16* c) { - __nv_bfloat162 r = __hfma2( - __nv_bfloat162 {a[0], a[1]}, - __nv_bfloat162 {b[0], b[1]}, - __nv_bfloat162 {c[0], c[1]}); + ops::fma<__bfloat16>, + __bfloat16* result, + const __bfloat16* a, + const __bfloat16* b, + const __bfloat16* c) { + __bfloat162 r = + __hfma2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}, __bfloat162 {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; @@ -4213,44 +4263,44 @@ struct apply_impl< namespace ops { template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) { +struct cast { + KERNEL_FLOAT_INLINE __bfloat16 operator()(double input) { return __double2bfloat16(input); }; }; template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(float input) { +struct cast { + KERNEL_FLOAT_INLINE __bfloat16 operator()(float input) { return __float2bfloat16(input); }; }; template<> -struct cast<__nv_bfloat16, float> { - KERNEL_FLOAT_INLINE float operator()(__nv_bfloat16 input) { +struct cast<__bfloat16, float> { + KERNEL_FLOAT_INLINE float operator()(__bfloat16 input) { return __bfloat162float(input); }; }; } // namespace ops -#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ - namespace ops { \ - template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) { \ - return TO_HALF; \ - } \ - }; \ - template<> \ - struct cast<__nv_bfloat16, T> { \ - KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) { \ - return FROM_HALF; \ - } \ - }; \ - } - -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ + namespace ops { \ + template<> \ + struct cast { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(T input) { \ + return TO_HALF; \ + } \ + }; \ + template<> \ + struct cast<__bfloat16, T> { \ + KERNEL_FLOAT_INLINE T operator()(__bfloat16 input) { \ + return FROM_HALF; \ + } \ + }; \ + } + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED // clang-format off // there are no official char casts. Instead, cast to int and then to char KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input)); @@ -4267,17 +4317,25 @@ KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_ KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); // clang-format on -#else +#endif + +#if KERNEL_FLOAT_IS_CUDA KERNEL_FLOAT_BF16_CAST( bool, __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); + +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_BF16_CAST( + bool, + __hip_bfloat16 {input ? (unsigned short)0 : (unsigned short)0x3C00}, + (__hip_bfloat16(input).data & 0x7FFF) != 0); #endif -using bfloat16 = __nv_bfloat16; -KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16) +using bfloat16 = __bfloat16; +KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __bfloat16) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, __bfloat16) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, __bfloat16) } // namespace kernel_float @@ -4286,12 +4344,12 @@ KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) namespace kernel_float { template<> -struct promote_type<__nv_bfloat16, __half> { +struct promote_type<__bfloat16, __half> { using type = float; }; template<> -struct promote_type<__half, __nv_bfloat16> { +struct promote_type<__half, __bfloat16> { using type = float; }; @@ -4876,9 +4934,7 @@ struct tiling { using index_type = IndexType; using point_type = vector>; -#if KERNEL_FLOAT_IS_DEVICE - __forceinline__ __device__ tiling() : block_(threadIdx) {} -#endif + __forceinline__ __device__ tiling() : block_(dim3(threadIdx)) {} KERNEL_FLOAT_INLINE tiling(BlockDim block, vec offset = {}) : block_(block), offset_(offset) {} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bd8edb6..523fe50 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,18 +1,24 @@ cmake_minimum_required(VERSION 3.16) -project(tests) +project(tests LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) file(GLOB FILES *.cu) add_executable(kernel_float_tests ${FILES}) target_link_libraries(kernel_float_tests PRIVATE kernel_float) -target_compile_options(kernel_float_tests PRIVATE "--extended-lambda") -set_target_properties(kernel_float_tests PROPERTIES CUDA_ARCHITECTURES "70;80") - -target_compile_options(kernel_float_tests PRIVATE "-ftime-report -ftime-report-details") add_subdirectory(Catch2) target_link_libraries(kernel_float_tests PRIVATE Catch2::Catch2WithMain) -find_package(CUDA REQUIRED) -target_include_directories(kernel_float_tests PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(kernel_float_tests PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + + target_compile_options(kernel_float_tests PRIVATE "-ftime-report -ftime-report-details") + target_compile_options(kernel_float_tests PRIVATE "--extended-lambda") + set_target_properties(kernel_float_tests PROPERTIES CUDA_ARCHITECTURES "70;80") +endif() + +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties(${FILES} PROPERTIES LANGUAGE HIP) +endif() \ No newline at end of file diff --git a/tests/common.h b/tests/common.h index ff8072d..bd257ca 100644 --- a/tests/common.h +++ b/tests/common.h @@ -1,8 +1,5 @@ #pragma once -#include -#include - #include #include "catch2/catch_all.hpp" @@ -10,19 +7,37 @@ namespace kf = kernel_float; +#if KERNEL_FLOAT_IS_HIP +#define cudaError_t hipError_t +#define cudaSuccess hipSuccess +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaSetDevice hipSetDevice +#define cudaDeviceSynchronize hipDeviceSynchronize + +using __nv_bfloat16 = __hip_bfloat16; +#endif + namespace detail { -__attribute__((noinline)) static __host__ __device__ void +__attribute__((noinline)) static __host__ void __assertion_failed(const char* expr, const char* file, int line) { -#ifndef __CUDA_ARCH__ std::string msg = "assertion failed: " + std::string(expr) + " (" + file + ":" + std::to_string(line) + ")"; throw std::runtime_error(msg); -#else +} + +__attribute__((noinline)) static __device__ void +__assertion_failed(const char* expr, const char* file, int line) { printf("assertion failed: %s (%s:%d)\n", expr, file, line); + +#if KERNEL_FLOAT_IS_CUDA asm("trap;"); +#elif KERNEL_FLOAT_IS_HIP + __builtin_trap(); +#endif + while (1) ; -#endif } } // namespace detail @@ -52,14 +67,14 @@ struct equals_helper { template<> struct equals_helper { static __host__ __device__ bool call(const double& left, const double& right) { - return (isnan(left) && isnan(right)) || (left == right); + return (std::isnan(left) && std::isnan(right)) || (left == right); } }; template<> struct equals_helper { static __host__ __device__ bool call(const float& left, const float& right) { - return (isnan(left) && isnan(right)) || (left == right); + return (std::isnan(left) && std::isnan(right)) || (left == right); } }; @@ -73,7 +88,7 @@ struct equals_helper<__half> { template<> struct equals_helper<__nv_bfloat16> { static __host__ __device__ bool call(const __nv_bfloat16& left, const __nv_bfloat16& right) { - return equals_helper::call(float(left), float(right)); + return equals_helper::call(__bfloat162float(left), __bfloat162float(right)); } }; @@ -123,14 +138,14 @@ struct approx_helper { template<> struct approx_helper<__half> { static __host__ __device__ bool call(__half left, __half right) { - return approx_helper::call(double(left), double(right), 0.01); + return approx_helper::call(float(left), float(right), 0.01); } }; template<> struct approx_helper<__nv_bfloat16> { static __host__ __device__ bool call(__nv_bfloat16 left, __nv_bfloat16 right) { - return approx_helper::call(double(left), double(right), 0.05); + return approx_helper::call(__bfloat162float(left), __bfloat162float(right), 0.05); } }; } // namespace detail @@ -218,7 +233,7 @@ struct generator_value<__half> { template<> struct generator_value<__nv_bfloat16> { __host__ __device__ static __nv_bfloat16 call(uint64_t seed) { - return __nv_bfloat16(generator_value::call(seed)); + return __float2bfloat16(generator_value::call(seed)); } }; } // namespace detail From 938565527f9129209410a678f3906af0c1c962ce Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 11:11:57 +0100 Subject: [PATCH 02/25] Changes to make code compile under HIPRTC --- include/kernel_float/base.h | 2 +- include/kernel_float/bf16.h | 69 ++++----- include/kernel_float/binops.h | 8 +- include/kernel_float/fp16.h | 68 +++++---- include/kernel_float/meta.h | 3 - include/kernel_float/prelude.h | 8 +- include/kernel_float/unops.h | 66 ++++----- include/kernel_float/vector.h | 19 ++- single_include/kernel_float.h | 247 +++++++++++++++++---------------- tests/common.h | 29 ++-- 10 files changed, 270 insertions(+), 249 deletions(-) diff --git a/include/kernel_float/base.h b/include/kernel_float/base.h index 8c7c400..2f658b5 100644 --- a/include/kernel_float/base.h +++ b/include/kernel_float/base.h @@ -270,7 +270,7 @@ using promoted_vector_value_type = promote_t...>; template KERNEL_FLOAT_INLINE vector_storage_type into_vector_storage(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } } // namespace kernel_float diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 7aa99f1..bad5f75 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -4,6 +4,11 @@ #include "macros.h" #if KERNEL_FLOAT_BF16_AVAILABLE +//#define CUDA_NO_BFLOAT16 (1) +//#define __CUDA_NO_BFLOAT16_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT162_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1) + #if KERNEL_FLOAT_IS_CUDA #include #elif KERNEL_FLOAT_IS_HIP @@ -76,21 +81,24 @@ struct allow_float_fallback<__bfloat16> { }; \ } -KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) + KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) + KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) #endif #if KERNEL_FLOAT_BF16_OPS_SUPPORTED @@ -99,7 +107,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) template<> \ struct NAME<__bfloat16> { \ KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \ - return FUN1(left, right); \ + return ops::cast {}(FUN1(left, right)); \ } \ }; \ } \ @@ -159,29 +167,6 @@ struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16, _ } // namespace detail #endif -namespace ops { -template<> -struct cast { - KERNEL_FLOAT_INLINE __bfloat16 operator()(double input) { - return __double2bfloat16(input); - }; -}; - -template<> -struct cast { - KERNEL_FLOAT_INLINE __bfloat16 operator()(float input) { - return __float2bfloat16(input); - }; -}; - -template<> -struct cast<__bfloat16, float> { - KERNEL_FLOAT_INLINE float operator()(__bfloat16 input) { - return __bfloat162float(input); - }; -}; -} // namespace ops - #define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ @@ -198,6 +183,9 @@ struct cast<__bfloat16, float> { }; \ } +KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input)) +KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input)) + #if KERNEL_FLOAT_BF16_OPS_SUPPORTED // clang-format off // there are no official char casts. Instead, cast to int and then to char @@ -205,24 +193,23 @@ KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(i KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input)); +KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input))); KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input)); -KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input)); +KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); // clang-format on #endif #if KERNEL_FLOAT_IS_CUDA -KERNEL_FLOAT_BF16_CAST( - bool, - __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, - (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); - +//KERNEL_FLOAT_BF16_CAST( +// bool, +// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, +// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); #elif KERNEL_FLOAT_IS_HIP KERNEL_FLOAT_BF16_CAST( bool, diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 77bc6ae..5ac6ed9 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -65,10 +65,10 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co return result; } -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ - template> \ - KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ + template> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common(ops::NAME {}, static_cast(left), static_cast(right)); \ } #define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \ diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 4228514..0e90d8d 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -4,6 +4,11 @@ #include "macros.h" #if KERNEL_FLOAT_FP16_AVAILABLE +//#define CUDA_NO_HALF (1) +//#define __CUDA_NO_HALF_OPERATORS__ (1) +//#define __CUDA_NO_HALF2_OPERATORS__ (1) +//#define __CUDA_NO_HALF_CONVERSIONS__ (1) + #if KERNEL_FLOAT_IS_CUDA #include #elif KERNEL_FLOAT_IS_HIP @@ -64,41 +69,44 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) #endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) -KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) + KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) -KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) -KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) + KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp) +KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) + #if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __half, __half, __half> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__half> { \ + KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ + return ops::cast {}(FUN1(left, right)); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __half, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ + __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } #else #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) @@ -190,13 +198,13 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input)); +KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input))); KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input)); +KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); #endif diff --git a/include/kernel_float/meta.h b/include/kernel_float/meta.h index 005be9b..5256129 100644 --- a/include/kernel_float/meta.h +++ b/include/kernel_float/meta.h @@ -270,9 +270,6 @@ struct enable_if_impl { template using enable_if_t = typename detail::enable_if_impl::type; -template -using identity_t = T; - KERNEL_FLOAT_INLINE constexpr size_t round_up_to_power_of_two(size_t n) { size_t result = 1; diff --git a/include/kernel_float/prelude.h b/include/kernel_float/prelude.h index 59872a0..b16f054 100644 --- a/include/kernel_float/prelude.h +++ b/include/kernel_float/prelude.h @@ -67,8 +67,8 @@ KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) #endif #if KERNEL_FLOAT_BF16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16) -KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16) #endif #if KERNEL_FLOAT_BF8_AVAILABLE @@ -82,12 +82,12 @@ static constexpr extent kextent = {}; template KERNEL_FLOAT_INLINE kvec, sizeof...(Args)> make_kvec(Args&&... args) { - return ::kernel_float::make_vec(std::forward(args)...); + return ::kernel_float::make_vec(static_cast(args)...); }; template KERNEL_FLOAT_INLINE into_vector_type into_kvec(V&& input) { - return ::kernel_float::into_vec(std::forward(input)); + return ::kernel_float::into_vec(static_cast(input)); } template diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 8381d60..9e5fe42 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -123,8 +123,8 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast \ struct NAME::value>> { \ KERNEL_FLOAT_INLINE T operator()(T input_arg) { \ - float input = float(input_arg); \ - return T(EXPR_F32); \ + float input = ops::cast {}(input_arg); \ + return ops::cast {}(EXPR_F32); \ } \ }; \ \ @@ -140,52 +140,56 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast, E>::call(input)) {} + storage_type(detail::broadcast_impl, E>::call(input)) {} // For all other arguments, we convert it using `convert_storage` according to broadcast rules template, T>, int> = 0> @@ -82,6 +82,14 @@ struct vector: public S { return *this; } + /** + * Returns an instance of the `extent_type` for this vector. + */ + KERNEL_FLOAT_INLINE + extent_type extent() const { + return {}; + } + /** * Returns a pointer to the underlying storage data. */ @@ -206,7 +214,7 @@ struct vector: public S { */ KERNEL_FLOAT_INLINE void set(size_t x, T value) { - at(x) = std::move(value); + at(x) = static_cast(value); } /** @@ -239,7 +247,8 @@ struct vector: public S { * Broadcast this vector into a new size `(Ns...)`. */ template - KERNEL_FLOAT_INLINE vector> broadcast(extent new_size = {}) const { + KERNEL_FLOAT_INLINE vector> + broadcast(kernel_float::extent new_size = {}) const { return kernel_float::broadcast(*this, new_size); } @@ -274,7 +283,7 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE void for_each(F fun) const { - return kernel_float::for_each(*this, std::move(fun)); + return kernel_float::for_each(*this, fun); } /** @@ -303,7 +312,7 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE into_vector_type into_vec(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } template diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 7e62861..ea9787f 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-01 17:19:21.255671 -// git hash: 8333b040cbdbb1ec66e1ab4e459f597d889fcd7e +// date: 2024-11-18 11:10:30.225884 +// git hash: 5490ea756b41c688b66dd69e776a58c9ce5b1ef2 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -373,9 +373,6 @@ struct enable_if_impl { template using enable_if_t = typename detail::enable_if_impl::type; -template -using identity_t = T; - KERNEL_FLOAT_INLINE constexpr size_t round_up_to_power_of_two(size_t n) { size_t result = 1; @@ -660,7 +657,7 @@ using promoted_vector_value_type = promote_t...>; template KERNEL_FLOAT_INLINE vector_storage_type into_vector_storage(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } } // namespace kernel_float @@ -1255,8 +1252,8 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast \ struct NAME::value>> { \ KERNEL_FLOAT_INLINE T operator()(T input_arg) { \ - float input = float(input_arg); \ - return T(EXPR_F32); \ + float input = ops::cast {}(input_arg); \ + return ops::cast {}(EXPR_F32); \ } \ }; \ \ @@ -1272,52 +1269,56 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast zip_common(F fun, const L& left, co return result; } -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ - template> \ - KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ + template> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common(ops::NAME {}, static_cast(left), static_cast(right)); \ } #define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \ @@ -3552,7 +3553,7 @@ struct vector: public S { // Copy anything of type `storage_type` KERNEL_FLOAT_INLINE vector(const value_type& input = {}) : - storage_type(detail::broadcast_impl, E>::call(input)) {} + storage_type(detail::broadcast_impl, E>::call(input)) {} // For all other arguments, we convert it using `convert_storage` according to broadcast rules template, T>, int> = 0> @@ -3596,6 +3597,14 @@ struct vector: public S { return *this; } + /** + * Returns an instance of the `extent_type` for this vector. + */ + KERNEL_FLOAT_INLINE + extent_type extent() const { + return {}; + } + /** * Returns a pointer to the underlying storage data. */ @@ -3720,7 +3729,7 @@ struct vector: public S { */ KERNEL_FLOAT_INLINE void set(size_t x, T value) { - at(x) = std::move(value); + at(x) = static_cast(value); } /** @@ -3753,7 +3762,8 @@ struct vector: public S { * Broadcast this vector into a new size `(Ns...)`. */ template - KERNEL_FLOAT_INLINE vector> broadcast(extent new_size = {}) const { + KERNEL_FLOAT_INLINE vector> + broadcast(kernel_float::extent new_size = {}) const { return kernel_float::broadcast(*this, new_size); } @@ -3788,7 +3798,7 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE void for_each(F fun) const { - return kernel_float::for_each(*this, std::move(fun)); + return kernel_float::for_each(*this, fun); } /** @@ -3817,7 +3827,7 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE into_vector_type into_vec(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } template @@ -3893,6 +3903,11 @@ vec(Args&&... args) -> vec, sizeof...(Args)>; #if KERNEL_FLOAT_FP16_AVAILABLE +//#define CUDA_NO_HALF (1) +//#define __CUDA_NO_HALF_OPERATORS__ (1) +//#define __CUDA_NO_HALF2_OPERATORS__ (1) +//#define __CUDA_NO_HALF_CONVERSIONS__ (1) + #if KERNEL_FLOAT_IS_CUDA #include #elif KERNEL_FLOAT_IS_HIP @@ -3953,41 +3968,44 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) #endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) -KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) + KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) -KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) -KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) + KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp) +KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) + #if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __half, __half, __half> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__half> { \ + KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ + return ops::cast {}(FUN1(left, right)); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __half, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ + __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } #else #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) @@ -4079,13 +4097,13 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input)); +KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input))); KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input)); +KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); #endif @@ -4106,6 +4124,11 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, __half) #if KERNEL_FLOAT_BF16_AVAILABLE +//#define CUDA_NO_BFLOAT16 (1) +//#define __CUDA_NO_BFLOAT16_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT162_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1) + #if KERNEL_FLOAT_IS_CUDA #include #elif KERNEL_FLOAT_IS_HIP @@ -4178,21 +4201,24 @@ struct allow_float_fallback<__bfloat16> { }; \ } -KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) + KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) + KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) #endif #if KERNEL_FLOAT_BF16_OPS_SUPPORTED @@ -4201,7 +4227,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) template<> \ struct NAME<__bfloat16> { \ KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \ - return FUN1(left, right); \ + return ops::cast {}(FUN1(left, right)); \ } \ }; \ } \ @@ -4261,29 +4287,6 @@ struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16, _ } // namespace detail #endif -namespace ops { -template<> -struct cast { - KERNEL_FLOAT_INLINE __bfloat16 operator()(double input) { - return __double2bfloat16(input); - }; -}; - -template<> -struct cast { - KERNEL_FLOAT_INLINE __bfloat16 operator()(float input) { - return __float2bfloat16(input); - }; -}; - -template<> -struct cast<__bfloat16, float> { - KERNEL_FLOAT_INLINE float operator()(__bfloat16 input) { - return __bfloat162float(input); - }; -}; -} // namespace ops - #define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ @@ -4300,6 +4303,9 @@ struct cast<__bfloat16, float> { }; \ } +KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input)) +KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input)) + #if KERNEL_FLOAT_BF16_OPS_SUPPORTED // clang-format off // there are no official char casts. Instead, cast to int and then to char @@ -4307,24 +4313,23 @@ KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(i KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input)); +KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input))); KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input)); -KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input)); +KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); // clang-format on #endif #if KERNEL_FLOAT_IS_CUDA -KERNEL_FLOAT_BF16_CAST( - bool, - __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, - (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); - +//KERNEL_FLOAT_BF16_CAST( +// bool, +// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, +// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); #elif KERNEL_FLOAT_IS_HIP KERNEL_FLOAT_BF16_CAST( bool, @@ -4546,8 +4551,8 @@ KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) #endif #if KERNEL_FLOAT_BF16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16) -KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16) #endif #if KERNEL_FLOAT_BF8_AVAILABLE @@ -4561,12 +4566,12 @@ static constexpr extent kextent = {}; template KERNEL_FLOAT_INLINE kvec, sizeof...(Args)> make_kvec(Args&&... args) { - return ::kernel_float::make_vec(std::forward(args)...); + return ::kernel_float::make_vec(static_cast(args)...); }; template KERNEL_FLOAT_INLINE into_vector_type into_kvec(V&& input) { - return ::kernel_float::into_vec(std::forward(input)); + return ::kernel_float::into_vec(static_cast(input)); } template diff --git a/tests/common.h b/tests/common.h index bd257ca..d209774 100644 --- a/tests/common.h +++ b/tests/common.h @@ -19,6 +19,22 @@ using __nv_bfloat16 = __hip_bfloat16; #endif namespace detail { +#if KERNEL_FLOAT_IS_CUDA +__attribute__((noinline)) static __host__ __device__ void +__assertion_failed(const char* expr, const char* file, int line) { +#if KERNEL_FLOAT_IS_HOST + std::string msg = + "assertion failed: " + std::string(expr) + " (" + file + ":" + std::to_string(line) + ")"; + throw std::runtime_error(msg); +#else + printf("assertion failed: %s (%s:%d)\n", expr, file, line); + asm("trap;"); + while (1) + ; +#endif +} + +#elif KERNEL_FLOAT_IS_HIP __attribute__((noinline)) static __host__ void __assertion_failed(const char* expr, const char* file, int line) { std::string msg = @@ -29,16 +45,11 @@ __assertion_failed(const char* expr, const char* file, int line) { __attribute__((noinline)) static __device__ void __assertion_failed(const char* expr, const char* file, int line) { printf("assertion failed: %s (%s:%d)\n", expr, file, line); - -#if KERNEL_FLOAT_IS_CUDA - asm("trap;"); -#elif KERNEL_FLOAT_IS_HIP __builtin_trap(); -#endif - while (1) ; } +#endif } // namespace detail #define ASSERT(...) \ @@ -81,7 +92,7 @@ struct equals_helper { template<> struct equals_helper<__half> { static __host__ __device__ bool call(const __half& left, const __half& right) { - return equals_helper::call(float(left), float(right)); + return equals_helper::call(__half2float(left), __half2float(right)); } }; @@ -138,7 +149,7 @@ struct approx_helper { template<> struct approx_helper<__half> { static __host__ __device__ bool call(__half left, __half right) { - return approx_helper::call(float(left), float(right), 0.01); + return approx_helper::call(__half2float(left), __half2float(right), 0.01); } }; @@ -226,7 +237,7 @@ struct generator_value>> { template<> struct generator_value<__half> { __host__ __device__ static __half call(uint64_t seed) { - return __half(generator_value::call(seed)); + return __float2half(generator_value::call(seed)); } }; From 4fca9cb22fe2199add5e4c7a9cb6a08ba469a678 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 12:11:22 +0100 Subject: [PATCH 03/25] Add approximation functions --- include/kernel_float.h | 1 + include/kernel_float/apply.h | 11 + include/kernel_float/approx.h | 391 +++++++++++++++++++++++++++++ include/kernel_float/bf16.h | 2 + include/kernel_float/binops.h | 3 +- include/kernel_float/fp16.h | 2 + include/kernel_float/macros.h | 2 + include/kernel_float/unops.h | 42 ++-- single_include/kernel_float.h | 457 ++++++++++++++++++++++++++++++++-- 9 files changed, 875 insertions(+), 36 deletions(-) create mode 100644 include/kernel_float/approx.h diff --git a/include/kernel_float.h b/include/kernel_float.h index ee098fc..be7e21c 100644 --- a/include/kernel_float.h +++ b/include/kernel_float.h @@ -1,6 +1,7 @@ #ifndef KERNEL_FLOAT_H #define KERNEL_FLOAT_H +#include "kernel_float/approx.h" #include "kernel_float/base.h" #include "kernel_float/bf16.h" #include "kernel_float/binops.h" diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 3a7e02c..1421132 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -130,6 +130,9 @@ struct apply_impl { template struct apply_fastmath_impl: apply_impl {}; + +template +struct apply_approx_impl: apply_fastmath_impl {}; } // namespace detail struct accurate_policy { @@ -142,6 +145,14 @@ struct fast_policy { using type = detail::apply_fastmath_impl; }; +template +struct approximate_policy { + template + using type = detail::apply_approx_impl; +}; + +using default_approximate_policy = approximate_policy<>; + #ifdef KERNEL_FLOAT_POLICY using default_policy = KERNEL_FLOAT_POLICY; #else diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h new file mode 100644 index 0000000..6725faf --- /dev/null +++ b/include/kernel_float/approx.h @@ -0,0 +1,391 @@ +#pragma once + +#include "apply.h" +#include "bf16.h" +#include "fp16.h" +#include "macros.h" + +namespace kernel_float { + +namespace approx { + +static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +using uint32_t = unsigned int; + +template +KERNEL_FLOAT_DEVICE T transmute(const U& input) { + static_assert(sizeof(T) == sizeof(U), "types must have equal size"); + T result {}; + ::memcpy(&result, &input, sizeof(T)); + return result; +} + +KERNEL_FLOAT_DEVICE uint32_t +bitwise_if_else(uint32_t condition, uint32_t if_true, uint32_t if_false) { + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + // equivalent to (condition & if_true) | ((~condition) & if_false) + asm("lop3.b32 %0, %1, %2, %3, 0xCA;" + : "=r"(result) + : "r"(condition), "r"(if_true), "r"(if_false)); +#else + result = (condition & if_true) | ((~condition) & if_false); +#endif + return result; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x) { + return y; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x, T coef, TRest... coefs) { + y = __hfma2(x, y, T2 {coef, coef}); + return eval_poly_recur(y, x, coefs...); +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly(T2 x, T coef, TRest... coefs) { + return eval_poly_recur(T2 {coef, coef}, x, coefs...); +} + +#define KERNEL_FLOAT_DEFINE_POLY(NAME, N, ...) \ + template \ + struct NAME { \ + template \ + static KERNEL_FLOAT_DEVICE T2 call(T2 x) { \ + return eval_poly(x, __VA_ARGS__); \ + } \ + }; + +template +struct sin_poly: sin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 1, 1.365) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 2, -21.56, 5.18) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 3, 53.53, -38.06, 6.184) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 4, -56.1, 77.94, -41.1, 6.277) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 5, 32.78, -74.5, 81.4, -41.34, 6.28) + +template +struct cos_poly: cos_poly {}; +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 1, 0.0) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 2, -8.0, 0.6943) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 3, 38.94, -17.5, 0.9707) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 4, -59.66, 61.12, -19.56, 0.9985) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 5, 45.66, -82.4, 64.7, -19.73, 1.0) + +template +struct asin_poly: asin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 1, 1.531) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 2, -0.169, 1.567) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { + // Flip signbit of input when sign<0 + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + asm("lop3.b32 %0, %1, %2, %3, 0x6A;" + : "=r"(result) + : "r"(0x80008000), "r"(transmute(sign)), "r"(transmute(input))); +#else + result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); +#endif + + return transmute<__half2>(result); +} + +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { + uint32_t val; +#if KERNEL_FLOAT_IS_CUDA + uint32_t ai = *(reinterpret_cast(&a)); + uint32_t bi = *(reinterpret_cast(&b)); + asm("{ set.gt.u32.f16x2 %0,%1,%2;\n}" : "=r"(val) : "r"(ai), "r"(bi)); +#else + val = transmute(make_short2(a.x > b.x ? ~0 : 0, a.y > b.y ? ~0 : 0)); +#endif + return val; +} + +KERNEL_FLOAT_INLINE __half2 make_half2(half x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) { + /* Using rint is too slow. Round using floating-point magic instead. */ + // __half2 x = arg * make_half2(-0.15915494309); + // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); + + // 1/(2pi) = 0.15915494309189535 + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) { + __half2 xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) { + __half2 xf = normalize_trig_input(x); + return sin_poly::call(__hmul2(xf, xf)) * xf; +} + +template +KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { + // Flip bits + uint32_t m = ~transmute(x); + + // Multiply by bias (add contant) + __half2 y = transmute<__half2>(uint32_t(0x776d776d) + m); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + // y += y * (1 - x * y) + y = __hfma2(y, __hfma2(-x, y, make_half2(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; + + // Add bias (0x199c) + __half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __half2 half_x = make_half2(-0.5) * x; + __half2 correction = __hfma2(half_x, y * y, make_half2(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { + if (Iter == 1) { + __half2 y = rsqrt<0>(x); + + // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` + __half2 xy = x * y; + return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); + } + + return x * rsqrt(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + auto abs_x = __habs2(x); + auto v = asin_poly::call(abs_x); + auto abs_y = __hfma2(-v, sqrt(make_half2(1) - abs_x), make_half2(HALF_PI)); + return flipsign(abs_y, x); +} + +template +KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + return make_half2(HALF_PI) - asin(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { + __half2 y; + + if (Deg == 0) { + // Bring the value to range [32, 64] + // 1.442 = 1/log(2) + // 46.969 = 32.5/log(2) + __half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + + // Transmute to int, shift higher mantissa bits into exponent field. + y = transmute<__half2>((transmute(m) & 0x03ff03ff) << 5); + } else { + // Add a large number to round to an integer + __half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + + // The exponent is now in the lower 5 bits. Shift that into the exponent field. + __half2 exp = transmute<__half2>((transmute(v) & 0x001f001f) << 10); + + // The fractional part can be obtained from "1231-v". + // 0.6934 = log(2) + __half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + + // This is the Taylor expansion of "exp(x)-1" around 0 + __half2 adjust; + if (Deg == 1) { + adjust = frac; + } else if (Deg == 2) { + // adjust = frac + 0.5 * frac^2 + adjust = __hfma2(frac, __hmul2(frac, make_half2(0.5)), frac); + } else /* if (Deg == 2) */ { + // adjust = frac + 0.5 * frac^2 + 1/6 * frac^3 + adjust = __hfma2( + frac, + __hmul2(__hfma2(frac, make_half2(0.1666), make_half2(0.5)), frac), + frac); + } + + // result = exp * (adjust + 1) + y = __hfma2(exp, adjust, exp); + } + + // Values below -10.39 (= -15*log(2)) become zero + uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); + return transmute<__half2>(zero_mask & transmute(y)); +} + +template +KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) { + // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) + uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); + + // 0.6934 = log(2) + // 32.53 = 46.969*log(2) + return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125)); +} + +template +KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { + if (Deg == 0) { + return x * rcp<0>(make_half2(0.2869) + __habs2(x)); + } else { + auto c0 = make_half2(0.4531); + auto c1 = make_half2(0.5156); + auto x2b = __hfma2(x, x, c1); + return (x * x2b) * rcp(__hfma2(x2b, __habs2(x), c0)); + } +} + +#endif // KERNEL_FLOAT_FP16_AVAILABLE + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) { + return {__double2bfloat16(x), __double2bfloat16(x)}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) { + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __bfloat162 ws = __hadd2( + __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), + make_bfloat162(OFFSET)); + return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { + __bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute(x)); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + y = __hfma2(y, __hfma2(__hneg2(x), y, make_bfloat162(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias (0x1f36) + __bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x); + __bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) { + return __hmul2(x, rsqrt(x)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { + static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float OFFSET = 382.4958400542335; + + auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET); + auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET); + + return { + transmute<__bfloat16>(uint16_t(transmute(a))), + transmute<__bfloat16>(uint16_t(transmute(b)))}; +} +#endif +} // namespace approx + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ + namespace detail { \ + template \ + struct apply_approx_impl, 2, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::FUN<__half> fun, __half* output, const __half* input) { \ + __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + template<> \ + struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \ + apply_approx_impl, 2, __half, __half> {}; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) + +} // namespace kernel_float diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index bad5f75..6784e89 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -85,8 +85,10 @@ KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp2, ::hexp2, ::h2exp2) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log2, ::hlog2, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 5ac6ed9..a19ca64 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -359,8 +359,7 @@ template< typename L, typename R, typename T = promoted_vector_value_type, - typename = - enable_if_t> && is_vector_broadcastable>>> + typename = enable_if_t<(vector_size == 3 && vector_size == 3)>> KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { return detail::cross_impl::call(convert_storage(left), convert_storage(right)); } diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 0e90d8d..8d94c51 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -73,8 +73,10 @@ KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp2, hexp2, h2exp2) KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log2, hlog2, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 01b0254..68be6e5 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -8,6 +8,7 @@ // clang-format off #ifdef __CUDACC__ #define KERNEL_FLOAT_IS_CUDA (1) + #define KERNEL_FLOAT_DEVICE __forceinline__ __device__ #ifdef __CUDA_ARCH__ #define KERNEL_FLOAT_INLINE __forceinline__ __device__ @@ -18,6 +19,7 @@ #endif // __CUDA_ARCH__ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) + #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ #ifdef __HIP_DEVICE_COMPILE__ #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 9e5fe42..739f795 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -178,7 +178,6 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(abs) -KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs) KERNEL_FLOAT_DEFINE_UNARY_MATH(floor) KERNEL_FLOAT_DEFINE_UNARY_MATH(round) KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil) @@ -208,31 +207,43 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp) return ::kernel_float::map(ops::NAME> {}, input); \ } -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) + +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) + // This PTX is only supported on CUDA #if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ +#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ struct apply_fastmath_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ - *result = FAST_FUN(*inputs); \ + T input = inputs[0]; \ + *result = EXPR_F32; \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log2, __log2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log10, __log10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ @@ -250,12 +261,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") - -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") + +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") + #endif } // namespace kernel_float diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index ea9787f..7dc3fed 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 11:10:30.225884 -// git hash: 5490ea756b41c688b66dd69e776a58c9ce5b1ef2 +// date: 2024-11-18 12:11:06.609851 +// git hash: de62ad0ced81f2d5129b31bfb621fcbc0ce161e9 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -30,6 +30,7 @@ // clang-format off #ifdef __CUDACC__ #define KERNEL_FLOAT_IS_CUDA (1) + #define KERNEL_FLOAT_DEVICE __forceinline__ __device__ #ifdef __CUDA_ARCH__ #define KERNEL_FLOAT_INLINE __forceinline__ __device__ @@ -40,6 +41,7 @@ #endif // __CUDA_ARCH__ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) + #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ #ifdef __HIP_DEVICE_COMPILE__ #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ @@ -795,6 +797,9 @@ struct apply_impl { template struct apply_fastmath_impl: apply_impl {}; + +template +struct apply_approx_impl: apply_fastmath_impl {}; } // namespace detail struct accurate_policy { @@ -807,6 +812,14 @@ struct fast_policy { using type = detail::apply_fastmath_impl; }; +template +struct approximate_policy { + template + using type = detail::apply_approx_impl; +}; + +using default_approximate_policy = approximate_policy<>; + #ifdef KERNEL_FLOAT_POLICY using default_policy = KERNEL_FLOAT_POLICY; #else @@ -1307,7 +1320,6 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(abs) -KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs) KERNEL_FLOAT_DEFINE_UNARY_MATH(floor) KERNEL_FLOAT_DEFINE_UNARY_MATH(round) KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil) @@ -1337,31 +1349,43 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp) return ::kernel_float::map(ops::NAME> {}, input); \ } -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) + +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) + // This PTX is only supported on CUDA #if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ +#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ struct apply_fastmath_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ - *result = FAST_FUN(*inputs); \ + T input = inputs[0]; \ + *result = EXPR_F32; \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log2, __log2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log10, __log10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ @@ -1379,12 +1403,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") - -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") + +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") + #endif } // namespace kernel_float @@ -1968,8 +1993,7 @@ template< typename L, typename R, typename T = promoted_vector_value_type, - typename = - enable_if_t> && is_vector_broadcastable>>> + typename = enable_if_t<(vector_size == 3 && vector_size == 3)>> KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { return detail::cross_impl::call(convert_storage(left), convert_storage(right)); } @@ -3972,8 +3996,10 @@ KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp2, hexp2, h2exp2) KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log2, hlog2, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) @@ -4205,8 +4231,10 @@ KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp2, ::hexp2, ::h2exp2) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log2, ::hlog2, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) @@ -4364,6 +4392,397 @@ struct promote_type<__half, __bfloat16> { #endif #endif //KERNEL_FLOAT_BF16_H +#pragma once + + + + + + +namespace kernel_float { + +namespace approx { + +static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +using uint32_t = unsigned int; + +template +KERNEL_FLOAT_DEVICE T transmute(const U& input) { + static_assert(sizeof(T) == sizeof(U), "types must have equal size"); + T result {}; + ::memcpy(&result, &input, sizeof(T)); + return result; +} + +KERNEL_FLOAT_DEVICE uint32_t +bitwise_if_else(uint32_t condition, uint32_t if_true, uint32_t if_false) { + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + // equivalent to (condition & if_true) | ((~condition) & if_false) + asm("lop3.b32 %0, %1, %2, %3, 0xCA;" + : "=r"(result) + : "r"(condition), "r"(if_true), "r"(if_false)); +#else + result = (condition & if_true) | ((~condition) & if_false); +#endif + return result; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x) { + return y; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x, T coef, TRest... coefs) { + y = __hfma2(x, y, T2 {coef, coef}); + return eval_poly_recur(y, x, coefs...); +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly(T2 x, T coef, TRest... coefs) { + return eval_poly_recur(T2 {coef, coef}, x, coefs...); +} + +#define KERNEL_FLOAT_DEFINE_POLY(NAME, N, ...) \ + template \ + struct NAME { \ + template \ + static KERNEL_FLOAT_DEVICE T2 call(T2 x) { \ + return eval_poly(x, __VA_ARGS__); \ + } \ + }; + +template +struct sin_poly: sin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 1, 1.365) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 2, -21.56, 5.18) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 3, 53.53, -38.06, 6.184) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 4, -56.1, 77.94, -41.1, 6.277) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 5, 32.78, -74.5, 81.4, -41.34, 6.28) + +template +struct cos_poly: cos_poly {}; +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 1, 0.0) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 2, -8.0, 0.6943) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 3, 38.94, -17.5, 0.9707) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 4, -59.66, 61.12, -19.56, 0.9985) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 5, 45.66, -82.4, 64.7, -19.73, 1.0) + +template +struct asin_poly: asin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 1, 1.531) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 2, -0.169, 1.567) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { + // Flip signbit of input when sign<0 + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + asm("lop3.b32 %0, %1, %2, %3, 0x6A;" + : "=r"(result) + : "r"(0x80008000), "r"(transmute(sign)), "r"(transmute(input))); +#else + result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); +#endif + + return transmute<__half2>(result); +} + +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { + uint32_t val; +#if KERNEL_FLOAT_IS_CUDA + uint32_t ai = *(reinterpret_cast(&a)); + uint32_t bi = *(reinterpret_cast(&b)); + asm("{ set.gt.u32.f16x2 %0,%1,%2;\n}" : "=r"(val) : "r"(ai), "r"(bi)); +#else + val = transmute(make_short2(a.x > b.x ? ~0 : 0, a.y > b.y ? ~0 : 0)); +#endif + return val; +} + +KERNEL_FLOAT_INLINE __half2 make_half2(half x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) { + /* Using rint is too slow. Round using floating-point magic instead. */ + // __half2 x = arg * make_half2(-0.15915494309); + // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); + + // 1/(2pi) = 0.15915494309189535 + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) { + __half2 xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) { + __half2 xf = normalize_trig_input(x); + return sin_poly::call(__hmul2(xf, xf)) * xf; +} + +template +KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { + // Flip bits + uint32_t m = ~transmute(x); + + // Multiply by bias (add contant) + __half2 y = transmute<__half2>(uint32_t(0x776d776d) + m); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + // y += y * (1 - x * y) + y = __hfma2(y, __hfma2(-x, y, make_half2(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; + + // Add bias (0x199c) + __half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __half2 half_x = make_half2(-0.5) * x; + __half2 correction = __hfma2(half_x, y * y, make_half2(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { + if (Iter == 1) { + __half2 y = rsqrt<0>(x); + + // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` + __half2 xy = x * y; + return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); + } + + return x * rsqrt(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + auto abs_x = __habs2(x); + auto v = asin_poly::call(abs_x); + auto abs_y = __hfma2(-v, sqrt(make_half2(1) - abs_x), make_half2(HALF_PI)); + return flipsign(abs_y, x); +} + +template +KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + return make_half2(HALF_PI) - asin(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { + __half2 y; + + if (Deg == 0) { + // Bring the value to range [32, 64] + // 1.442 = 1/log(2) + // 46.969 = 32.5/log(2) + __half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + + // Transmute to int, shift higher mantissa bits into exponent field. + y = transmute<__half2>((transmute(m) & 0x03ff03ff) << 5); + } else { + // Add a large number to round to an integer + __half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + + // The exponent is now in the lower 5 bits. Shift that into the exponent field. + __half2 exp = transmute<__half2>((transmute(v) & 0x001f001f) << 10); + + // The fractional part can be obtained from "1231-v". + // 0.6934 = log(2) + __half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + + // This is the Taylor expansion of "exp(x)-1" around 0 + __half2 adjust; + if (Deg == 1) { + adjust = frac; + } else if (Deg == 2) { + // adjust = frac + 0.5 * frac^2 + adjust = __hfma2(frac, __hmul2(frac, make_half2(0.5)), frac); + } else /* if (Deg == 2) */ { + // adjust = frac + 0.5 * frac^2 + 1/6 * frac^3 + adjust = __hfma2( + frac, + __hmul2(__hfma2(frac, make_half2(0.1666), make_half2(0.5)), frac), + frac); + } + + // result = exp * (adjust + 1) + y = __hfma2(exp, adjust, exp); + } + + // Values below -10.39 (= -15*log(2)) become zero + uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); + return transmute<__half2>(zero_mask & transmute(y)); +} + +template +KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) { + // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) + uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); + + // 0.6934 = log(2) + // 32.53 = 46.969*log(2) + return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125)); +} + +template +KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { + if (Deg == 0) { + return x * rcp<0>(make_half2(0.2869) + __habs2(x)); + } else { + auto c0 = make_half2(0.4531); + auto c1 = make_half2(0.5156); + auto x2b = __hfma2(x, x, c1); + return (x * x2b) * rcp(__hfma2(x2b, __habs2(x), c0)); + } +} + +#endif // KERNEL_FLOAT_FP16_AVAILABLE + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) { + return {__double2bfloat16(x), __double2bfloat16(x)}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) { + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __bfloat162 ws = __hadd2( + __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), + make_bfloat162(OFFSET)); + return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { + __bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute(x)); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + y = __hfma2(y, __hfma2(__hneg2(x), y, make_bfloat162(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias (0x1f36) + __bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x); + __bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) { + return __hmul2(x, rsqrt(x)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { + static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float OFFSET = 382.4958400542335; + + auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET); + auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET); + + return { + transmute<__bfloat16>(uint16_t(transmute(a))), + transmute<__bfloat16>(uint16_t(transmute(b)))}; +} +#endif +} // namespace approx + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ + namespace detail { \ + template \ + struct apply_approx_impl, 2, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::FUN<__half> fun, __half* output, const __half* input) { \ + __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + template<> \ + struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \ + apply_approx_impl, 2, __half, __half> {}; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) + +} // namespace kernel_float #ifndef KERNEL_FLOAT_FP8_H #define KERNEL_FLOAT_FP8_H From ae0e6b16ac2d626e69bb08554044a77671f408ab Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 13:01:16 +0100 Subject: [PATCH 04/25] Simplify how policies are implemented internally --- docs/guides/accuracy.md | 21 ++- include/kernel_float/apply.h | 93 +++++++----- include/kernel_float/approx.h | 38 ++--- include/kernel_float/bf16.h | 53 ++++--- include/kernel_float/binops.h | 18 ++- include/kernel_float/conversion.h | 7 +- include/kernel_float/fp16.h | 6 +- include/kernel_float/reduce.h | 8 +- include/kernel_float/triops.h | 12 +- include/kernel_float/unops.h | 4 +- single_include/kernel_float.h | 243 +++++++++++++++++------------- 11 files changed, 295 insertions(+), 208 deletions(-) diff --git a/docs/guides/accuracy.md b/docs/guides/accuracy.md index da7bb53..fd84588 100644 --- a/docs/guides/accuracy.md +++ b/docs/guides/accuracy.md @@ -25,13 +25,13 @@ kf::vec c = kf::fast_rcp(x); kf::vec d = kf::fast_div(a, b); ``` -These functions are only functional for 32-bit and 16-bit floats. +These functions are only functional for 32-bit and 16-bit floats. For other input types, the operation falls back to the regular version. ## Approximate Math -For 16-bit floats, several approximate functions are provided. -These use approximations (typically low-degree polynomials) to calculate rough estimates of the functions. +For 16-bit floats, several approximate functions are provided. +These use approximations (typically low-degree polynomials) to calculate rough estimates of the functions. This can be very fast but also less accurate. @@ -69,14 +69,15 @@ kf::vec a = kf::approx_sin<3>(x); ## Tuning Accuracy Level -Many functions in Kernel Float accept an additional Accuracy option as a template parameter. +Many functions in Kernel Float accept an additional `Accuracy` option as a template parameter. This allows you to tune the accuracy level without changing the function name. -There are four possible values for this parameter: +There are five possible values for this parameter: - `kf::accurate_policy`: Use the most accurate version of the function available. - `kf::fast_policy`: Use the "fast math" version. -- `kf::approx_policy`: Use the approximate version with degree `N`. +- `kf::approx_level_policy`: Use the approximate version with accuracy level `N` (higher is more accurate). +- `kf::approx_policy`: Use the approximate version with a default accuracy level. - `kf::default_policy`: Use a global default policy (see the next section). For example, consider this code: @@ -97,15 +98,19 @@ kf::vec c = kf::cos(input); kf::vec d = kf::cos(input); // Use the approximate policy -kf::vec e = kf::cos>(input); +kf::vec e = kf::cos(input); + +// Use the approximate policy with degree 3 polynomial. +kf::vec f = kf::cos>(input); // You can use aliases to define your own policy using my_own_policy = kf::fast_policy; -kf::vec f = kf::cos(input); +kf::vec g = kf::cos(input); ``` ## Setting `default_policy` +If no policy is explicitly set, any function use the `kf::default_policy`. By default, `kf::default_policy` is set to `kf::accurate_policy`. Set the preprocessor option `KERNEL_FLOAT_FAST_MATH=1` to change the default policy to `kf::fast_policy`. diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 1421132..7d9a96e 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -116,10 +116,49 @@ broadcast_like(const V& input, const R& other) { return broadcast(input, vector_extent_type {}); } +/** + * The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all + * operations are performed without any approximations or optimizations that could potentially alter the precise + * outcome of the computations + */ +struct accurate_policy {}; + +/** + * The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving + * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve + * approximations that slightly compromise precision. + */ +struct fast_policy {}; + +/** + * This template policy allows developers to specify a custom degree of approximation for their computations. By + * adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the + * specific needs of your application. Higher values mean more precision. + */ +template +struct approx_level_policy {}; + +/** + * The approximate_policy serves as the default approximation policy, providing a standard level of approximation + * without requiring explicit configuration. It balances accuracy and performance, making it suitable for + * general-purpose use cases where neither extreme precision nor maximum speed is necessary. + */ +using approx_policy = approx_level_policy<>; + +#ifndef KERNEL_FLOAT_POLICY +#define KERNEL_FLOAT_POLICY accurate_policy; +#endif + +/** + * The `default_policy` acts as the standard computation policy. It can be configured externally using the + * `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`. + */ +using default_policy = KERNEL_FLOAT_POLICY; + namespace detail { -template -struct apply_impl { +template +struct apply_base_impl { KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { #pragma unroll for (size_t i = 0; i < N; i++) { @@ -128,41 +167,23 @@ struct apply_impl { } }; -template -struct apply_fastmath_impl: apply_impl {}; - -template -struct apply_approx_impl: apply_fastmath_impl {}; -} // namespace detail - -struct accurate_policy { - template - using type = detail::apply_impl; -}; - -struct fast_policy { - template - using type = detail::apply_fastmath_impl; -}; - -template -struct approximate_policy { - template - using type = detail::apply_approx_impl; -}; +template +struct apply_impl: apply_base_impl {}; -using default_approximate_policy = approximate_policy<>; +template +struct apply_base_impl: + apply_impl {}; -#ifdef KERNEL_FLOAT_POLICY -using default_policy = KERNEL_FLOAT_POLICY; -#else -using default_policy = accurate_policy; -#endif +template +struct apply_base_impl: + apply_impl {}; -namespace detail { +template +struct apply_base_impl, F, N, Output, Args...>: + apply_impl {}; template -struct map_policy_impl { +struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; static constexpr size_t remainder = N % packet_size; @@ -170,7 +191,7 @@ struct map_policy_impl { if constexpr (N / packet_size > 0) { #pragma unroll for (size_t i = 0; i < N - remainder; i += packet_size) { - Policy::template type::call( + apply_impl::call( fun, output + i, (args + i)...); @@ -180,14 +201,14 @@ struct map_policy_impl { if constexpr (remainder > 0) { #pragma unroll for (size_t i = N - remainder; i < N; i++) { - Policy::template type::call(fun, output + i, (args + i)...); + apply_impl::call(fun, output + i, (args + i)...); } } } }; template -using map_impl = map_policy_impl; +using default_map_impl = map_impl; } // namespace detail @@ -211,7 +232,7 @@ KERNEL_FLOAT_INLINE map_type map(F fun, const Args&... args) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, Output, vector_value_type...>::call( + detail::map_impl, Output, vector_value_type...>::call( fun, result.data(), (detail::broadcast_impl, vector_extent_type, E>::call( diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index 6725faf..945df81 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -359,25 +359,25 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { #endif } // namespace approx -#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ - namespace detail { \ - template \ - struct apply_approx_impl, 2, __half, __half> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::FUN<__half> fun, __half* output, const __half* input) { \ - __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ - output[0] = res.x; \ - output[1] = res.y; \ - } \ - }; \ - template<> \ - struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \ - apply_approx_impl, 2, __half, __half> {}; \ - } \ - \ - template \ - KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ - return map>(ops::FUN> {}, args); \ +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN<__half>, 2, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::FUN<__half> fun, __half* output, const __half* input) { \ + __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + template<> \ + struct apply_impl, 2, __half, __half>: \ + apply_impl, ops::FUN<__half>, 2, __half, __half> {}; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ } KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 6784e89..0d656aa 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -61,24 +61,24 @@ struct allow_float_fallback<__bfloat16> { }; // namespace detail #if KERNEL_FLOAT_BF16_OPS_SUPPORTED -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__bfloat16> { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __bfloat16, __bfloat16> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ - __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__bfloat16> { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __bfloat16, __bfloat16> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ + __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) @@ -115,7 +115,13 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) } \ namespace detail { \ template<> \ - struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16> { \ + struct apply_impl< \ + accurate_policy, \ + ops::NAME<__bfloat16>, \ + 2, \ + __bfloat16, \ + __bfloat16, \ + __bfloat16> { \ KERNEL_FLOAT_INLINE static void call( \ ops::NAME<__bfloat16>, \ __bfloat16* result, \ @@ -154,7 +160,14 @@ struct fma<__bfloat16> { namespace detail { template<> -struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16, __bfloat16> { +struct apply_impl< + accurate_policy, + ops::fma<__bfloat16>, + 2, + __bfloat16, + __bfloat16, + __bfloat16, + __bfloat16> { KERNEL_FLOAT_INLINE static void call( ops::fma<__bfloat16>, __bfloat16* result, diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index a19ca64..7c9ec2d 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -52,7 +52,7 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co vector_storage> result; - detail::map_impl, O, T, T>::call( + detail::default_map_impl, O, T, T>::call( fun, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -290,21 +290,25 @@ struct multiply { }; // namespace ops namespace detail { -template -struct apply_fastmath_impl, N, T, T, T> { +template +struct apply_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide fun, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; // Fast way to perform division is to multiply by the reciprocal - apply_fastmath_impl, N, T, T>::call({}, rhs_rcp, rhs); - apply_fastmath_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); + apply_impl, N, T, T>::call({}, rhs_rcp, rhs); + apply_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); } }; +template +struct apply_impl, N, T, T, T>: + apply_base_impl, N, T, T, T> {}; + #if KERNEL_FLOAT_IS_DEVICE template<> -struct apply_fastmath_impl, 1, float, float, float> { +struct apply_impl, 1, float, float, float> { KERNEL_FLOAT_INLINE static void call(ops::divide fun, float* result, const float* lhs, const float* rhs) { *result = __fdividef(*lhs, *rhs); @@ -319,7 +323,7 @@ fast_divide(const L& left, const R& right) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, extent_size, T, T, T>::call( + detail::map_impl, extent_size, T, T, T>::call( ops::divide {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( diff --git a/include/kernel_float/conversion.h b/include/kernel_float/conversion.h index 8538be7..8e84cdb 100644 --- a/include/kernel_float/conversion.h +++ b/include/kernel_float/conversion.h @@ -17,7 +17,10 @@ struct convert_impl { static vector_storage> call(vector_storage> input) { using F = ops::cast; vector_storage> intermediate; - detail::map_impl, T2, T>::call(F {}, intermediate.data(), input.data()); + detail::default_map_impl, T2, T>::call( + F {}, + intermediate.data(), + input.data()); return detail::broadcast_impl::call(intermediate); } }; @@ -48,7 +51,7 @@ struct convert_impl { using F = ops::cast; vector_storage> result; - detail::map_impl, T2, T>::call(F {}, result.data(), input.data()); + detail::default_map_impl, T2, T>::call(F {}, result.data(), input.data()); return result; } }; diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 8d94c51..eccd457 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -58,7 +58,7 @@ struct allow_float_fallback<__half> { } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half> { \ + struct apply_impl, 2, __half, __half> { \ KERNEL_FLOAT_INLINE static void call(ops::NAME<__half>, __half* result, const __half* a) { \ __half2 r = FUN2(__half2 {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ @@ -102,7 +102,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half, __half> { \ + struct apply_impl, 2, __half, __half, __half> { \ KERNEL_FLOAT_INLINE static void \ call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ @@ -144,7 +144,7 @@ struct fma<__half> { namespace detail { template<> -struct apply_impl, 2, __half, __half, __half, __half> { +struct apply_impl, 2, __half, __half, __half, __half> { KERNEL_FLOAT_INLINE static void call(ops::fma<__half>, __half* result, const __half* a, const __half* b, const __half* c) { __half2 r = __hfma2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}, __half2 {c[0], c[1]}); diff --git a/include/kernel_float/reduce.h b/include/kernel_float/reduce.h index e616f17..853df3a 100644 --- a/include/kernel_float/reduce.h +++ b/include/kernel_float/reduce.h @@ -24,7 +24,7 @@ struct reduce_recur_impl { template KERNEL_FLOAT_INLINE static T call(F fun, const T* input) { vector_storage temp; - map_impl::call(fun, temp.data(), input, input + K); + default_map_impl::call(fun, temp.data(), input, input + K); if constexpr (N < 2 * K) { #pragma unroll @@ -183,11 +183,11 @@ struct dot_impl { if constexpr (N / K > 0) { T accum[K] = {T {}}; - apply_impl, K, T, T, T>::call({}, accum, left, right); + apply_impl, K, T, T, T>::call({}, accum, left, right); #pragma unroll for (size_t i = 1; i < N / K; i++) { - apply_impl, K, T, T, T, T>::call( + apply_impl, K, T, T, T, T>::call( ops::fma {}, accum, left + i * K, @@ -200,7 +200,7 @@ struct dot_impl { if constexpr (N % K > 0) { for (size_t i = N - N % K; i < N; i++) { - apply_impl, 1, T, T, T, T>::call( + apply_impl, 1, T, T, T, T>::call( {}, &result, left + i, diff --git a/include/kernel_float/triops.h b/include/kernel_float/triops.h index 12cca59..82d3a89 100644 --- a/include/kernel_float/triops.h +++ b/include/kernel_float/triops.h @@ -41,7 +41,7 @@ KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values, cons using F = ops::conditional; vector_storage> result; - detail::map_impl, T, bool, T, T>::call( + detail::default_map_impl, T, bool, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, bool, E>::call( @@ -98,13 +98,13 @@ struct fma { } // namespace ops namespace detail { -template -struct apply_impl, N, T, T, T, T> { +template +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail @@ -140,7 +140,7 @@ KERNEL_FLOAT_INLINE vector fma(const A& a, const B& b, const C& c) { using F = ops::fma; vector_storage> result; - detail::map_impl, T, T, T, T>::call( + detail::default_map_impl, T, T, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 739f795..d288a9c 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -225,7 +225,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ T input = inputs[0]; \ *result = EXPR_F32; \ @@ -248,7 +248,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F fun, T* result, const T* inputs) { \ asm(INSTR " %0, %1;" : "=" REG(*result) : REG(*inputs)); \ } \ diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 7dc3fed..e1ff9a0 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 12:11:06.609851 -// git hash: de62ad0ced81f2d5129b31bfb621fcbc0ce161e9 +// date: 2024-11-18 13:09:09.880139 +// git hash: 5bd24b5f693cf68213714f0e9fd2d471abe04e66 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -783,10 +783,49 @@ broadcast_like(const V& input, const R& other) { return broadcast(input, vector_extent_type {}); } +/** + * The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all + * operations are performed without any approximations or optimizations that could potentially alter the precise + * outcome of the computations + */ +struct accurate_policy {}; + +/** + * The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving + * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve + * approximations that slightly compromise precision. + */ +struct fast_policy {}; + +/** + * This template policy allows developers to specify a custom degree of approximation for their computations. By + * adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the + * specific needs of your application. Higher values mean more precision. + */ +template +struct approx_level_policy {}; + +/** + * The approximate_policy serves as the default approximation policy, providing a standard level of approximation + * without requiring explicit configuration. It balances accuracy and performance, making it suitable for + * general-purpose use cases where neither extreme precision nor maximum speed is necessary. + */ +using approx_policy = approx_level_policy<>; + +#ifndef KERNEL_FLOAT_POLICY +#define KERNEL_FLOAT_POLICY accurate_policy; +#endif + +/** + * The `default_policy` acts as the standard computation policy. It can be configured externally using the + * `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`. + */ +using default_policy = KERNEL_FLOAT_POLICY; + namespace detail { -template -struct apply_impl { +template +struct apply_base_impl { KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { #pragma unroll for (size_t i = 0; i < N; i++) { @@ -795,41 +834,23 @@ struct apply_impl { } }; -template -struct apply_fastmath_impl: apply_impl {}; - -template -struct apply_approx_impl: apply_fastmath_impl {}; -} // namespace detail - -struct accurate_policy { - template - using type = detail::apply_impl; -}; - -struct fast_policy { - template - using type = detail::apply_fastmath_impl; -}; - -template -struct approximate_policy { - template - using type = detail::apply_approx_impl; -}; +template +struct apply_impl: apply_base_impl {}; -using default_approximate_policy = approximate_policy<>; +template +struct apply_base_impl: + apply_impl {}; -#ifdef KERNEL_FLOAT_POLICY -using default_policy = KERNEL_FLOAT_POLICY; -#else -using default_policy = accurate_policy; -#endif +template +struct apply_base_impl: + apply_impl {}; -namespace detail { +template +struct apply_base_impl, F, N, Output, Args...>: + apply_impl {}; template -struct map_policy_impl { +struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; static constexpr size_t remainder = N % packet_size; @@ -837,7 +858,7 @@ struct map_policy_impl { if constexpr (N / packet_size > 0) { #pragma unroll for (size_t i = 0; i < N - remainder; i += packet_size) { - Policy::template type::call( + apply_impl::call( fun, output + i, (args + i)...); @@ -847,14 +868,14 @@ struct map_policy_impl { if constexpr (remainder > 0) { #pragma unroll for (size_t i = N - remainder; i < N; i++) { - Policy::template type::call(fun, output + i, (args + i)...); + apply_impl::call(fun, output + i, (args + i)...); } } } }; template -using map_impl = map_policy_impl; +using default_map_impl = map_impl; } // namespace detail @@ -878,7 +899,7 @@ KERNEL_FLOAT_INLINE map_type map(F fun, const Args&... args) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, Output, vector_value_type...>::call( + detail::map_impl, Output, vector_value_type...>::call( fun, result.data(), (detail::broadcast_impl, vector_extent_type, E>::call( @@ -1367,7 +1388,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ T input = inputs[0]; \ *result = EXPR_F32; \ @@ -1390,7 +1411,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F fun, T* result, const T* inputs) { \ asm(INSTR " %0, %1;" : "=" REG(*result) : REG(*inputs)); \ } \ @@ -1434,7 +1455,10 @@ struct convert_impl { static vector_storage> call(vector_storage> input) { using F = ops::cast; vector_storage> intermediate; - detail::map_impl, T2, T>::call(F {}, intermediate.data(), input.data()); + detail::default_map_impl, T2, T>::call( + F {}, + intermediate.data(), + input.data()); return detail::broadcast_impl::call(intermediate); } }; @@ -1465,7 +1489,7 @@ struct convert_impl { using F = ops::cast; vector_storage> result; - detail::map_impl, T2, T>::call(F {}, result.data(), input.data()); + detail::default_map_impl, T2, T>::call(F {}, result.data(), input.data()); return result; } }; @@ -1686,7 +1710,7 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co vector_storage> result; - detail::map_impl, O, T, T>::call( + detail::default_map_impl, O, T, T>::call( fun, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -1924,21 +1948,25 @@ struct multiply { }; // namespace ops namespace detail { -template -struct apply_fastmath_impl, N, T, T, T> { +template +struct apply_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide fun, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; // Fast way to perform division is to multiply by the reciprocal - apply_fastmath_impl, N, T, T>::call({}, rhs_rcp, rhs); - apply_fastmath_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); + apply_impl, N, T, T>::call({}, rhs_rcp, rhs); + apply_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); } }; +template +struct apply_impl, N, T, T, T>: + apply_base_impl, N, T, T, T> {}; + #if KERNEL_FLOAT_IS_DEVICE template<> -struct apply_fastmath_impl, 1, float, float, float> { +struct apply_impl, 1, float, float, float> { KERNEL_FLOAT_INLINE static void call(ops::divide fun, float* result, const float* lhs, const float* rhs) { *result = __fdividef(*lhs, *rhs); @@ -1953,7 +1981,7 @@ fast_divide(const L& left, const R& right) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, extent_size, T, T, T>::call( + detail::map_impl, extent_size, T, T, T>::call( ops::divide {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -3114,7 +3142,7 @@ KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values, cons using F = ops::conditional; vector_storage> result; - detail::map_impl, T, bool, T, T>::call( + detail::default_map_impl, T, bool, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, bool, E>::call( @@ -3171,13 +3199,13 @@ struct fma { } // namespace ops namespace detail { -template -struct apply_impl, N, T, T, T, T> { +template +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail @@ -3213,7 +3241,7 @@ KERNEL_FLOAT_INLINE vector fma(const A& a, const B& b, const C& c) { using F = ops::fma; vector_storage> result; - detail::map_impl, T, T, T, T>::call( + detail::default_map_impl, T, T, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -3258,7 +3286,7 @@ struct reduce_recur_impl { template KERNEL_FLOAT_INLINE static T call(F fun, const T* input) { vector_storage temp; - map_impl::call(fun, temp.data(), input, input + K); + default_map_impl::call(fun, temp.data(), input, input + K); if constexpr (N < 2 * K) { #pragma unroll @@ -3417,11 +3445,11 @@ struct dot_impl { if constexpr (N / K > 0) { T accum[K] = {T {}}; - apply_impl, K, T, T, T>::call({}, accum, left, right); + apply_impl, K, T, T, T>::call({}, accum, left, right); #pragma unroll for (size_t i = 1; i < N / K; i++) { - apply_impl, K, T, T, T, T>::call( + apply_impl, K, T, T, T, T>::call( ops::fma {}, accum, left + i * K, @@ -3434,7 +3462,7 @@ struct dot_impl { if constexpr (N % K > 0) { for (size_t i = N - N % K; i < N; i++) { - apply_impl, 1, T, T, T, T>::call( + apply_impl, 1, T, T, T, T>::call( {}, &result, left + i, @@ -3981,7 +4009,7 @@ struct allow_float_fallback<__half> { } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half> { \ + struct apply_impl, 2, __half, __half> { \ KERNEL_FLOAT_INLINE static void call(ops::NAME<__half>, __half* result, const __half* a) { \ __half2 r = FUN2(__half2 {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ @@ -4025,7 +4053,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half, __half> { \ + struct apply_impl, 2, __half, __half, __half> { \ KERNEL_FLOAT_INLINE static void \ call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ @@ -4067,7 +4095,7 @@ struct fma<__half> { namespace detail { template<> -struct apply_impl, 2, __half, __half, __half, __half> { +struct apply_impl, 2, __half, __half, __half, __half> { KERNEL_FLOAT_INLINE static void call(ops::fma<__half>, __half* result, const __half* a, const __half* b, const __half* c) { __half2 r = __hfma2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}, __half2 {c[0], c[1]}); @@ -4207,24 +4235,24 @@ struct allow_float_fallback<__bfloat16> { }; // namespace detail #if KERNEL_FLOAT_BF16_OPS_SUPPORTED -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__bfloat16> { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __bfloat16, __bfloat16> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ - __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME<__bfloat16> { \ + KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, __bfloat16, __bfloat16> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ + __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) @@ -4261,7 +4289,13 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) } \ namespace detail { \ template<> \ - struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16> { \ + struct apply_impl< \ + accurate_policy, \ + ops::NAME<__bfloat16>, \ + 2, \ + __bfloat16, \ + __bfloat16, \ + __bfloat16> { \ KERNEL_FLOAT_INLINE static void call( \ ops::NAME<__bfloat16>, \ __bfloat16* result, \ @@ -4300,7 +4334,14 @@ struct fma<__bfloat16> { namespace detail { template<> -struct apply_impl, 2, __bfloat16, __bfloat16, __bfloat16, __bfloat16> { +struct apply_impl< + accurate_policy, + ops::fma<__bfloat16>, + 2, + __bfloat16, + __bfloat16, + __bfloat16, + __bfloat16> { KERNEL_FLOAT_INLINE static void call( ops::fma<__bfloat16>, __bfloat16* result, @@ -4753,25 +4794,25 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { #endif } // namespace approx -#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ - namespace detail { \ - template \ - struct apply_approx_impl, 2, __half, __half> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::FUN<__half> fun, __half* output, const __half* input) { \ - __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ - output[0] = res.x; \ - output[1] = res.y; \ - } \ - }; \ - template<> \ - struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \ - apply_approx_impl, 2, __half, __half> {}; \ - } \ - \ - template \ - KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ - return map>(ops::FUN> {}, args); \ +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN<__half>, 2, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::FUN<__half> fun, __half* output, const __half* input) { \ + __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + template<> \ + struct apply_impl, 2, __half, __half>: \ + apply_impl, ops::FUN<__half>, 2, __half, __half> {}; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ } KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) From f89cf98f79e78ab6013063dea4b4b516ce163855 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 13:40:27 +0100 Subject: [PATCH 05/25] Rename FP16 primitive names: `__half` to `half_t` and `__nv_bfloat16` `to bfloat16_t` --- include/kernel_float/approx.h | 114 ++++++------- include/kernel_float/bf16.h | 107 +++++++------ include/kernel_float/fp16.h | 69 ++++---- single_include/kernel_float.h | 294 +++++++++++++++++----------------- 4 files changed, 298 insertions(+), 286 deletions(-) diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index 945df81..df81d30 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -85,7 +85,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) #if KERNEL_FLOAT_FP16_AVAILABLE -KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { +KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) { // Flip signbit of input when sign<0 uint32_t result; @@ -97,10 +97,10 @@ KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); #endif - return transmute<__half2>(result); + return transmute(result); } -KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(half2_t a, half2_t b) { uint32_t val; #if KERNEL_FLOAT_IS_CUDA uint32_t ai = *(reinterpret_cast(&a)); @@ -112,42 +112,42 @@ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { return val; } -KERNEL_FLOAT_INLINE __half2 make_half2(half x) { +KERNEL_FLOAT_INLINE half2_t make_half2(half x) { return {x, x}; } -KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) { /* Using rint is too slow. Round using floating-point magic instead. */ - // __half2 x = arg * make_half2(-0.15915494309); + // half2_t x = arg * make_half2(-0.15915494309); // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); // 1/(2pi) = 0.15915494309189535 static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; static constexpr double OFFSET = -2042.0; - __half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); } template -KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) { - __half2 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE half2_t cos(half2_t x) { + half2_t xf = normalize_trig_input(x); return cos_poly::call(__hmul2(xf, xf)); } template -KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) { - __half2 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE half2_t sin(half2_t x) { + half2_t xf = normalize_trig_input(x); return sin_poly::call(__hmul2(xf, xf)) * xf; } template -KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) { // Flip bits uint32_t m = ~transmute(x); // Multiply by bias (add contant) - __half2 y = transmute<__half2>(uint32_t(0x776d776d) + m); + half2_t y = transmute(uint32_t(0x776d776d) + m); #pragma unroll for (int i = 0; i < Iter; i++) { @@ -159,19 +159,19 @@ KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) { // Set top and bottom bits for both halfs, then shift by 1, then invert uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; // Add bias (0x199c) - __half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c)); + half2_t y = transmute(uint32_t(r) + uint32_t(0x199c199c)); // Newton-Raphson iterations #pragma unroll for (int i = 0; i < Iter; i++) { - __half2 half_x = make_half2(-0.5) * x; - __half2 correction = __hfma2(half_x, y * y, make_half2(0.5)); + half2_t half_x = make_half2(-0.5) * x; + half2_t correction = __hfma2(half_x, y * y, make_half2(0.5)); y = __hfma2(correction, y, y); // y += y * correction } @@ -179,12 +179,12 @@ KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t sqrt(half2_t x) { if (Iter == 1) { - __half2 y = rsqrt<0>(x); + half2_t y = rsqrt<0>(x); // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` - __half2 xy = x * y; + half2_t xy = x * y; return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); } @@ -192,7 +192,7 @@ KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t asin(half2_t x) { static constexpr double HALF_PI = 1.57079632679; auto abs_x = __habs2(x); auto v = asin_poly::call(abs_x); @@ -201,36 +201,36 @@ KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t acos(half2_t x) { static constexpr double HALF_PI = 1.57079632679; return make_half2(HALF_PI) - asin(x); } template -KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { - __half2 y; +KERNEL_FLOAT_DEVICE half2_t exp(half2_t x) { + half2_t y; if (Deg == 0) { // Bring the value to range [32, 64] // 1.442 = 1/log(2) // 46.969 = 32.5/log(2) - __half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + half2_t m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); // Transmute to int, shift higher mantissa bits into exponent field. - y = transmute<__half2>((transmute(m) & 0x03ff03ff) << 5); + y = transmute((transmute(m) & 0x03ff03ff) << 5); } else { // Add a large number to round to an integer - __half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + half2_t v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); // The exponent is now in the lower 5 bits. Shift that into the exponent field. - __half2 exp = transmute<__half2>((transmute(v) & 0x001f001f) << 10); + half2_t exp = transmute((transmute(v) & 0x001f001f) << 10); // The fractional part can be obtained from "1231-v". // 0.6934 = log(2) - __half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + half2_t frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); // This is the Taylor expansion of "exp(x)-1" around 0 - __half2 adjust; + half2_t adjust; if (Deg == 1) { adjust = frac; } else if (Deg == 2) { @@ -250,21 +250,21 @@ KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { // Values below -10.39 (= -15*log(2)) become zero uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); - return transmute<__half2>(zero_mask & transmute(y)); + return transmute(zero_mask & transmute(y)); } template -KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) { +KERNEL_FLOAT_DEVICE half2_t log(half2_t arg) { // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); // 0.6934 = log(2) // 32.53 = 46.969*log(2) - return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125)); + return __hfma2(transmute(bits), make_half2(0.6934), make_half2(-32.53125)); } template -KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) { if (Deg == 0) { return x * rcp<0>(make_half2(0.2869) + __habs2(x)); } else { @@ -278,39 +278,39 @@ KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { #endif // KERNEL_FLOAT_FP16_AVAILABLE #if KERNEL_FLOAT_BF16_OPS_SUPPORTED -KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) { return {x, x}; } -KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(double x) { return {__double2bfloat16(x), __double2bfloat16(x)}; } -KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) { static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; static constexpr double OFFSET = -2042.0; - __bfloat162 ws = __hadd2( + bfloat16x2_t ws = __hadd2( __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), make_bfloat162(OFFSET)); return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); } template -KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) { - __bfloat162 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); } template -KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) { - __bfloat162 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); } template -KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { - __bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute(x)); +KERNEL_FLOAT_DEVICE bfloat16x2_t rcp(bfloat16x2_t x) { + bfloat16x2_t y = transmute(uint32_t(0x7ef07ef0) + ~transmute(x)); #pragma unroll for (int i = 0; i < Iter; i++) { @@ -321,18 +321,18 @@ KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { } template -KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt(bfloat16x2_t x) { // Set top and bottom bits for both halfs, then shift by 1, then invert uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); // Add bias (0x1f36) - __bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36)); + bfloat16x2_t y = transmute(uint32_t(r) + uint32_t(0x1f361f36)); // Newton-Raphson iterations #pragma unroll for (int i = 0; i < Iter; i++) { - __bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x); - __bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + bfloat16x2_t half_x = __hmul2(make_bfloat162(-0.5), x); + bfloat16x2_t correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); y = __hfma2(correction, y, y); // y += y * correction } @@ -340,17 +340,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { } template -KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { return __hmul2(x, rsqrt(x)); } template -KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { +KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { static constexpr float SCALE = 1.44272065994f / 256.0f; static constexpr float OFFSET = 382.4958400542335; - auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET); - auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET); + auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET); + auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET); return { transmute<__bfloat16>(uint16_t(transmute(a))), @@ -362,17 +362,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { #define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ namespace detail { \ template \ - struct apply_impl, ops::FUN<__half>, 2, __half, __half> { \ + struct apply_impl, ops::FUN, 2, half_t, half_t> { \ KERNEL_FLOAT_INLINE static void \ - call(ops::FUN<__half> fun, __half* output, const __half* input) { \ - __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + call(ops::FUN fun, half_t* output, const half_t* input) { \ + half2_t res = approx::FUN(half2_t {input[0], input[1]}); \ output[0] = res.x; \ output[1] = res.y; \ } \ }; \ template<> \ - struct apply_impl, 2, __half, __half>: \ - apply_impl, ops::FUN<__half>, 2, __half, __half> {}; \ + struct apply_impl, 2, half_t, half_t>: \ + apply_impl, ops::FUN, 2, half_t, half_t> {}; \ } \ \ template \ diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 0d656aa..251a12a 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -22,11 +22,11 @@ namespace kernel_float { #if KERNEL_FLOAT_IS_CUDA -using __bfloat16 = __nv_bfloat16; -using __bfloat162 = __nv_bfloat162; +using bfloat16_t = __nv_bfloat16; +using bfloat16x2_t = __nv_bfloat162; #elif KERNEL_FLOAT_IS_HIP -using __bfloat16 = __hip_bfloat16; -using __bfloat162 = __hip_bfloat162; +using bfloat16_t = __hip_bfloat16; +using bfloat16x2_t = __hip_bfloat162; #endif #if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 @@ -34,28 +34,28 @@ using __bfloat162 = __hip_bfloat162; #endif template<> -struct preferred_vector_size<__bfloat16> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, bfloat16_t) template<> -struct into_vector_impl<__bfloat162> { - using value_type = __bfloat16; +struct into_vector_impl { + using value_type = bfloat16_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__bfloat16, 2> call(__bfloat162 input) { + static vector_storage call(bfloat16x2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__bfloat16> { +struct allow_float_fallback { static constexpr bool value = true; }; }; // namespace detail @@ -64,18 +64,18 @@ struct allow_float_fallback<__bfloat16> { #define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__bfloat16> { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t input) { \ return FUN1(input); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __bfloat16, __bfloat16> { \ + struct apply_impl, 2, bfloat16_t, bfloat16_t> { \ KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ - __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ + call(ops::NAME, bfloat16_t* result, const bfloat16_t* a) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -107,9 +107,9 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) #define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__bfloat16> { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \ - return ops::cast {}(FUN1(left, right)); \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t left, bfloat16_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ } \ }; \ } \ @@ -117,17 +117,17 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) template<> \ struct apply_impl< \ accurate_policy, \ - ops::NAME<__bfloat16>, \ + ops::NAME, \ 2, \ - __bfloat16, \ - __bfloat16, \ - __bfloat16> { \ + bfloat16_t, \ + bfloat16_t, \ + bfloat16_t> { \ KERNEL_FLOAT_INLINE static void call( \ - ops::NAME<__bfloat16>, \ - __bfloat16* result, \ - const __bfloat16* a, \ - const __bfloat16* b) { \ - __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}); \ + ops::NAME, \ + bfloat16_t* result, \ + const bfloat16_t* a, \ + const bfloat16_t* b) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}, bfloat16x2_t {b[0], b[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -151,8 +151,8 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) #if KERNEL_FLOAT_BF16_OPS_SUPPORTED namespace ops { template<> -struct fma<__bfloat16> { - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 a, __bfloat16 b, __bfloat16 c) const { +struct fma { + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t a, bfloat16_t b, bfloat16_t c) const { return __hfma(a, b, c); } }; @@ -162,20 +162,22 @@ namespace detail { template<> struct apply_impl< accurate_policy, - ops::fma<__bfloat16>, + ops::fma, 2, - __bfloat16, - __bfloat16, - __bfloat16, - __bfloat16> { + bfloat16_t, + bfloat16_t, + bfloat16_t, + bfloat16_t> { KERNEL_FLOAT_INLINE static void call( - ops::fma<__bfloat16>, - __bfloat16* result, - const __bfloat16* a, - const __bfloat16* b, - const __bfloat16* c) { - __bfloat162 r = - __hfma2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}, __bfloat162 {c[0], c[1]}); + ops::fma, + bfloat16_t* result, + const bfloat16_t* a, + const bfloat16_t* b, + const bfloat16_t* c) { + bfloat16x2_t r = __hfma2( + bfloat16x2_t {a[0], a[1]}, + bfloat16x2_t {b[0], b[1]}, + bfloat16x2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; @@ -185,14 +187,14 @@ struct apply_impl< #define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(T input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(T input) { \ return TO_HALF; \ } \ }; \ template<> \ - struct cast<__bfloat16, T> { \ - KERNEL_FLOAT_INLINE T operator()(__bfloat16 input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(bfloat16_t input) { \ return FROM_HALF; \ } \ }; \ @@ -232,10 +234,9 @@ KERNEL_FLOAT_BF16_CAST( (__hip_bfloat16(input).data & 0x7FFF) != 0); #endif -using bfloat16 = __bfloat16; -KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __bfloat16) +KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, bfloat16_t) } // namespace kernel_float @@ -244,12 +245,12 @@ KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __bfloat16) namespace kernel_float { template<> -struct promote_type<__bfloat16, __half> { +struct promote_type { using type = float; }; template<> -struct promote_type<__half, __bfloat16> { +struct promote_type { using type = float; }; diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index eccd457..4c185ff 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -19,29 +19,35 @@ namespace kernel_float { +using half_t = ::__half; +using half2_t = ::__half2; + +using __half = void; +using __half2 = void; + template<> -struct preferred_vector_size<__half> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, half_t) template<> -struct into_vector_impl<__half2> { - using value_type = __half; +struct into_vector_impl { + using value_type = half_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__half, 2> call(__half2 input) { + static vector_storage call(half2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__half> { +struct allow_float_fallback { static constexpr bool value = true; }; }; // namespace detail @@ -50,17 +56,17 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half input) { \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t input) { \ return FUN1(input); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half> { \ - KERNEL_FLOAT_INLINE static void call(ops::NAME<__half>, __half* result, const __half* a) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}); \ + struct apply_impl, 2, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void call(ops::NAME, half_t* result, const half_t* a) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -94,18 +100,18 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ - return ops::cast {}(FUN1(left, right)); \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t left, half_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half, __half> { \ + struct apply_impl, 2, half_t, half_t, half_t> { \ KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ + call(ops::NAME, half_t* result, const half_t* a, const half_t* b) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -135,8 +141,8 @@ KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2) #if KERNEL_FLOAT_IS_DEVICE namespace ops { template<> -struct fma<__half> { - KERNEL_FLOAT_INLINE __half operator()(__half a, __half b, __half c) const { +struct fma { + KERNEL_FLOAT_INLINE half_t operator()(half_t a, half_t b, half_t c) const { return __hfma(a, b, c); } }; @@ -144,10 +150,10 @@ struct fma<__half> { namespace detail { template<> -struct apply_impl, 2, __half, __half, __half, __half> { +struct apply_impl, 2, half_t, half_t, half_t, half_t> { KERNEL_FLOAT_INLINE static void - call(ops::fma<__half>, __half* result, const __half* a, const __half* b, const __half* c) { - __half2 r = __hfma2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}, __half2 {c[0], c[1]}); + call(ops::fma, half_t* result, const half_t* a, const half_t* b, const half_t* c) { + half2_t r = __hfma2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}, half2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; @@ -157,14 +163,14 @@ struct apply_impl, 2, __half, __half, __half, #define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __half operator()(T input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE half_t operator()(T input) { \ return TO_HALF; \ } \ }; \ template<> \ - struct cast<__half, T> { \ - KERNEL_FLOAT_INLINE T operator()(__half input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(half_t input) { \ return FROM_HALF; \ } \ }; \ @@ -211,10 +217,9 @@ KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__ha KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); #endif -using half = __half; -KERNEL_FLOAT_VECTOR_ALIAS(half, __half) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +KERNEL_FLOAT_VECTOR_ALIAS(half, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) } // namespace kernel_float diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index e1ff9a0..1b9126d 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 13:09:09.880139 -// git hash: 5bd24b5f693cf68213714f0e9fd2d471abe04e66 +// date: 2024-11-18 13:40:03.668017 +// git hash: ae0e6b16ac2d626e69bb08554044a77671f408ab //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -3970,29 +3970,35 @@ vec(Args&&... args) -> vec, sizeof...(Args)>; namespace kernel_float { +using half_t = ::__half; +using half2_t = ::__half2; + +using __half = void; +using __half2 = void; + template<> -struct preferred_vector_size<__half> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, half_t) template<> -struct into_vector_impl<__half2> { - using value_type = __half; +struct into_vector_impl { + using value_type = half_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__half, 2> call(__half2 input) { + static vector_storage call(half2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__half> { +struct allow_float_fallback { static constexpr bool value = true; }; }; // namespace detail @@ -4001,17 +4007,17 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half input) { \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t input) { \ return FUN1(input); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half> { \ - KERNEL_FLOAT_INLINE static void call(ops::NAME<__half>, __half* result, const __half* a) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}); \ + struct apply_impl, 2, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void call(ops::NAME, half_t* result, const half_t* a) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -4045,18 +4051,18 @@ KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ - return ops::cast {}(FUN1(left, right)); \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t left, half_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half, __half> { \ + struct apply_impl, 2, half_t, half_t, half_t> { \ KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ + call(ops::NAME, half_t* result, const half_t* a, const half_t* b) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -4086,8 +4092,8 @@ KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2) #if KERNEL_FLOAT_IS_DEVICE namespace ops { template<> -struct fma<__half> { - KERNEL_FLOAT_INLINE __half operator()(__half a, __half b, __half c) const { +struct fma { + KERNEL_FLOAT_INLINE half_t operator()(half_t a, half_t b, half_t c) const { return __hfma(a, b, c); } }; @@ -4095,10 +4101,10 @@ struct fma<__half> { namespace detail { template<> -struct apply_impl, 2, __half, __half, __half, __half> { +struct apply_impl, 2, half_t, half_t, half_t, half_t> { KERNEL_FLOAT_INLINE static void - call(ops::fma<__half>, __half* result, const __half* a, const __half* b, const __half* c) { - __half2 r = __hfma2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}, __half2 {c[0], c[1]}); + call(ops::fma, half_t* result, const half_t* a, const half_t* b, const half_t* c) { + half2_t r = __hfma2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}, half2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; @@ -4108,14 +4114,14 @@ struct apply_impl, 2, __half, __half, __half, #define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __half operator()(T input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE half_t operator()(T input) { \ return TO_HALF; \ } \ }; \ template<> \ - struct cast<__half, T> { \ - KERNEL_FLOAT_INLINE T operator()(__half input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(half_t input) { \ return FROM_HALF; \ } \ }; \ @@ -4162,10 +4168,9 @@ KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__ha KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); #endif -using half = __half; -KERNEL_FLOAT_VECTOR_ALIAS(half, __half) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +KERNEL_FLOAT_VECTOR_ALIAS(half, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) } // namespace kernel_float @@ -4196,11 +4201,11 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, __half) namespace kernel_float { #if KERNEL_FLOAT_IS_CUDA -using __bfloat16 = __nv_bfloat16; -using __bfloat162 = __nv_bfloat162; +using bfloat16_t = __nv_bfloat16; +using bfloat16x2_t = __nv_bfloat162; #elif KERNEL_FLOAT_IS_HIP -using __bfloat16 = __hip_bfloat16; -using __bfloat162 = __hip_bfloat162; +using bfloat16_t = __hip_bfloat16; +using bfloat16x2_t = __hip_bfloat162; #endif #if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 @@ -4208,28 +4213,28 @@ using __bfloat162 = __hip_bfloat162; #endif template<> -struct preferred_vector_size<__bfloat16> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, bfloat16_t) template<> -struct into_vector_impl<__bfloat162> { - using value_type = __bfloat16; +struct into_vector_impl { + using value_type = bfloat16_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__bfloat16, 2> call(__bfloat162 input) { + static vector_storage call(bfloat16x2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__bfloat16> { +struct allow_float_fallback { static constexpr bool value = true; }; }; // namespace detail @@ -4238,18 +4243,18 @@ struct allow_float_fallback<__bfloat16> { #define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__bfloat16> { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 input) { \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t input) { \ return FUN1(input); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __bfloat16, __bfloat16> { \ + struct apply_impl, 2, bfloat16_t, bfloat16_t> { \ KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__bfloat16>, __bfloat16* result, const __bfloat16* a) { \ - __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}); \ + call(ops::NAME, bfloat16_t* result, const bfloat16_t* a) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -4281,9 +4286,9 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) #define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__bfloat16> { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 left, __bfloat16 right) const { \ - return ops::cast {}(FUN1(left, right)); \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t left, bfloat16_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ } \ }; \ } \ @@ -4291,17 +4296,17 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) template<> \ struct apply_impl< \ accurate_policy, \ - ops::NAME<__bfloat16>, \ + ops::NAME, \ 2, \ - __bfloat16, \ - __bfloat16, \ - __bfloat16> { \ + bfloat16_t, \ + bfloat16_t, \ + bfloat16_t> { \ KERNEL_FLOAT_INLINE static void call( \ - ops::NAME<__bfloat16>, \ - __bfloat16* result, \ - const __bfloat16* a, \ - const __bfloat16* b) { \ - __bfloat162 r = FUN2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}); \ + ops::NAME, \ + bfloat16_t* result, \ + const bfloat16_t* a, \ + const bfloat16_t* b) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}, bfloat16x2_t {b[0], b[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -4325,8 +4330,8 @@ KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) #if KERNEL_FLOAT_BF16_OPS_SUPPORTED namespace ops { template<> -struct fma<__bfloat16> { - KERNEL_FLOAT_INLINE __bfloat16 operator()(__bfloat16 a, __bfloat16 b, __bfloat16 c) const { +struct fma { + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t a, bfloat16_t b, bfloat16_t c) const { return __hfma(a, b, c); } }; @@ -4336,20 +4341,22 @@ namespace detail { template<> struct apply_impl< accurate_policy, - ops::fma<__bfloat16>, + ops::fma, 2, - __bfloat16, - __bfloat16, - __bfloat16, - __bfloat16> { + bfloat16_t, + bfloat16_t, + bfloat16_t, + bfloat16_t> { KERNEL_FLOAT_INLINE static void call( - ops::fma<__bfloat16>, - __bfloat16* result, - const __bfloat16* a, - const __bfloat16* b, - const __bfloat16* c) { - __bfloat162 r = - __hfma2(__bfloat162 {a[0], a[1]}, __bfloat162 {b[0], b[1]}, __bfloat162 {c[0], c[1]}); + ops::fma, + bfloat16_t* result, + const bfloat16_t* a, + const bfloat16_t* b, + const bfloat16_t* c) { + bfloat16x2_t r = __hfma2( + bfloat16x2_t {a[0], a[1]}, + bfloat16x2_t {b[0], b[1]}, + bfloat16x2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; @@ -4359,14 +4366,14 @@ struct apply_impl< #define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __bfloat16 operator()(T input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(T input) { \ return TO_HALF; \ } \ }; \ template<> \ - struct cast<__bfloat16, T> { \ - KERNEL_FLOAT_INLINE T operator()(__bfloat16 input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(bfloat16_t input) { \ return FROM_HALF; \ } \ }; \ @@ -4406,10 +4413,9 @@ KERNEL_FLOAT_BF16_CAST( (__hip_bfloat16(input).data & 0x7FFF) != 0); #endif -using bfloat16 = __bfloat16; -KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __bfloat16) +KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, bfloat16_t) } // namespace kernel_float @@ -4418,12 +4424,12 @@ KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __bfloat16) namespace kernel_float { template<> -struct promote_type<__bfloat16, __half> { +struct promote_type { using type = float; }; template<> -struct promote_type<__half, __bfloat16> { +struct promote_type { using type = float; }; @@ -4520,7 +4526,7 @@ KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) #if KERNEL_FLOAT_FP16_AVAILABLE -KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { +KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) { // Flip signbit of input when sign<0 uint32_t result; @@ -4532,10 +4538,10 @@ KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); #endif - return transmute<__half2>(result); + return transmute(result); } -KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(half2_t a, half2_t b) { uint32_t val; #if KERNEL_FLOAT_IS_CUDA uint32_t ai = *(reinterpret_cast(&a)); @@ -4547,42 +4553,42 @@ KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { return val; } -KERNEL_FLOAT_INLINE __half2 make_half2(half x) { +KERNEL_FLOAT_INLINE half2_t make_half2(half x) { return {x, x}; } -KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) { /* Using rint is too slow. Round using floating-point magic instead. */ - // __half2 x = arg * make_half2(-0.15915494309); + // half2_t x = arg * make_half2(-0.15915494309); // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); // 1/(2pi) = 0.15915494309189535 static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; static constexpr double OFFSET = -2042.0; - __half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); } template -KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) { - __half2 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE half2_t cos(half2_t x) { + half2_t xf = normalize_trig_input(x); return cos_poly::call(__hmul2(xf, xf)); } template -KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) { - __half2 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE half2_t sin(half2_t x) { + half2_t xf = normalize_trig_input(x); return sin_poly::call(__hmul2(xf, xf)) * xf; } template -KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) { // Flip bits uint32_t m = ~transmute(x); // Multiply by bias (add contant) - __half2 y = transmute<__half2>(uint32_t(0x776d776d) + m); + half2_t y = transmute(uint32_t(0x776d776d) + m); #pragma unroll for (int i = 0; i < Iter; i++) { @@ -4594,19 +4600,19 @@ KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) { // Set top and bottom bits for both halfs, then shift by 1, then invert uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; // Add bias (0x199c) - __half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c)); + half2_t y = transmute(uint32_t(r) + uint32_t(0x199c199c)); // Newton-Raphson iterations #pragma unroll for (int i = 0; i < Iter; i++) { - __half2 half_x = make_half2(-0.5) * x; - __half2 correction = __hfma2(half_x, y * y, make_half2(0.5)); + half2_t half_x = make_half2(-0.5) * x; + half2_t correction = __hfma2(half_x, y * y, make_half2(0.5)); y = __hfma2(correction, y, y); // y += y * correction } @@ -4614,12 +4620,12 @@ KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t sqrt(half2_t x) { if (Iter == 1) { - __half2 y = rsqrt<0>(x); + half2_t y = rsqrt<0>(x); // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` - __half2 xy = x * y; + half2_t xy = x * y; return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); } @@ -4627,7 +4633,7 @@ KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t asin(half2_t x) { static constexpr double HALF_PI = 1.57079632679; auto abs_x = __habs2(x); auto v = asin_poly::call(abs_x); @@ -4636,36 +4642,36 @@ KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { } template -KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t acos(half2_t x) { static constexpr double HALF_PI = 1.57079632679; return make_half2(HALF_PI) - asin(x); } template -KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { - __half2 y; +KERNEL_FLOAT_DEVICE half2_t exp(half2_t x) { + half2_t y; if (Deg == 0) { // Bring the value to range [32, 64] // 1.442 = 1/log(2) // 46.969 = 32.5/log(2) - __half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + half2_t m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); // Transmute to int, shift higher mantissa bits into exponent field. - y = transmute<__half2>((transmute(m) & 0x03ff03ff) << 5); + y = transmute((transmute(m) & 0x03ff03ff) << 5); } else { // Add a large number to round to an integer - __half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + half2_t v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); // The exponent is now in the lower 5 bits. Shift that into the exponent field. - __half2 exp = transmute<__half2>((transmute(v) & 0x001f001f) << 10); + half2_t exp = transmute((transmute(v) & 0x001f001f) << 10); // The fractional part can be obtained from "1231-v". // 0.6934 = log(2) - __half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + half2_t frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); // This is the Taylor expansion of "exp(x)-1" around 0 - __half2 adjust; + half2_t adjust; if (Deg == 1) { adjust = frac; } else if (Deg == 2) { @@ -4685,21 +4691,21 @@ KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { // Values below -10.39 (= -15*log(2)) become zero uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); - return transmute<__half2>(zero_mask & transmute(y)); + return transmute(zero_mask & transmute(y)); } template -KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) { +KERNEL_FLOAT_DEVICE half2_t log(half2_t arg) { // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); // 0.6934 = log(2) // 32.53 = 46.969*log(2) - return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125)); + return __hfma2(transmute(bits), make_half2(0.6934), make_half2(-32.53125)); } template -KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { +KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) { if (Deg == 0) { return x * rcp<0>(make_half2(0.2869) + __habs2(x)); } else { @@ -4713,39 +4719,39 @@ KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { #endif // KERNEL_FLOAT_FP16_AVAILABLE #if KERNEL_FLOAT_BF16_OPS_SUPPORTED -KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) { return {x, x}; } -KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(double x) { return {__double2bfloat16(x), __double2bfloat16(x)}; } -KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) { static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; static constexpr double OFFSET = -2042.0; - __bfloat162 ws = __hadd2( + bfloat16x2_t ws = __hadd2( __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), make_bfloat162(OFFSET)); return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); } template -KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) { - __bfloat162 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); } template -KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) { - __bfloat162 xf = normalize_trig_input(x); +KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); } template -KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { - __bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute(x)); +KERNEL_FLOAT_DEVICE bfloat16x2_t rcp(bfloat16x2_t x) { + bfloat16x2_t y = transmute(uint32_t(0x7ef07ef0) + ~transmute(x)); #pragma unroll for (int i = 0; i < Iter; i++) { @@ -4756,18 +4762,18 @@ KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { } template -KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt(bfloat16x2_t x) { // Set top and bottom bits for both halfs, then shift by 1, then invert uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); // Add bias (0x1f36) - __bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36)); + bfloat16x2_t y = transmute(uint32_t(r) + uint32_t(0x1f361f36)); // Newton-Raphson iterations #pragma unroll for (int i = 0; i < Iter; i++) { - __bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x); - __bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + bfloat16x2_t half_x = __hmul2(make_bfloat162(-0.5), x); + bfloat16x2_t correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); y = __hfma2(correction, y, y); // y += y * correction } @@ -4775,17 +4781,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { } template -KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) { +KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { return __hmul2(x, rsqrt(x)); } template -KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { +KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { static constexpr float SCALE = 1.44272065994f / 256.0f; static constexpr float OFFSET = 382.4958400542335; - auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET); - auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET); + auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET); + auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET); return { transmute<__bfloat16>(uint16_t(transmute(a))), @@ -4797,17 +4803,17 @@ KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { #define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ namespace detail { \ template \ - struct apply_impl, ops::FUN<__half>, 2, __half, __half> { \ + struct apply_impl, ops::FUN, 2, half_t, half_t> { \ KERNEL_FLOAT_INLINE static void \ - call(ops::FUN<__half> fun, __half* output, const __half* input) { \ - __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + call(ops::FUN fun, half_t* output, const half_t* input) { \ + half2_t res = approx::FUN(half2_t {input[0], input[1]}); \ output[0] = res.x; \ output[1] = res.y; \ } \ }; \ template<> \ - struct apply_impl, 2, __half, __half>: \ - apply_impl, ops::FUN<__half>, 2, __half, __half> {}; \ + struct apply_impl, 2, half_t, half_t>: \ + apply_impl, ops::FUN, 2, half_t, half_t> {}; \ } \ \ template \ From 014e32f040ce08295a9ab0fc4169e6fe68e0c1b7 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 13:50:33 +0100 Subject: [PATCH 06/25] Implement approximation for `pow` --- include/kernel_float/binops.h | 25 ++++++++++++++++++++++--- single_include/kernel_float.h | 29 ++++++++++++++++++++++++----- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 7c9ec2d..0c3d04a 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -292,8 +292,7 @@ struct multiply { namespace detail { template struct apply_impl, N, T, T, T> { - KERNEL_FLOAT_INLINE static void - call(ops::divide fun, T* result, const T* lhs, const T* rhs) { + KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; // Fast way to perform division is to multiply by the reciprocal @@ -310,13 +309,33 @@ struct apply_impl, N, T, T, T>: template<> struct apply_impl, 1, float, float, float> { KERNEL_FLOAT_INLINE static void - call(ops::divide fun, float* result, const float* lhs, const float* rhs) { + call(ops::divide, float* result, const float* lhs, const float* rhs) { *result = __fdividef(*lhs, *rhs); } }; #endif } // namespace detail +namespace detail { +// Override `pow` using `log2` and `exp2` +template +struct apply_impl, N, T, T, T> { + KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { + T lhs_log[N]; + T result_log[N]; + + // Fast way to perform power function is using log2 and exp2 + apply_impl, N, T, T>::call({}, lhs_log, lhs); + apply_impl, N, T, T, T>::call({}, result_log, lhs_log, rhs); + apply_impl, N, T, T, T>::call({}, result, result_log); + } +}; + +template +struct apply_impl, N, T, T, T>: + apply_base_impl, N, T, T, T> {}; +} // namespace detail + template> KERNEL_FLOAT_INLINE zip_common_type, T, T> fast_divide(const L& left, const R& right) { diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 1b9126d..31fcb0f 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 13:40:03.668017 -// git hash: ae0e6b16ac2d626e69bb08554044a77671f408ab +// date: 2024-11-18 13:50:24.614671 +// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -1950,8 +1950,7 @@ struct multiply { namespace detail { template struct apply_impl, N, T, T, T> { - KERNEL_FLOAT_INLINE static void - call(ops::divide fun, T* result, const T* lhs, const T* rhs) { + KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; // Fast way to perform division is to multiply by the reciprocal @@ -1968,13 +1967,33 @@ struct apply_impl, N, T, T, T>: template<> struct apply_impl, 1, float, float, float> { KERNEL_FLOAT_INLINE static void - call(ops::divide fun, float* result, const float* lhs, const float* rhs) { + call(ops::divide, float* result, const float* lhs, const float* rhs) { *result = __fdividef(*lhs, *rhs); } }; #endif } // namespace detail +namespace detail { +// Override `pow` using `log2` and `exp2` +template +struct apply_impl, N, T, T, T> { + KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { + T lhs_log[N]; + T result_log[N]; + + // Fast way to perform power function is using log2 and exp2 + apply_impl, N, T, T>::call({}, lhs_log, lhs); + apply_impl, N, T, T, T>::call({}, result_log, lhs_log, rhs); + apply_impl, N, T, T, T>::call({}, result, result_log); + } +}; + +template +struct apply_impl, N, T, T, T>: + apply_base_impl, N, T, T, T> {}; +} // namespace detail + template> KERNEL_FLOAT_INLINE zip_common_type, T, T> fast_divide(const L& left, const R& right) { From 003ce3677ecb97dc1602e38a3e774c103d05aa1a Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 16:57:53 +0100 Subject: [PATCH 07/25] Add `apply_fallback_impl` struct --- include/kernel_float/apply.h | 38 +++++++++++++++++++++++++++-------- include/kernel_float/binops.h | 12 ++--------- include/kernel_float/fp16.h | 5 +---- include/kernel_float/triops.h | 8 ++++---- 4 files changed, 37 insertions(+), 26 deletions(-) diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 7d9a96e..4e2bf8b 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -157,31 +157,53 @@ using default_policy = KERNEL_FLOAT_POLICY; namespace detail { +// template -struct apply_base_impl { +struct apply_fallback_impl { KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { -#pragma unroll - for (size_t i = 0; i < N; i++) { - output[i] = fun(args[i]...); - } + static_assert(N > 0, "operation not implemented"); } }; +template +struct apply_base_impl: apply_fallback_impl {}; + template struct apply_impl: apply_base_impl {}; +// `fast_policy` falls back to `accurate_policy` template -struct apply_base_impl: +struct apply_fallback_impl: apply_impl {}; +// `approx_policy` falls back to `fast_policy` template -struct apply_base_impl: +struct apply_fallback_impl: apply_impl {}; +// `approx_level_policy` falls back to `approx_policy` template -struct apply_base_impl, F, N, Output, Args...>: +struct apply_fallback_impl, F, N, Output, Args...>: apply_impl {}; +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + +// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + output[i] = invoke_impl::call(fun, args[i]...); + } + } +}; + template struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 0c3d04a..1bc66b6 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -291,7 +291,7 @@ struct multiply { namespace detail { template -struct apply_impl, N, T, T, T> { +struct apply_base_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; @@ -301,10 +301,6 @@ struct apply_impl, N, T, T, T> { } }; -template -struct apply_impl, N, T, T, T>: - apply_base_impl, N, T, T, T> {}; - #if KERNEL_FLOAT_IS_DEVICE template<> struct apply_impl, 1, float, float, float> { @@ -319,7 +315,7 @@ struct apply_impl, 1, float, float, float> { namespace detail { // Override `pow` using `log2` and `exp2` template -struct apply_impl, N, T, T, T> { +struct apply_base_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T lhs_log[N]; T result_log[N]; @@ -330,10 +326,6 @@ struct apply_impl, N, T, T, T> { apply_impl, N, T, T, T>::call({}, result, result_log); } }; - -template -struct apply_impl, N, T, T, T>: - apply_base_impl, N, T, T, T> {}; } // namespace detail template> diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 4c185ff..67b97fe 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -22,9 +22,6 @@ namespace kernel_float { using half_t = ::__half; using half2_t = ::__half2; -using __half = void; -using __half2 = void; - template<> struct preferred_vector_size { static constexpr size_t value = 2; @@ -50,7 +47,7 @@ template<> struct allow_float_fallback { static constexpr bool value = true; }; -}; // namespace detail +} // namespace detail #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ diff --git a/include/kernel_float/triops.h b/include/kernel_float/triops.h index 82d3a89..8e27e70 100644 --- a/include/kernel_float/triops.h +++ b/include/kernel_float/triops.h @@ -98,13 +98,13 @@ struct fma { } // namespace ops namespace detail { -template -struct apply_impl, N, T, T, T, T> { +template +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail From 76501fda40df9e396998d11840bc8f10b11ea47b Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 18 Nov 2024 16:58:20 +0100 Subject: [PATCH 08/25] Add `approx_*` functions --- include/kernel_float/approx.h | 94 ++++++++++++------- single_include/kernel_float.h | 165 +++++++++++++++++++++------------- 2 files changed, 169 insertions(+), 90 deletions(-) diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index df81d30..c1e7836 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -9,7 +9,7 @@ namespace kernel_float { namespace approx { -static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); using uint32_t = unsigned int; template @@ -346,11 +346,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { template KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { - static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float SCALE = 1.44272065994 / 256.0; static constexpr float OFFSET = 382.4958400542335; + static constexpr float MINIMUM = 382; - auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET); - auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET); + float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); return { transmute<__bfloat16>(uint16_t(transmute(a))), @@ -359,33 +360,66 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { #endif } // namespace approx -#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ - namespace detail { \ - template \ - struct apply_impl, ops::FUN, 2, half_t, half_t> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::FUN fun, half_t* output, const half_t* input) { \ - half2_t res = approx::FUN(half2_t {input[0], input[1]}); \ - output[0] = res.x; \ - output[1] = res.y; \ - } \ - }; \ - template<> \ - struct apply_impl, 2, half_t, half_t>: \ - apply_impl, ops::FUN, 2, half_t, half_t> {}; \ - } \ - \ - template \ - KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ - return map>(ops::FUN> {}, args); \ +namespace detail { +template +struct apply_impl, F, 1, T, T> { + KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { + T in2[2], out2[2]; + out2[0] = input[0]; + apply_impl, F, 2, T, T>::call(fun, out2, in2); + output[0] = out2[0]; + } +}; +} // namespace detail + +#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN, 2, T, T> { \ + KERNEL_FLOAT_INLINE static void call(ops::FUN, T* output, const T* input) { \ + auto res = approx::FUN({input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + \ + template<> \ + struct apply_impl, 2, T, T>: \ + apply_impl, ops::FUN, 2, T, T> {}; \ + } + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0) +//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ } -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sin) +KERNEL_FLOAT_DEFINE_APPROX_FUN(cos) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(exp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(log) } // namespace kernel_float diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 31fcb0f..0e66057 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 13:50:24.614671 -// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855 +// date: 2024-11-18 16:57:58.817191 +// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -824,31 +824,53 @@ using default_policy = KERNEL_FLOAT_POLICY; namespace detail { +// template -struct apply_base_impl { +struct apply_fallback_impl { KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { -#pragma unroll - for (size_t i = 0; i < N; i++) { - output[i] = fun(args[i]...); - } + static_assert(N > 0, "operation not implemented"); } }; +template +struct apply_base_impl: apply_fallback_impl {}; + template struct apply_impl: apply_base_impl {}; +// `fast_policy` falls back to `accurate_policy` template -struct apply_base_impl: +struct apply_fallback_impl: apply_impl {}; +// `approx_policy` falls back to `fast_policy` template -struct apply_base_impl: +struct apply_fallback_impl: apply_impl {}; +// `approx_level_policy` falls back to `approx_policy` template -struct apply_base_impl, F, N, Output, Args...>: +struct apply_fallback_impl, F, N, Output, Args...>: apply_impl {}; +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + +// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + output[i] = invoke_impl::call(fun, args[i]...); + } + } +}; + template struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; @@ -1949,7 +1971,7 @@ struct multiply { namespace detail { template -struct apply_impl, N, T, T, T> { +struct apply_base_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; @@ -1959,10 +1981,6 @@ struct apply_impl, N, T, T, T> { } }; -template -struct apply_impl, N, T, T, T>: - apply_base_impl, N, T, T, T> {}; - #if KERNEL_FLOAT_IS_DEVICE template<> struct apply_impl, 1, float, float, float> { @@ -1977,7 +1995,7 @@ struct apply_impl, 1, float, float, float> { namespace detail { // Override `pow` using `log2` and `exp2` template -struct apply_impl, N, T, T, T> { +struct apply_base_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T lhs_log[N]; T result_log[N]; @@ -1988,10 +2006,6 @@ struct apply_impl, N, T, T, T> { apply_impl, N, T, T, T>::call({}, result, result_log); } }; - -template -struct apply_impl, N, T, T, T>: - apply_base_impl, N, T, T, T> {}; } // namespace detail template> @@ -3218,13 +3232,13 @@ struct fma { } // namespace ops namespace detail { -template -struct apply_impl, N, T, T, T, T> { +template +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail @@ -3992,9 +4006,6 @@ namespace kernel_float { using half_t = ::__half; using half2_t = ::__half2; -using __half = void; -using __half2 = void; - template<> struct preferred_vector_size { static constexpr size_t value = 2; @@ -4020,7 +4031,7 @@ template<> struct allow_float_fallback { static constexpr bool value = true; }; -}; // namespace detail +} // namespace detail #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ @@ -4469,7 +4480,7 @@ namespace kernel_float { namespace approx { -static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); using uint32_t = unsigned int; template @@ -4806,11 +4817,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { template KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { - static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float SCALE = 1.44272065994 / 256.0; static constexpr float OFFSET = 382.4958400542335; + static constexpr float MINIMUM = 382; - auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET); - auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET); + float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); return { transmute<__bfloat16>(uint16_t(transmute(a))), @@ -4819,34 +4831,67 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { #endif } // namespace approx -#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ - namespace detail { \ - template \ - struct apply_impl, ops::FUN, 2, half_t, half_t> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::FUN fun, half_t* output, const half_t* input) { \ - half2_t res = approx::FUN(half2_t {input[0], input[1]}); \ - output[0] = res.x; \ - output[1] = res.y; \ - } \ - }; \ - template<> \ - struct apply_impl, 2, half_t, half_t>: \ - apply_impl, ops::FUN, 2, half_t, half_t> {}; \ - } \ - \ - template \ - KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ - return map>(ops::FUN> {}, args); \ - } - -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) +namespace detail { +template +struct apply_impl, F, 1, T, T> { + KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { + T in2[2], out2[2]; + out2[0] = input[0]; + apply_impl, F, 2, T, T>::call(fun, out2, in2); + output[0] = out2[0]; + } +}; +} // namespace detail + +#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN, 2, T, T> { \ + KERNEL_FLOAT_INLINE static void call(ops::FUN, T* output, const T* input) { \ + auto res = approx::FUN({input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + \ + template<> \ + struct apply_impl, 2, T, T>: \ + apply_impl, ops::FUN, 2, T, T> {}; \ + } + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0) +//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(sin) +KERNEL_FLOAT_DEFINE_APPROX_FUN(cos) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(exp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(log) } // namespace kernel_float #ifndef KERNEL_FLOAT_FP8_H From ba7356a8805fd99618c615ae19b01a3fad4be705 Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 20 Nov 2024 10:37:16 +0100 Subject: [PATCH 09/25] Fix incorrect definition of `KERNEL_FLOAT_POLICY` --- include/kernel_float/apply.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 4e2bf8b..d558f90 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -146,7 +146,7 @@ struct approx_level_policy {}; using approx_policy = approx_level_policy<>; #ifndef KERNEL_FLOAT_POLICY -#define KERNEL_FLOAT_POLICY accurate_policy; +#define KERNEL_FLOAT_POLICY accurate_policy #endif /** From 4231f44a83927cb1c2f14c2fab0ab30064e064cb Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 20 Nov 2024 10:37:52 +0100 Subject: [PATCH 10/25] Add `Accuracy` parameter to `zip_common` --- include/kernel_float/binops.h | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 1bc66b6..75de26a 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -44,7 +44,7 @@ using zip_common_type = vector< * vec c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f] * ``` */ -template +template KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, const R& right) { using T = promoted_vector_value_type; using O = result_t; @@ -52,7 +52,7 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co vector_storage> result; - detail::default_map_impl, O, T, T>::call( + detail::map_impl, O, T, T>::call( fun, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -65,10 +65,17 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co return result; } -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ - template> \ - KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, static_cast(left), static_cast(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ + template< \ + typename Accuracy = default_policy, \ + typename L, \ + typename R, \ + typename C = promoted_vector_value_type> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common( \ + ops::NAME {}, \ + static_cast(left), \ + static_cast(right)); \ } #define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \ From f5edbc8a0a4d07e7c6d1279ec4e0969a20beb6da Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 20 Nov 2024 10:38:52 +0100 Subject: [PATCH 11/25] Add `add_mul` to `vector` --- include/kernel_float/prelude.h | 10 +++++----- include/kernel_float/vector.h | 11 ++++++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/include/kernel_float/prelude.h b/include/kernel_float/prelude.h index b16f054..a723820 100644 --- a/include/kernel_float/prelude.h +++ b/include/kernel_float/prelude.h @@ -61,14 +61,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double) KERNEL_FLOAT_TYPE_ALIAS(float64x, double) #if KERNEL_FLOAT_FP16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(half, __half) -KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) -KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) +KERNEL_FLOAT_TYPE_ALIAS(half, half_t) +KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) +KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) #endif #if KERNEL_FLOAT_BF16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16) -KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, bfloat16_t) +KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t) #endif #if KERNEL_FLOAT_BF8_AVAILABLE diff --git a/include/kernel_float/vector.h b/include/kernel_float/vector.h index b7656f7..222db2e 100644 --- a/include/kernel_float/vector.h +++ b/include/kernel_float/vector.h @@ -287,11 +287,20 @@ struct vector: public S { } /** - * Returns the result of `*this + lhs * rhs`. + * Returns the result of `this + lhs * rhs`. * * The operation is performed using a single `kernel_float::fma` call, which may be faster then perform * the addition and multiplication separately. */ + template< + typename L, + typename R, + typename T2 = promote_t, vector_value_type>, + typename E2 = broadcast_extent, vector_extent_type>> + KERNEL_FLOAT_INLINE vector add_mul(const L& lhs, const R& rhs) const { + return ::kernel_float::fma(lhs, rhs, *this); + } + template< typename L, typename R, From e6c8a7c6c572fd1bf42419b15986ce6a4e23b819 Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 20 Nov 2024 10:39:37 +0100 Subject: [PATCH 12/25] Overwrite `fast_policy` for FP16 and BF16 --- include/kernel_float/bf16.h | 16 ++++++++++++++++ include/kernel_float/fp16.h | 16 ++++++++++++++++ include/kernel_float/fp8.h | 24 ++++++++++++------------ include/kernel_float/unops.h | 3 +++ 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 251a12a..22ea8b8 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -181,6 +181,22 @@ struct apply_impl< result[0] = r.x, result[1] = r.y; } }; + +// clang-format off +#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \ + template \ + struct apply_impl, N, bfloat16_t, bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, bfloat16_t* output, const bfloat16_t* input) { \ + float v[N]; \ + map_impl, N, float, bfloat16_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, bfloat16_t, float>::call({}, output, v); \ + } \ + }; +// clang-format on + +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH) } // namespace detail #endif diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 67b97fe..cbc3383 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -154,6 +154,22 @@ struct apply_impl, 2, half_t, half_t, half_t, result[0] = r.x, result[1] = r.y; } }; + +// clang-format off +#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \ + template \ + struct apply_impl, N, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, half_t* output, const half_t* input) { \ + float v[N]; \ + map_impl, N, float, half_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, half_t, float>::call({}, output, v); \ + } \ + }; +// clang-format on + +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH) } // namespace detail #endif diff --git a/include/kernel_float/fp8.h b/include/kernel_float/fp8.h index 49c55d4..c6ee2e9 100644 --- a/include/kernel_float/fp8.h +++ b/include/kernel_float/fp8.h @@ -64,7 +64,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { #define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \ namespace detail { \ template<> \ - struct apply_impl, 2, FP8_TY, T> { \ + struct apply_impl, 2, FP8_TY, T> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, FP8_TY* result, const T* v) { \ __half2_raw x; \ memcpy(&x, v, 2 * sizeof(T)); \ @@ -73,7 +73,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { } \ }; \ template<> \ - struct apply_impl, 2, T, FP8_TY> { \ + struct apply_impl, 2, T, FP8_TY> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, T* result, const FP8_TY* v) { \ __nv_fp8x2_storage_t x; \ memcpy(&x, v, 2 * sizeof(FP8_TY)); \ @@ -91,12 +91,12 @@ KERNEL_FLOAT_FP8_CAST(double) #include "fp16.h" namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__half) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(half_t) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_FP16_AVAILABLE @@ -105,12 +105,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) #include "bf16.h" namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__nv_bfloat16) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(bfloat16_t) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_BF16_AVAILABLE diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index d288a9c..6955e6c 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -263,6 +263,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") +#define KERNEL_FLOAT_FAST_F32_MAP(F) \ + F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) + //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") From 5c859b9763559ef1fa46d82711a46510bdddbb8e Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 20 Nov 2024 10:40:46 +0100 Subject: [PATCH 13/25] `kernel_float::approx::sqrt(0)` now returns 0 --- include/kernel_float/approx.h | 15 +++-- single_include/kernel_float.h | 120 +++++++++++++++++++++++++--------- 2 files changed, 98 insertions(+), 37 deletions(-) diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index c1e7836..db54ab2 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -160,17 +160,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) { template KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) { + // A small number added such that rsqrt(0) does not return NaN + static constexpr double EPS = 0.00000768899917602539; + // Set top and bottom bits for both halfs, then shift by 1, then invert uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); - //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; - // Add bias (0x199c) - half2_t y = transmute(uint32_t(r) + uint32_t(0x199c199c)); + // Add bias + static constexpr uint32_t BIAS = 0x199c199c; + half2_t y = transmute(uint32_t(r) + BIAS); // Newton-Raphson iterations #pragma unroll for (int i = 0; i < Iter; i++) { - half2_t half_x = make_half2(-0.5) * x; + half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS)); half2_t correction = __hfma2(half_x, y * y, make_half2(0.5)); y = __hfma2(correction, y, y); // y += y * correction } @@ -365,7 +368,7 @@ template struct apply_impl, F, 1, T, T> { KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { T in2[2], out2[2]; - out2[0] = input[0]; + in2[0] = input[0]; apply_impl, F, 2, T, T>::call(fun, out2, in2); output[0] = out2[0]; } @@ -396,6 +399,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2) #endif #if KERNEL_FLOAT_BF16_OPS_SUPPORTED diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 0e66057..5ad772b 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 16:57:58.817191 -// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a +// date: 2024-11-20 10:36:45.284577 +// git hash: 76501fda40df9e396998d11840bc8f10b11ea47b //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -813,7 +813,7 @@ struct approx_level_policy {}; using approx_policy = approx_level_policy<>; #ifndef KERNEL_FLOAT_POLICY -#define KERNEL_FLOAT_POLICY accurate_policy; +#define KERNEL_FLOAT_POLICY accurate_policy #endif /** @@ -1448,6 +1448,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") +#define KERNEL_FLOAT_FAST_F32_MAP(F) \ + F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) + //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") @@ -1724,7 +1727,7 @@ using zip_common_type = vector< * vec c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f] * ``` */ -template +template KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, const R& right) { using T = promoted_vector_value_type; using O = result_t; @@ -1732,7 +1735,7 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co vector_storage> result; - detail::default_map_impl, O, T, T>::call( + detail::map_impl, O, T, T>::call( fun, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -1745,10 +1748,17 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co return result; } -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ - template> \ - KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, static_cast(left), static_cast(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ + template< \ + typename Accuracy = default_policy, \ + typename L, \ + typename R, \ + typename C = promoted_vector_value_type> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common( \ + ops::NAME {}, \ + static_cast(left), \ + static_cast(right)); \ } #define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \ @@ -3887,11 +3897,20 @@ struct vector: public S { } /** - * Returns the result of `*this + lhs * rhs`. + * Returns the result of `this + lhs * rhs`. * * The operation is performed using a single `kernel_float::fma` call, which may be faster then perform * the addition and multiplication separately. */ + template< + typename L, + typename R, + typename T2 = promote_t, vector_value_type>, + typename E2 = broadcast_extent, vector_extent_type>> + KERNEL_FLOAT_INLINE vector add_mul(const L& lhs, const R& rhs) const { + return ::kernel_float::fma(lhs, rhs, *this); + } + template< typename L, typename R, @@ -4138,6 +4157,22 @@ struct apply_impl, 2, half_t, half_t, half_t, result[0] = r.x, result[1] = r.y; } }; + +// clang-format off +#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \ + template \ + struct apply_impl, N, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, half_t* output, const half_t* input) { \ + float v[N]; \ + map_impl, N, float, half_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, half_t, float>::call({}, output, v); \ + } \ + }; +// clang-format on + +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH) } // namespace detail #endif @@ -4390,6 +4425,22 @@ struct apply_impl< result[0] = r.x, result[1] = r.y; } }; + +// clang-format off +#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \ + template \ + struct apply_impl, N, bfloat16_t, bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, bfloat16_t* output, const bfloat16_t* input) { \ + float v[N]; \ + map_impl, N, float, bfloat16_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, bfloat16_t, float>::call({}, output, v); \ + } \ + }; +// clang-format on + +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH) } // namespace detail #endif @@ -4631,17 +4682,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) { template KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) { + // A small number added such that rsqrt(0) does not return NaN + static constexpr double EPS = 0.00000768899917602539; + // Set top and bottom bits for both halfs, then shift by 1, then invert uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); - //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; - // Add bias (0x199c) - half2_t y = transmute(uint32_t(r) + uint32_t(0x199c199c)); + // Add bias + static constexpr uint32_t BIAS = 0x199c199c; + half2_t y = transmute(uint32_t(r) + BIAS); // Newton-Raphson iterations #pragma unroll for (int i = 0; i < Iter; i++) { - half2_t half_x = make_half2(-0.5) * x; + half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS)); half2_t correction = __hfma2(half_x, y * y, make_half2(0.5)); y = __hfma2(correction, y, y); // y += y * correction } @@ -4836,7 +4890,7 @@ template struct apply_impl, F, 1, T, T> { KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { T in2[2], out2[2]; - out2[0] = input[0]; + in2[0] = input[0]; apply_impl, F, 2, T, T>::call(fun, out2, in2); output[0] = out2[0]; } @@ -4867,6 +4921,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2) #endif #if KERNEL_FLOAT_BF16_OPS_SUPPORTED @@ -4960,7 +5016,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { #define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \ namespace detail { \ template<> \ - struct apply_impl, 2, FP8_TY, T> { \ + struct apply_impl, 2, FP8_TY, T> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, FP8_TY* result, const T* v) { \ __half2_raw x; \ memcpy(&x, v, 2 * sizeof(T)); \ @@ -4969,7 +5025,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { } \ }; \ template<> \ - struct apply_impl, 2, T, FP8_TY> { \ + struct apply_impl, 2, T, FP8_TY> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, T* result, const FP8_TY* v) { \ __nv_fp8x2_storage_t x; \ memcpy(&x, v, 2 * sizeof(FP8_TY)); \ @@ -4987,12 +5043,12 @@ KERNEL_FLOAT_FP8_CAST(double) namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__half) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(half_t) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_FP16_AVAILABLE @@ -5001,12 +5057,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__nv_bfloat16) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(bfloat16_t) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_BF16_AVAILABLE @@ -5075,14 +5131,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double) KERNEL_FLOAT_TYPE_ALIAS(float64x, double) #if KERNEL_FLOAT_FP16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(half, __half) -KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) -KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) +KERNEL_FLOAT_TYPE_ALIAS(half, half_t) +KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) +KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) #endif #if KERNEL_FLOAT_BF16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16) -KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, bfloat16_t) +KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t) #endif #if KERNEL_FLOAT_BF8_AVAILABLE From 76c695a4cc5b13b3d5841ac5085574a5b47a299c Mon Sep 17 00:00:00 2001 From: stijn Date: Tue, 26 Nov 2024 13:48:27 +0100 Subject: [PATCH 14/25] Fix compilation error on HIP due to `KERNEL_FLOAT_FAST_F32_MAP` --- include/kernel_float/unops.h | 13 +++++++------ single_include/kernel_float.h | 17 +++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 6955e6c..f2059ac 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -212,16 +212,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) -// This PTX is only supported on CUDA -#if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE +#if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ @@ -245,6 +242,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) +// This PTX is only supported on CUDA +#if KERNEL_FLOAT_IS_CUDA #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ template<> \ @@ -261,7 +260,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f") +#endif #define KERNEL_FLOAT_FAST_F32_MAP(F) \ F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) @@ -270,7 +270,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") - +#else +#define KERNEL_FLOAT_FAST_F32_MAP(F) #endif } // namespace kernel_float diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 5ad772b..f0639dc 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-20 10:36:45.284577 -// git hash: 76501fda40df9e396998d11840bc8f10b11ea47b +// date: 2024-11-26 13:52:06.286983 +// git hash: c4c6ac09808d14b5407afb06ecdecd235cd50ed3 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -1397,16 +1397,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) -// This PTX is only supported on CUDA -#if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE +#if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ @@ -1430,6 +1427,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) +// This PTX is only supported on CUDA +#if KERNEL_FLOAT_IS_CUDA #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ template<> \ @@ -1446,7 +1445,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f") +#endif #define KERNEL_FLOAT_FAST_F32_MAP(F) \ F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) @@ -1455,7 +1455,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") //KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") - +#else +#define KERNEL_FLOAT_FAST_F32_MAP(F) #endif } // namespace kernel_float From d8a53a370a35db9b26fca39957abb35d75b53cec Mon Sep 17 00:00:00 2001 From: stijn Date: Tue, 26 Nov 2024 14:21:40 +0100 Subject: [PATCH 15/25] Remove call to `__exp2f` since it does not exist --- include/kernel_float/apply.h | 14 +++++++------- include/kernel_float/unops.h | 12 +++++++----- single_include/kernel_float.h | 30 ++++++++++++++++-------------- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index d558f90..2698fa1 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -157,6 +157,13 @@ using default_policy = KERNEL_FLOAT_POLICY; namespace detail { +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + // template struct apply_fallback_impl { @@ -186,13 +193,6 @@ template struct apply_fallback_impl, F, N, Output, Args...>: apply_impl {}; -template -struct invoke_impl { - KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { - return fun(args...); - } -}; - // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. template struct apply_impl { diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index f2059ac..12baef5 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -231,7 +231,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) } KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) // Seems to be missing? KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) @@ -257,19 +257,21 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f") + +// These are no longer necessary due to the KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN above +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") #endif #define KERNEL_FLOAT_FAST_F32_MAP(F) \ F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") #else #define KERNEL_FLOAT_FAST_F32_MAP(F) #endif diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index f0639dc..e59bb80 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-26 13:52:06.286983 -// git hash: c4c6ac09808d14b5407afb06ecdecd235cd50ed3 +// date: 2024-11-26 14:20:49.081641 +// git hash: 76c695a4cc5b13b3d5841ac5085574a5b47a299c //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -824,6 +824,13 @@ using default_policy = KERNEL_FLOAT_POLICY; namespace detail { +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + // template struct apply_fallback_impl { @@ -853,13 +860,6 @@ template struct apply_fallback_impl, F, N, Output, Args...>: apply_impl {}; -template -struct invoke_impl { - KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { - return fun(args...); - } -}; - // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. template struct apply_impl { @@ -1416,7 +1416,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) } KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) // Seems to be missing? KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) @@ -1442,19 +1442,21 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f") + +// These are no longer necessary due to the KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN above +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") #endif #define KERNEL_FLOAT_FAST_F32_MAP(F) \ F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") -//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") #else #define KERNEL_FLOAT_FAST_F32_MAP(F) #endif From a2b08a56e31d1c9a6302c8a49c740cf56fcc1607 Mon Sep 17 00:00:00 2001 From: stijn Date: Tue, 26 Nov 2024 14:22:45 +0100 Subject: [PATCH 16/25] Change github workflow to compile for all architectures --- .github/workflows/cmake-action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cmake-action.yml b/.github/workflows/cmake-action.yml index fd621db..37cdf95 100644 --- a/.github/workflows/cmake-action.yml +++ b/.github/workflows/cmake-action.yml @@ -33,7 +33,7 @@ jobs: - name: Configure CMake # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type - run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DKERNEL_FLOAT_BUILD_TEST=1 -DKERNEL_FLOAT_BUILD_EXAMPLE=1 + run: CUDAARCHS=all cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DKERNEL_FLOAT_BUILD_TEST=1 -DKERNEL_FLOAT_BUILD_EXAMPLE=1 - name: Build # Build your program with the given configuration From 846de1f9aefaef76da15ebb5474080d531efaf38 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 2 Dec 2024 11:01:01 +0100 Subject: [PATCH 17/25] Fix bug in `approx::exp(bfloat16)` for HIP --- include/kernel_float/approx.h | 10 ++++++---- single_include/kernel_float.h | 14 ++++++++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index db54ab2..6ad70d9 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -10,7 +10,9 @@ namespace kernel_float { namespace approx { static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); +static_assert(sizeof(unsigned short) * 8 == 16, "invalid size of unsigned short"); using uint32_t = unsigned int; +using uint16_t = unsigned short; template KERNEL_FLOAT_DEVICE T transmute(const U& input) { @@ -353,12 +355,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { static constexpr float OFFSET = 382.4958400542335; static constexpr float MINIMUM = 382; - float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); - float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); + float a = fmaxf(fmaf(__bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(__bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); return { - transmute<__bfloat16>(uint16_t(transmute(a))), - transmute<__bfloat16>(uint16_t(transmute(b)))}; + transmute(uint16_t(transmute(a))), + transmute(uint16_t(transmute(b)))}; } #endif } // namespace approx diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index e59bb80..f6b2493 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-26 14:20:49.081641 -// git hash: 76c695a4cc5b13b3d5841ac5085574a5b47a299c +// date: 2024-12-02 10:59:19.296684 +// git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -4535,7 +4535,9 @@ namespace kernel_float { namespace approx { static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); +static_assert(sizeof(unsigned short) * 8 == 16, "invalid size of unsigned short"); using uint32_t = unsigned int; +using uint16_t = unsigned short; template KERNEL_FLOAT_DEVICE T transmute(const U& input) { @@ -4878,12 +4880,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { static constexpr float OFFSET = 382.4958400542335; static constexpr float MINIMUM = 382; - float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); - float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); + float a = fmaxf(fmaf(__bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(__bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); return { - transmute<__bfloat16>(uint16_t(transmute(a))), - transmute<__bfloat16>(uint16_t(transmute(b)))}; + transmute(uint16_t(transmute(a))), + transmute(uint16_t(transmute(b)))}; } #endif } // namespace approx From f94bd1068ba605130043a96f395084e168906826 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 2 Dec 2024 18:49:12 +0100 Subject: [PATCH 18/25] Fix several issues related to HIP compilation for bfloat16 --- include/kernel_float/bf16.h | 30 ++++++++++++++++++++++-- include/kernel_float/binops.h | 6 +++-- include/kernel_float/macros.h | 3 +-- single_include/kernel_float.h | 43 ++++++++++++++++++++++++++++------- 4 files changed, 68 insertions(+), 14 deletions(-) diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 22ea8b8..f89fc3c 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -60,7 +60,6 @@ struct allow_float_fallback { }; }; // namespace detail -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -81,6 +80,7 @@ struct allow_float_fallback { }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) @@ -101,9 +101,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) + +// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops. +// For CUDA, we can just use the regular bfloat16 functions (see above). +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data &= 0x7FFF; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data ^= 0x8000; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) { + return {hip_habs(a.x), hip_habs(a.y)}; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) { + return {hip_hneg(a.x), hip_hneg(a.y)}; +} + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2) #endif -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -133,6 +158,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 75de26a..2e4c149 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -189,14 +189,16 @@ namespace ops { template struct min { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left < right ? left : right; + auto cond = less {}(left, right); + return cast {}(cond) ? left : right; } }; template struct max { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left > right ? left : right; + auto cond = greater {}(left, right); + return cast {}(cond) ? left : right; } }; diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 68be6e5..88bbdbc 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -20,12 +20,11 @@ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #ifdef __HIP_DEVICE_COMPILE__ - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_DEVICE (1) #else - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_HOST (1) #endif diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index f6b2493..c77c7e6 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-12-02 10:59:19.296684 -// git hash: a2b08a56e31d1c9a6302c8a49c740cf56fcc1607 +// date: 2024-12-02 18:48:50.243676 +// git hash: 846de1f9aefaef76da15ebb5474080d531efaf38 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -42,12 +42,11 @@ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #ifdef __HIP_DEVICE_COMPILE__ - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_DEVICE (1) #else - #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ #define KERNEL_FLOAT_IS_HOST (1) #endif @@ -1875,14 +1874,16 @@ namespace ops { template struct min { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left < right ? left : right; + auto cond = less {}(left, right); + return cast {}(cond) ? left : right; } }; template struct max { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left > right ? left : right; + auto cond = greater {}(left, right); + return cast {}(cond) ? left : right; } }; @@ -4307,7 +4308,6 @@ struct allow_float_fallback { }; }; // namespace detail -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -4328,6 +4328,7 @@ struct allow_float_fallback { }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) @@ -4348,9 +4349,34 @@ KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) + +// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops. +// For CUDA, we can just use the regular bfloat16 functions (see above). +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data &= 0x7FFF; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data ^= 0x8000; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) { + return {hip_habs(a.x), hip_habs(a.y)}; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) { + return {hip_hneg(a.x), hip_hneg(a.y)}; +} + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2) #endif -#if KERNEL_FLOAT_BF16_OPS_SUPPORTED #define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ @@ -4380,6 +4406,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) }; \ } +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2) From 27307892958727687895089083c10c77a122da2f Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 27 Jan 2025 16:20:17 +0100 Subject: [PATCH 19/25] `vector_ptr` now requires alignment in bytes instead of elements --- example.cu | 13 ++ examples/vector_add/main.cu | 2 +- include/kernel_float/memory.h | 250 ++++++++++++++++++++-------------- tests/memory.cu | 18 ++- 4 files changed, 177 insertions(+), 106 deletions(-) create mode 100644 example.cu diff --git a/example.cu b/example.cu new file mode 100644 index 0000000..f15e3e7 --- /dev/null +++ b/example.cu @@ -0,0 +1,13 @@ +#include "kernel_float.h" +#include + +namespace kf = kernel_float; + +__global__ void kernel( + kf::vec_ptr input, + float constant, + kf::vec_ptr output +) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + output(i) = input[i] + kf::cast(constant); +} diff --git a/examples/vector_add/main.cu b/examples/vector_add/main.cu index 4c9d8b8..2ac2eb4 100644 --- a/examples/vector_add/main.cu +++ b/examples/vector_add/main.cu @@ -22,7 +22,7 @@ __global__ void my_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * N < length) { - output(i) = kf::fma(input[i], input[i], kf::cast(constant)); + output[i] = kf::fma(input[i], input[i], kf::cast(constant)); } } diff --git a/include/kernel_float/memory.h b/include/kernel_float/memory.h index d76ed1e..2f0a92a 100644 --- a/include/kernel_float/memory.h +++ b/include/kernel_float/memory.h @@ -33,7 +33,7 @@ struct copy_impl> { * * ``` * // Load 2 elements at data[0] and data[8], skip data[2] and data[4] - * vec values = = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); + * vec values = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); * ``` */ template> @@ -80,7 +80,7 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const I& indices, const V& values, const * vec values = read<4>(data); * * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values = read<4>(values + 10, data); + * vec values = read<4>(data + 10); * ``` */ template @@ -107,31 +107,42 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const V& values) { } namespace detail { +/** + * Returns the greatest common divisor of `a` and `b`. + */ KERNEL_FLOAT_INLINE constexpr size_t gcd(size_t a, size_t b) { return b == 0 ? a : gcd(b, a % b); } -template +/** + * Returns true if a pointer having alignment of `a` bytes also has an alignment of `b` bytes. Returns false otherwise. + */ +KERNEL_FLOAT_INLINE +constexpr size_t alignment_divisible(size_t a, size_t b) { + return gcd(a, KERNEL_FLOAT_MAX_ALIGNMENT) % gcd(b, KERNEL_FLOAT_MAX_ALIGNMENT) == 0; +} + +template struct copy_aligned_impl { static constexpr size_t K = N > 8 ? 8 : (N > 4 ? 4 : (N > 2 ? 2 : 1)); - static constexpr size_t alignment_K = gcd(alignment, sizeof(T) * K); + static constexpr size_t Alignment_K = gcd(Alignment, sizeof(T) * K); KERNEL_FLOAT_INLINE static void load(T* output, const T* input) { - copy_aligned_impl::load(output, input); - copy_aligned_impl::load(output + K, input + K); + copy_aligned_impl::load(output, input); + copy_aligned_impl::load(output + K, input + K); } KERNEL_FLOAT_INLINE static void store(T* output, const T* input) { - copy_aligned_impl::store(output, input); - copy_aligned_impl::store(output + K, input + K); + copy_aligned_impl::store(output, input); + copy_aligned_impl::store(output + K, input + K); } }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { KERNEL_FLOAT_INLINE static void load(T* output, const T* input) {} @@ -139,8 +150,8 @@ struct copy_aligned_impl { static void store(T* output, const T* input) {} }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { using storage_type = T; KERNEL_FLOAT_INLINE @@ -154,9 +165,9 @@ struct copy_aligned_impl { } }; -template -struct copy_aligned_impl sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 2 * sizeof(T)); +template +struct copy_aligned_impl sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 2 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1; }; @@ -174,9 +185,9 @@ struct copy_aligned_impl sizeof(T))>> } }; -template -struct copy_aligned_impl 2 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 4 * sizeof(T)); +template +struct copy_aligned_impl 2 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 4 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3; }; @@ -200,9 +211,9 @@ struct copy_aligned_impl 2 * sizeof(T) } }; -template -struct copy_aligned_impl 4 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 8 * sizeof(T)); +template +struct copy_aligned_impl 4 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 8 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3, v4, v5, v6, v7; }; @@ -246,8 +257,8 @@ struct copy_aligned_impl 4 * sizeof(T) * // Load 4 elements at locations data[0], data[1], data[2], data[3] * vec values = read_aligned<4>(data); * - * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values2 = read_aligned<4>(data + 10); + * // Load 4 elements at locations data[12], data[13], data[14], data[15] + * vec values2 = read_aligned<4>(data + 12); * ``` */ template @@ -294,10 +305,13 @@ KERNEL_FLOAT_INLINE void write_aligned(T* ptr, const V& values) { * @tparam U The underlying storage type. Defaults to the same type as T. * @tparam Align The alignment constraint for read and write operations. */ -template +template struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; /** @@ -313,7 +327,12 @@ struct vector_ref { * @return vector_type A vector of type vector_type containing the read and converted data. */ KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert(result); } /** @@ -324,7 +343,9 @@ struct vector_ref { */ template KERNEL_FLOAT_INLINE void write(const V& values) const { - write_aligned(data_, convert(values)); + detail::copy_aligned_impl::store( + KERNEL_FLOAT_ASSUME_ALIGNED(U, data_, access_alignment), + convert_storage(values).data()); } /** @@ -357,16 +378,24 @@ struct vector_ref { /** * Specialization for `vector_ref` if the backing storage is const. */ -template -struct vector_ref { +template +struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = const U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; KERNEL_FLOAT_INLINE explicit vector_ref(pointer_type data) : data_(data) {} KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert_storage(result); } KERNEL_FLOAT_INLINE operator vector_type() const { @@ -381,13 +410,13 @@ struct vector_ref { pointer_type data_ = nullptr; }; -#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ - template \ - KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ - vector_ref ptr, \ - const V& value) { \ - ptr.write(ptr.read() OP value); \ - return ptr; \ +#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ + template \ + KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ + vector_ref ptr, \ + const V& value) { \ + ptr.write(ptr.read() OP value); \ + return ptr; \ } KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(+, +=) @@ -395,6 +424,17 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(-, -=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(*, *=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) +template +struct into_vector_impl> { + using value_type = T; + using extent_type = extent; + + KERNEL_FLOAT_INLINE + static vector_storage call(const vector_ref& reference) { + return reference.read(); + } +}; + /** * A wrapper for a pointer that enables vectorized access and supports type conversions.. * @@ -408,9 +448,11 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) * @tparam T The type of the elements as viewed by the user. * @tparam N The alignment of T in number of elements. * @tparam U The underlying storage type, defaults to T. + * @tparam Alignment The assumed alignment of the underlying U* pointer. */ -template +template struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = U*; using value_type = decay_t; @@ -428,44 +470,47 @@ struct vector_ptr { * Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor * only allows conversion if the alignment of the source is greater than or equal to the alignment of the target. */ - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} /** - * Accesses a reference to a vector at a specific index with optional alignment considerations. - * - * @tparam N The number of elements in the vector to access, defaults to the alignment. - * @param index The index at which to access the vector. + * Shorthand for `at(0)`. */ - template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE const vector_ref operator*() const { + return vector_ref {data_}; } /** - * Accesses a vector at a specific index. + * Accesses a reference to a vector at a specific index with optional alignment considerations. * - * @tparam K The number of elements to read, defaults to `N`. - * @param index The index from which to read the data. + * @tparam N The number of elements in the vector to access, defaults to N. + * @param index The index at which to access the vector. */ template - KERNEL_FLOAT_INLINE vector> read(size_t index) const { - return this->template at(index).read(); + KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { + return vector_ref {data_ + index * N}; } /** - * Shorthand for `read(index)`. + * Shorthand for `at(index)`. */ - KERNEL_FLOAT_INLINE const vector> operator[](size_t index) const { - return read(index); + KERNEL_FLOAT_INLINE vector_ref + operator[](size_t index) const { + return at(index); } /** - * Shorthand for `read(0)`. + * Accesses a vector at a specific index. + * + * @tparam K The number of elements to read, defaults to `N`. + * @param index The index from which to read the data. */ - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); + template + KERNEL_FLOAT_INLINE vector> read(size_t index = 0) const { + return this->template at(index).read(); } /** @@ -481,15 +526,6 @@ struct vector_ptr { this->template at(index).write(values); } - /** - * Shorthand for `at(index)`. Returns a vector reference to can be used - * to assign to this pointer, contrary to `operator[]` that does not - * allow assignment. - */ - KERNEL_FLOAT_INLINE vector_ref operator()(size_t index) const { - return at(index); - } - /** * Gets the raw data pointer managed by this `vector_ptr`. */ @@ -504,26 +540,36 @@ struct vector_ptr { /** * Specialization for `vector_ptr` if the backing storage is const. */ -template -struct vector_ptr { +template +struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = const U*; using value_type = decay_t; vector_ptr() = default; + KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {} - template - KERNEL_FLOAT_INLINE - vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} + KERNEL_FLOAT_INLINE vector_ref operator*() const { + return vector_ref {data_}; + } + template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE vector_ref + at(size_t index) const { + return vector_ref {data_ + index * N}; } template @@ -535,10 +581,6 @@ struct vector_ptr { return read(index); } - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); - } - KERNEL_FLOAT_INLINE pointer_type get() const { return data_; } @@ -547,44 +589,54 @@ struct vector_ptr { pointer_type data_ = nullptr; }; -template -KERNEL_FLOAT_INLINE vector_ptr operator+(vector_ptr p, size_t i) { - return vector_ptr {p.get() + i * N}; +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(vector_ptr p, size_t i) { + return vector_ptr {p.get() + i * N}; } -template -KERNEL_FLOAT_INLINE vector_ptr operator+(size_t i, vector_ptr p) { +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(size_t i, vector_ptr p) { return p + i; } +template< + typename T, + size_t N, + typename U, + size_t A, + typename = enable_if_t<(N * sizeof(U)) % A == 0>> +KERNEL_FLOAT_INLINE vector_ptr& operator+=(vector_ptr& p, size_t i) { + return p = p + i; +} + /** - * Creates a `vector_ptr` from a raw pointer `U*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. * - * @tparam T The type of the elements as viewed by the user. This type may differ from `U`. * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. - * @tparam U The type of the elements pointed to by the raw pointer. + * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(U* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } // Doxygen cannot deal with the `assert_aligned` being defined twice, we ignore the second definition. /// @cond IGNORE /** - * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*`. The alignment is assumed to be KERNEL_FLOAT_MAX_ALIGNMENT. * - * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } /// @endcond -template -using vec_ptr = vector_ptr; +template +using vec_ptr = vector_ptr; #if defined(__cpp_deduction_guides) template diff --git a/tests/memory.cu b/tests/memory.cu index 2621a92..7ab7ee4 100644 --- a/tests/memory.cu +++ b/tests/memory.cu @@ -156,7 +156,8 @@ struct vector_ptr_test { }; { - auto ptr = kf::vector_ptr {&storage.data[0]}; + kf::vector_ptr storage_ptr = kf::assert_aligned(storage.data); + kf::vector_ptr ptr = storage_ptr; ASSERT_EQ(ptr.get(), static_cast(storage.data)); T expected[N] = {T(double(N + I))...}; @@ -174,7 +175,8 @@ struct vector_ptr_test { } { - auto ptr = kf::vector_ptr {&storage.data[0]}; + kf::vector_ptr storage_ptr = kf::assert_aligned(storage.data); + kf::vector_ptr ptr = storage_ptr; ASSERT_EQ(ptr.get(), static_cast(storage.data)); T expected[N] = {T(double(N + I))...}; @@ -182,7 +184,7 @@ struct vector_ptr_test { auto a = ptr.read(1); ASSERT_EQ_ALL(a[I], expected[I]); - auto b = ptr[1]; + kf::vec b = ptr[1]; ASSERT_EQ_ALL(b[I], expected[I]); kf::vec c = ptr.at(1); @@ -191,16 +193,20 @@ struct vector_ptr_test { kf::vec overwrite = {T(double(100 + I))...}; ptr.at(1) = overwrite; - auto e = ptr[1]; + kf::vec e = ptr[1]; ASSERT_EQ_ALL(e[I], overwrite[I]); ptr.write(1, T(1337.0)); - auto f = ptr[1]; + kf::vec f = ptr[1]; ASSERT_EQ_ALL(f[I], T(1337.0)); ptr.at(1) += T(1.0); - auto g = ptr[1]; + kf::vec g = ptr[1]; ASSERT_EQ_ALL(g[I], T(1338.0)); + + kf::cast_to(ptr[1]) = double(3.14); + kf::vec h = ptr[1]; + ASSERT_EQ_ALL(h[I], T(3.14)); } } }; From 1611258545746b9d726df59b3a338048b591bcb0 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 27 Jan 2025 16:24:13 +0100 Subject: [PATCH 20/25] Remove `apply_fallback_impl` --- include/kernel_float/apply.h | 69 ++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 2698fa1..6b821cc 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -7,41 +7,41 @@ namespace kernel_float { namespace detail { template -struct broadcast_extent_helper; +struct broadcast_extent_impl; template -struct broadcast_extent_helper { +struct broadcast_extent_impl { using type = E; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent; }; template<> -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent<1>; }; template -struct broadcast_extent_helper: - broadcast_extent_helper::type, C, Rest...> {}; +struct broadcast_extent_impl: + broadcast_extent_impl::type, C, Rest...> {}; } // namespace detail template -using broadcast_extent = typename detail::broadcast_extent_helper::type; +using broadcast_extent = typename detail::broadcast_extent_impl::type; template using broadcast_vector_extent_type = broadcast_extent...>; @@ -128,7 +128,9 @@ struct accurate_policy {}; * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve * approximations that slightly compromise precision. */ -struct fast_policy {}; +struct fast_policy { + using fallback_policy = accurate_policy; +}; /** * This template policy allows developers to specify a custom degree of approximation for their computations. By @@ -136,7 +138,14 @@ struct fast_policy {}; * specific needs of your application. Higher values mean more precision. */ template -struct approx_level_policy {}; +struct approx_level_policy { + using fallback_policy = approx_level_policy<>; +}; + +template<> +struct approx_level_policy<> { + using fallback_policy = fast_policy; +}; /** * The approximate_policy serves as the default approximation policy, providing a standard level of approximation @@ -145,15 +154,17 @@ struct approx_level_policy {}; */ using approx_policy = approx_level_policy<>; -#ifndef KERNEL_FLOAT_POLICY -#define KERNEL_FLOAT_POLICY accurate_policy -#endif - /** * The `default_policy` acts as the standard computation policy. It can be configured externally using the - * `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`. + * `KERNEL_FLOAT_GLOBAL_POLICY` macro. If `KERNEL_FLOAT_GLOBAL_POLICY` is not defined, default to `accurate_policy`. */ +#if defined(KERNEL_FLOAT_GLOBAL_POLICY) +using default_policy = KERNEL_FLOAT_GLOBAL_POLICY; +#elif defined(KERNEL_FLOAT_POLICY) using default_policy = KERNEL_FLOAT_POLICY; +#else +using default_policy = accurate_policy; +#endif namespace detail { @@ -164,35 +175,15 @@ struct invoke_impl { } }; -// template -struct apply_fallback_impl { - KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { - static_assert(N > 0, "operation not implemented"); - } -}; +struct apply_impl; template -struct apply_base_impl: apply_fallback_impl {}; +struct apply_base_impl: apply_impl {}; template struct apply_impl: apply_base_impl {}; -// `fast_policy` falls back to `accurate_policy` -template -struct apply_fallback_impl: - apply_impl {}; - -// `approx_policy` falls back to `fast_policy` -template -struct apply_fallback_impl: - apply_impl {}; - -// `approx_level_policy` falls back to `approx_policy` -template -struct apply_fallback_impl, F, N, Output, Args...>: - apply_impl {}; - // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. template struct apply_impl { @@ -266,4 +257,4 @@ KERNEL_FLOAT_INLINE map_type map(F fun, const Args&... args) { } // namespace kernel_float -#endif // KERNEL_FLOAT_APPLY_H \ No newline at end of file +#endif // KERNEL_FLOAT_APPLY_H From c44c6ed2b3c8cb8e6040c1e286b7272736c86387 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 27 Jan 2025 16:24:32 +0100 Subject: [PATCH 21/25] Fix incorrect type name in `approx.h` --- include/kernel_float/approx.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index 6ad70d9..32a53aa 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -304,13 +304,13 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) { template KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) { bfloat16x2_t xf = normalize_trig_input(x); - return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); + return cos_poly::call(__hmul2(xf, xf)); } template KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) { bfloat16x2_t xf = normalize_trig_input(x); - return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); + return __hmul2(sin_poly::call(__hmul2(xf, xf)), xf); } template From 212efee63e07fbc8579c3110900476ca2cd6691e Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 27 Jan 2025 16:25:46 +0100 Subject: [PATCH 22/25] Fix incorrect type name in `binops.h` --- include/kernel_float/binops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 2e4c149..85ac823 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -325,7 +325,7 @@ namespace detail { // Override `pow` using `log2` and `exp2` template struct apply_base_impl, N, T, T, T> { - KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { + KERNEL_FLOAT_INLINE static void call(ops::pow, T* result, const T* lhs, const T* rhs) { T lhs_log[N]; T result_log[N]; From 09dc82096e4c013a079f0e315da1ccce17453c93 Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 27 Jan 2025 16:26:12 +0100 Subject: [PATCH 23/25] Change `AssignConversionProxy` to also accept rvalues --- include/kernel_float/conversion.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/include/kernel_float/conversion.h b/include/kernel_float/conversion.h index 8e84cdb..881cc51 100644 --- a/include/kernel_float/conversion.h +++ b/include/kernel_float/conversion.h @@ -87,11 +87,11 @@ KERNEL_FLOAT_INLINE vector> convert(const V& input, extent new_s template struct AssignConversionProxy { KERNEL_FLOAT_INLINE - explicit AssignConversionProxy(T* ptr) : ptr_(ptr) {} + explicit AssignConversionProxy(T&& ptr) : ptr_(std::forward(ptr)) {} template KERNEL_FLOAT_INLINE AssignConversionProxy& operator=(U&& values) { - *ptr_ = detail::convert_impl< + ptr_ = detail::convert_impl< vector_value_type, vector_extent_type, vector_value_type, @@ -102,12 +102,12 @@ struct AssignConversionProxy { } private: - T* ptr_; + T ptr_; }; /** * Takes a vector reference and gives back a helper object. This object allows you to assign - * a vector of a different type to another vector while perofrming implicit type converion. + * a vector of a different type to another vector while performing implicit type conversion. * * For example, if `x = expression;` does not compile because `x` and `expression` are * different vector types, you can use `cast_to(x) = expression;` to make it work. @@ -120,9 +120,9 @@ struct AssignConversionProxy { * cast_to(x) = y; // Normally, `x = y;` would give an error, but `cast_to` fixes that. * ``` */ -template -KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { - return AssignConversionProxy(&input); +template +KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T&& input) { + return AssignConversionProxy(std::forward(input)); } /** @@ -135,7 +135,7 @@ KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { * ``` */ template -KERNEL_FLOAT_INLINE vector> fill(T value = {}, extent = {}) { +KERNEL_FLOAT_INLINE vector> fill(T value, extent = {}) { vector_storage input = {value}; return detail::broadcast_impl, extent>::call(input); } From 126737c3e3e682fe421cf186fb93b78538f448eb Mon Sep 17 00:00:00 2001 From: stijn Date: Mon, 27 Jan 2025 16:26:54 +0100 Subject: [PATCH 24/25] Change `ops::cast` to get rid of `cast_float_fallback` --- include/kernel_float/unops.h | 58 ++--- single_include/kernel_float.h | 401 +++++++++++++++++++--------------- 2 files changed, 257 insertions(+), 202 deletions(-) diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 12baef5..81ddec0 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -20,43 +20,49 @@ struct cast { }; template -struct cast { +struct cast { KERNEL_FLOAT_INLINE T operator()(T input) noexcept { return input; } }; -template -struct cast_float_fallback; - -template -struct cast_float_fallback { +template +struct cast { KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return R(input); + if constexpr ( + detail::allow_float_fallback::value || detail::allow_float_fallback::value) { + return cast {}(cast {}(input)); + } else { + return R(input); + } } }; -// clang-format off -template -struct cast_float_fallback< - T, - R, - enable_if_t< - !is_same_type && - !is_same_type && - (detail::allow_float_fallback::value || detail::allow_float_fallback::value) - > -> { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast {}(cast {}(input)); +template<> +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; } }; -// clang-format on -template -struct cast { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast_float_fallback {}(input); +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(T input) noexcept { + return float(input); + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(float input) noexcept { + return T(input); } }; @@ -255,7 +261,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) } KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.ftz.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index c77c7e6..6d1754b 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-12-02 18:48:50.243676 -// git hash: 846de1f9aefaef76da15ebb5474080d531efaf38 +// date: 2025-01-27 16:26:28.827757 +// git hash: 09dc82096e4c013a079f0e315da1ccce17453c93 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -673,41 +673,41 @@ namespace kernel_float { namespace detail { template -struct broadcast_extent_helper; +struct broadcast_extent_impl; template -struct broadcast_extent_helper { +struct broadcast_extent_impl { using type = E; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent; }; template<> -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent<1>; }; template -struct broadcast_extent_helper: - broadcast_extent_helper::type, C, Rest...> {}; +struct broadcast_extent_impl: + broadcast_extent_impl::type, C, Rest...> {}; } // namespace detail template -using broadcast_extent = typename detail::broadcast_extent_helper::type; +using broadcast_extent = typename detail::broadcast_extent_impl::type; template using broadcast_vector_extent_type = broadcast_extent...>; @@ -794,7 +794,9 @@ struct accurate_policy {}; * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve * approximations that slightly compromise precision. */ -struct fast_policy {}; +struct fast_policy { + using fallback_policy = accurate_policy; +}; /** * This template policy allows developers to specify a custom degree of approximation for their computations. By @@ -802,7 +804,14 @@ struct fast_policy {}; * specific needs of your application. Higher values mean more precision. */ template -struct approx_level_policy {}; +struct approx_level_policy { + using fallback_policy = approx_level_policy<>; +}; + +template<> +struct approx_level_policy<> { + using fallback_policy = fast_policy; +}; /** * The approximate_policy serves as the default approximation policy, providing a standard level of approximation @@ -811,15 +820,17 @@ struct approx_level_policy {}; */ using approx_policy = approx_level_policy<>; -#ifndef KERNEL_FLOAT_POLICY -#define KERNEL_FLOAT_POLICY accurate_policy -#endif - /** * The `default_policy` acts as the standard computation policy. It can be configured externally using the - * `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`. + * `KERNEL_FLOAT_GLOBAL_POLICY` macro. If `KERNEL_FLOAT_GLOBAL_POLICY` is not defined, default to `accurate_policy`. */ +#if defined(KERNEL_FLOAT_GLOBAL_POLICY) +using default_policy = KERNEL_FLOAT_GLOBAL_POLICY; +#elif defined(KERNEL_FLOAT_POLICY) using default_policy = KERNEL_FLOAT_POLICY; +#else +using default_policy = accurate_policy; +#endif namespace detail { @@ -830,35 +841,15 @@ struct invoke_impl { } }; -// template -struct apply_fallback_impl { - KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { - static_assert(N > 0, "operation not implemented"); - } -}; +struct apply_impl; template -struct apply_base_impl: apply_fallback_impl {}; +struct apply_base_impl: apply_impl {}; template struct apply_impl: apply_base_impl {}; -// `fast_policy` falls back to `accurate_policy` -template -struct apply_fallback_impl: - apply_impl {}; - -// `approx_policy` falls back to `fast_policy` -template -struct apply_fallback_impl: - apply_impl {}; - -// `approx_level_policy` falls back to `approx_policy` -template -struct apply_fallback_impl, F, N, Output, Args...>: - apply_impl {}; - // Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. template struct apply_impl { @@ -1204,43 +1195,49 @@ struct cast { }; template -struct cast { +struct cast { KERNEL_FLOAT_INLINE T operator()(T input) noexcept { return input; } }; -template -struct cast_float_fallback; - -template -struct cast_float_fallback { +template +struct cast { KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return R(input); + if constexpr ( + detail::allow_float_fallback::value || detail::allow_float_fallback::value) { + return cast {}(cast {}(input)); + } else { + return R(input); + } } }; -// clang-format off -template -struct cast_float_fallback< - T, - R, - enable_if_t< - !is_same_type && - !is_same_type && - (detail::allow_float_fallback::value || detail::allow_float_fallback::value) - > -> { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast {}(cast {}(input)); +template<> +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; } }; -// clang-format on -template -struct cast { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast_float_fallback {}(input); +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(T input) noexcept { + return float(input); + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(float input) noexcept { + return T(input); } }; @@ -1439,7 +1436,7 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) } KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.ftz.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") @@ -1552,11 +1549,11 @@ KERNEL_FLOAT_INLINE vector> convert(const V& input, extent new_s template struct AssignConversionProxy { KERNEL_FLOAT_INLINE - explicit AssignConversionProxy(T* ptr) : ptr_(ptr) {} + explicit AssignConversionProxy(T&& ptr) : ptr_(std::forward(ptr)) {} template KERNEL_FLOAT_INLINE AssignConversionProxy& operator=(U&& values) { - *ptr_ = detail::convert_impl< + ptr_ = detail::convert_impl< vector_value_type, vector_extent_type, vector_value_type, @@ -1567,12 +1564,12 @@ struct AssignConversionProxy { } private: - T* ptr_; + T ptr_; }; /** * Takes a vector reference and gives back a helper object. This object allows you to assign - * a vector of a different type to another vector while perofrming implicit type converion. + * a vector of a different type to another vector while performing implicit type conversion. * * For example, if `x = expression;` does not compile because `x` and `expression` are * different vector types, you can use `cast_to(x) = expression;` to make it work. @@ -1585,9 +1582,9 @@ struct AssignConversionProxy { * cast_to(x) = y; // Normally, `x = y;` would give an error, but `cast_to` fixes that. * ``` */ -template -KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { - return AssignConversionProxy(&input); +template +KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T&& input) { + return AssignConversionProxy(std::forward(input)); } /** @@ -1600,7 +1597,7 @@ KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { * ``` */ template -KERNEL_FLOAT_INLINE vector> fill(T value = {}, extent = {}) { +KERNEL_FLOAT_INLINE vector> fill(T value, extent = {}) { vector_storage input = {value}; return detail::broadcast_impl, extent>::call(input); } @@ -2010,7 +2007,7 @@ namespace detail { // Override `pow` using `log2` and `exp2` template struct apply_base_impl, N, T, T, T> { - KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { + KERNEL_FLOAT_INLINE static void call(ops::pow, T* result, const T* lhs, const T* rhs) { T lhs_log[N]; T result_log[N]; @@ -2574,7 +2571,7 @@ struct copy_impl> { * * ``` * // Load 2 elements at data[0] and data[8], skip data[2] and data[4] - * vec values = = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); + * vec values = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); * ``` */ template> @@ -2621,7 +2618,7 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const I& indices, const V& values, const * vec values = read<4>(data); * * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values = read<4>(values + 10, data); + * vec values = read<4>(data + 10); * ``` */ template @@ -2648,31 +2645,42 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const V& values) { } namespace detail { +/** + * Returns the greatest common divisor of `a` and `b`. + */ KERNEL_FLOAT_INLINE constexpr size_t gcd(size_t a, size_t b) { return b == 0 ? a : gcd(b, a % b); } -template +/** + * Returns true if a pointer having alignment of `a` bytes also has an alignment of `b` bytes. Returns false otherwise. + */ +KERNEL_FLOAT_INLINE +constexpr size_t alignment_divisible(size_t a, size_t b) { + return gcd(a, KERNEL_FLOAT_MAX_ALIGNMENT) % gcd(b, KERNEL_FLOAT_MAX_ALIGNMENT) == 0; +} + +template struct copy_aligned_impl { static constexpr size_t K = N > 8 ? 8 : (N > 4 ? 4 : (N > 2 ? 2 : 1)); - static constexpr size_t alignment_K = gcd(alignment, sizeof(T) * K); + static constexpr size_t Alignment_K = gcd(Alignment, sizeof(T) * K); KERNEL_FLOAT_INLINE static void load(T* output, const T* input) { - copy_aligned_impl::load(output, input); - copy_aligned_impl::load(output + K, input + K); + copy_aligned_impl::load(output, input); + copy_aligned_impl::load(output + K, input + K); } KERNEL_FLOAT_INLINE static void store(T* output, const T* input) { - copy_aligned_impl::store(output, input); - copy_aligned_impl::store(output + K, input + K); + copy_aligned_impl::store(output, input); + copy_aligned_impl::store(output + K, input + K); } }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { KERNEL_FLOAT_INLINE static void load(T* output, const T* input) {} @@ -2680,8 +2688,8 @@ struct copy_aligned_impl { static void store(T* output, const T* input) {} }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { using storage_type = T; KERNEL_FLOAT_INLINE @@ -2695,9 +2703,9 @@ struct copy_aligned_impl { } }; -template -struct copy_aligned_impl sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 2 * sizeof(T)); +template +struct copy_aligned_impl sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 2 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1; }; @@ -2715,9 +2723,9 @@ struct copy_aligned_impl sizeof(T))>> } }; -template -struct copy_aligned_impl 2 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 4 * sizeof(T)); +template +struct copy_aligned_impl 2 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 4 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3; }; @@ -2741,9 +2749,9 @@ struct copy_aligned_impl 2 * sizeof(T) } }; -template -struct copy_aligned_impl 4 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 8 * sizeof(T)); +template +struct copy_aligned_impl 4 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 8 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3, v4, v5, v6, v7; }; @@ -2787,8 +2795,8 @@ struct copy_aligned_impl 4 * sizeof(T) * // Load 4 elements at locations data[0], data[1], data[2], data[3] * vec values = read_aligned<4>(data); * - * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values2 = read_aligned<4>(data + 10); + * // Load 4 elements at locations data[12], data[13], data[14], data[15] + * vec values2 = read_aligned<4>(data + 12); * ``` */ template @@ -2835,10 +2843,13 @@ KERNEL_FLOAT_INLINE void write_aligned(T* ptr, const V& values) { * @tparam U The underlying storage type. Defaults to the same type as T. * @tparam Align The alignment constraint for read and write operations. */ -template +template struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; /** @@ -2854,7 +2865,12 @@ struct vector_ref { * @return vector_type A vector of type vector_type containing the read and converted data. */ KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert(result); } /** @@ -2865,7 +2881,9 @@ struct vector_ref { */ template KERNEL_FLOAT_INLINE void write(const V& values) const { - write_aligned(data_, convert(values)); + detail::copy_aligned_impl::store( + KERNEL_FLOAT_ASSUME_ALIGNED(U, data_, access_alignment), + convert_storage(values).data()); } /** @@ -2898,16 +2916,24 @@ struct vector_ref { /** * Specialization for `vector_ref` if the backing storage is const. */ -template -struct vector_ref { +template +struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = const U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; KERNEL_FLOAT_INLINE explicit vector_ref(pointer_type data) : data_(data) {} KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert_storage(result); } KERNEL_FLOAT_INLINE operator vector_type() const { @@ -2922,13 +2948,13 @@ struct vector_ref { pointer_type data_ = nullptr; }; -#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ - template \ - KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ - vector_ref ptr, \ - const V& value) { \ - ptr.write(ptr.read() OP value); \ - return ptr; \ +#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ + template \ + KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ + vector_ref ptr, \ + const V& value) { \ + ptr.write(ptr.read() OP value); \ + return ptr; \ } KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(+, +=) @@ -2936,6 +2962,17 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(-, -=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(*, *=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) +template +struct into_vector_impl> { + using value_type = T; + using extent_type = extent; + + KERNEL_FLOAT_INLINE + static vector_storage call(const vector_ref& reference) { + return reference.read(); + } +}; + /** * A wrapper for a pointer that enables vectorized access and supports type conversions.. * @@ -2949,9 +2986,11 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) * @tparam T The type of the elements as viewed by the user. * @tparam N The alignment of T in number of elements. * @tparam U The underlying storage type, defaults to T. + * @tparam Alignment The assumed alignment of the underlying U* pointer. */ -template +template struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = U*; using value_type = decay_t; @@ -2969,44 +3008,47 @@ struct vector_ptr { * Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor * only allows conversion if the alignment of the source is greater than or equal to the alignment of the target. */ - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} /** - * Accesses a reference to a vector at a specific index with optional alignment considerations. - * - * @tparam N The number of elements in the vector to access, defaults to the alignment. - * @param index The index at which to access the vector. + * Shorthand for `at(0)`. */ - template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE const vector_ref operator*() const { + return vector_ref {data_}; } /** - * Accesses a vector at a specific index. + * Accesses a reference to a vector at a specific index with optional alignment considerations. * - * @tparam K The number of elements to read, defaults to `N`. - * @param index The index from which to read the data. + * @tparam N The number of elements in the vector to access, defaults to N. + * @param index The index at which to access the vector. */ template - KERNEL_FLOAT_INLINE vector> read(size_t index) const { - return this->template at(index).read(); + KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { + return vector_ref {data_ + index * N}; } /** - * Shorthand for `read(index)`. + * Shorthand for `at(index)`. */ - KERNEL_FLOAT_INLINE const vector> operator[](size_t index) const { - return read(index); + KERNEL_FLOAT_INLINE vector_ref + operator[](size_t index) const { + return at(index); } /** - * Shorthand for `read(0)`. + * Accesses a vector at a specific index. + * + * @tparam K The number of elements to read, defaults to `N`. + * @param index The index from which to read the data. */ - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); + template + KERNEL_FLOAT_INLINE vector> read(size_t index = 0) const { + return this->template at(index).read(); } /** @@ -3022,15 +3064,6 @@ struct vector_ptr { this->template at(index).write(values); } - /** - * Shorthand for `at(index)`. Returns a vector reference to can be used - * to assign to this pointer, contrary to `operator[]` that does not - * allow assignment. - */ - KERNEL_FLOAT_INLINE vector_ref operator()(size_t index) const { - return at(index); - } - /** * Gets the raw data pointer managed by this `vector_ptr`. */ @@ -3045,26 +3078,36 @@ struct vector_ptr { /** * Specialization for `vector_ptr` if the backing storage is const. */ -template -struct vector_ptr { +template +struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = const U*; using value_type = decay_t; vector_ptr() = default; + KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {} - template - KERNEL_FLOAT_INLINE - vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} + KERNEL_FLOAT_INLINE vector_ref operator*() const { + return vector_ref {data_}; + } + template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE vector_ref + at(size_t index) const { + return vector_ref {data_ + index * N}; } template @@ -3076,10 +3119,6 @@ struct vector_ptr { return read(index); } - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); - } - KERNEL_FLOAT_INLINE pointer_type get() const { return data_; } @@ -3088,44 +3127,54 @@ struct vector_ptr { pointer_type data_ = nullptr; }; -template -KERNEL_FLOAT_INLINE vector_ptr operator+(vector_ptr p, size_t i) { - return vector_ptr {p.get() + i * N}; +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(vector_ptr p, size_t i) { + return vector_ptr {p.get() + i * N}; } -template -KERNEL_FLOAT_INLINE vector_ptr operator+(size_t i, vector_ptr p) { +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(size_t i, vector_ptr p) { return p + i; } +template< + typename T, + size_t N, + typename U, + size_t A, + typename = enable_if_t<(N * sizeof(U)) % A == 0>> +KERNEL_FLOAT_INLINE vector_ptr& operator+=(vector_ptr& p, size_t i) { + return p = p + i; +} + /** - * Creates a `vector_ptr` from a raw pointer `U*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. * - * @tparam T The type of the elements as viewed by the user. This type may differ from `U`. * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. - * @tparam U The type of the elements pointed to by the raw pointer. + * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(U* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } // Doxygen cannot deal with the `assert_aligned` being defined twice, we ignore the second definition. /// @cond IGNORE /** - * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*`. The alignment is assumed to be KERNEL_FLOAT_MAX_ALIGNMENT. * - * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } /// @endcond -template -using vec_ptr = vector_ptr; +template +using vec_ptr = vector_ptr; #if defined(__cpp_deduction_guides) template @@ -4856,13 +4905,13 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) { template KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) { bfloat16x2_t xf = normalize_trig_input(x); - return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); + return cos_poly::call(__hmul2(xf, xf)); } template KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) { bfloat16x2_t xf = normalize_trig_input(x); - return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); + return __hmul2(sin_poly::call(__hmul2(xf, xf)), xf); } template From 4d185630a1d8c766b0154d0a3388d2c29f657142 Mon Sep 17 00:00:00 2001 From: stijn Date: Fri, 11 Apr 2025 09:56:12 +0200 Subject: [PATCH 25/25] Update `Jimver/cuda-toolkit` workflow action to newer version --- .github/workflows/cmake-action.yml | 2 +- .github/workflows/cmake.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/cmake-action.yml b/.github/workflows/cmake-action.yml index 37cdf95..a47a820 100644 --- a/.github/workflows/cmake-action.yml +++ b/.github/workflows/cmake-action.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: Jimver/cuda-toolkit@v0.2.11 + - uses: Jimver/cuda-toolkit@v0.2.22 id: cuda-toolkit with: method: network diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index 136fcd3..3384b57 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -13,16 +13,16 @@ jobs: build-cuda: uses: ./.github/workflows/cmake-action.yml with: - cuda-version: "12.2.0" + cuda-version: "12.8.0" - build-cuda-11-7: + build-cuda-12-6: needs: build-cuda uses: ./.github/workflows/cmake-action.yml with: - cuda-version: "11.7.0" + cuda-version: "12.6.0" - build-cuda-12-0: + build-cuda-12-5: needs: build-cuda uses: ./.github/workflows/cmake-action.yml with: - cuda-version: "12.0.0" + cuda-version: "12.5.0"