diff --git a/.github/workflows/cmake-action.yml b/.github/workflows/cmake-action.yml index fd621db..a47a820 100644 --- a/.github/workflows/cmake-action.yml +++ b/.github/workflows/cmake-action.yml @@ -19,7 +19,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: Jimver/cuda-toolkit@v0.2.11 + - uses: Jimver/cuda-toolkit@v0.2.22 id: cuda-toolkit with: method: network @@ -33,7 +33,7 @@ jobs: - name: Configure CMake # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type - run: cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DKERNEL_FLOAT_BUILD_TEST=1 -DKERNEL_FLOAT_BUILD_EXAMPLE=1 + run: CUDAARCHS=all cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} -DKERNEL_FLOAT_BUILD_TEST=1 -DKERNEL_FLOAT_BUILD_EXAMPLE=1 - name: Build # Build your program with the given configuration diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index 136fcd3..3384b57 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -13,16 +13,16 @@ jobs: build-cuda: uses: ./.github/workflows/cmake-action.yml with: - cuda-version: "12.2.0" + cuda-version: "12.8.0" - build-cuda-11-7: + build-cuda-12-6: needs: build-cuda uses: ./.github/workflows/cmake-action.yml with: - cuda-version: "11.7.0" + cuda-version: "12.6.0" - build-cuda-12-0: + build-cuda-12-5: needs: build-cuda uses: ./.github/workflows/cmake-action.yml with: - cuda-version: "12.0.0" + cuda-version: "12.5.0" diff --git a/CMakeLists.txt b/CMakeLists.txt index f563f74..bd5c28c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,13 +1,34 @@ cmake_minimum_required(VERSION 3.20) set (PROJECT_NAME kernel_float) -project(${PROJECT_NAME} CXX CUDA) +project(${PROJECT_NAME} LANGUAGES CXX) -set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +# Validate and enable the appropriate language +if (NOT DEFINED KERNEL_FLOAT_LANGUAGE) + set(KERNEL_FLOAT_LANGUAGE "CUDA") +endif() + +if (KERNEL_FLOAT_LANGUAGE STREQUAL "CUDA") + enable_language(CUDA) + set(KERNEL_FLOAT_LANGUAGE_CUDA ON) +elseif (KERNEL_FLOAT_LANGUAGE STREQUAL "HIP") + enable_language(HIP) + set(KERNEL_FLOAT_LANGUAGE_HIP ON) +else() + message(FATAL_ERROR "KERNEL_FLOAT_LANGUAGE must be either 'HIP' or 'CUDA'") +endif() + +# Create an interface library for kernel_float add_library(${PROJECT_NAME} INTERFACE) target_include_directories(${PROJECT_NAME} INTERFACE "${PROJECT_SOURCE_DIR}/include") +# Optionally build tests and examples if the corresponding flags are set +option(KERNEL_FLOAT_BUILD_TEST "Build kernel float tests" OFF) +option(KERNEL_FLOAT_BUILD_EXAMPLE "Build kernel float examples" OFF) + if (KERNEL_FLOAT_BUILD_TEST) add_subdirectory(tests) endif() @@ -15,3 +36,9 @@ endif() if (KERNEL_FLOAT_BUILD_EXAMPLE) add_subdirectory(examples) endif() + +# Display configuration +message(STATUS "=== Kernel Float ===") +message(STATUS "Using GPU Language: ${KERNEL_FLOAT_LANGUAGE}") +message(STATUS "Building Tests: ${KERNEL_FLOAT_BUILD_TEST}") +message(STATUS "Building Examples: ${KERNEL_FLOAT_BUILD_EXAMPLE}") \ No newline at end of file diff --git a/README.md b/README.md index 0670187..79748b0 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,12 @@ ![GitHub Repo stars](https://img.shields.io/github/stars/KernelTuner/kernel_float?style=social) -_Kernel Float_ is a header-only library for CUDA that simplifies working with vector types and reduced precision floating-point arithmetic in GPU code. +_Kernel Float_ is a header-only library for CUDA/HIP that simplifies working with vector types and reduced precision floating-point arithmetic in GPU code. ## Summary -CUDA natively offers several reduced precision floating-point types (`__half`, `__nv_bfloat16`, `__nv_fp8_e4m3`, `__nv_fp8_e5m2`) +CUDA/HIP natively offers several reduced precision floating-point types (`__half`, `__nv_bfloat16`, `__nv_fp8_e4m3`, `__nv_fp8_e5m2`) and vector types (e.g., `__half2`, `__nv_fp8x4_e4m3`, `float3`). However, working with these types is cumbersome: mathematical operations require intrinsics (e.g., `__hadd2` performs addition for `__half2`), @@ -24,9 +24,9 @@ and some functionality is missing (e.g., one cannot convert a `__half` to `__nv_ _Kernel Float_ resolves this by offering a single data type `kernel_float::vec` that stores `N` elements of type `T`. Internally, the data is stored as a fixed-sized array of elements. Operator overloading (like `+`, `*`, `&&`) has been implemented such that the most optimal intrinsic for the available types is selected automatically. -Many mathetical functions (like `log`, `exp`, `sin`) and common operations (such as `sum`, `range`, `for_each`) are also available. +Many mathematical functions (like `log`, `exp`, `sin`) and common operations (such as `sum`, `range`, `for_each`) are also available. -By using this library, developers can avoid the complexity of working with reduced precision floating-point types in CUDA and focus on their applications. +Using Kernel Float, developers avoid the complexity of reduced precision floating-point types in CUDA and can focus on their applications. ## Features @@ -40,6 +40,7 @@ In a nutshell, _Kernel Float_ offers the following features: * Easy integration as a single header file. * Written for C++17. * Compatible with NVCC (NVIDIA Compiler) and NVRTC (NVIDIA Runtime Compilation). +* Compatible with HIPCC (AMD HIP Compiler) ## Example @@ -49,7 +50,7 @@ Check out the [examples](https://github.com/KernelTuner/kernel_float/tree/master Below shows a simple example of a CUDA kernel that adds a `constant` to the `input` array and writes the results to the `output` array. Each thread processes two elements. -Notice how easy it would be change the precision (for example, `double` to `half`) or the vector size (for example, 4 instead of 2 items per thread). +Notice how easy it would be to change the precision (for example, `double` to `half`) or the vector size (for example, 4 instead of 2 items per thread). ```cpp @@ -63,14 +64,14 @@ __global__ void kernel(const kf::vec* input, float constant, kf::vec c = kf::fast_rcp(x); kf::vec d = kf::fast_div(a, b); ``` -These functions are only functional for 32-bit and 16-bit floats. +These functions are only functional for 32-bit and 16-bit floats. For other input types, the operation falls back to the regular version. ## Approximate Math -For 16-bit floats, several approximate functions are provided. -These use approximations (typically low-degree polynomials) to calculate rough estimates of the functions. +For 16-bit floats, several approximate functions are provided. +These use approximations (typically low-degree polynomials) to calculate rough estimates of the functions. This can be very fast but also less accurate. @@ -69,14 +69,15 @@ kf::vec a = kf::approx_sin<3>(x); ## Tuning Accuracy Level -Many functions in Kernel Float accept an additional Accuracy option as a template parameter. +Many functions in Kernel Float accept an additional `Accuracy` option as a template parameter. This allows you to tune the accuracy level without changing the function name. -There are four possible values for this parameter: +There are five possible values for this parameter: - `kf::accurate_policy`: Use the most accurate version of the function available. - `kf::fast_policy`: Use the "fast math" version. -- `kf::approx_policy`: Use the approximate version with degree `N`. +- `kf::approx_level_policy`: Use the approximate version with accuracy level `N` (higher is more accurate). +- `kf::approx_policy`: Use the approximate version with a default accuracy level. - `kf::default_policy`: Use a global default policy (see the next section). For example, consider this code: @@ -97,15 +98,19 @@ kf::vec c = kf::cos(input); kf::vec d = kf::cos(input); // Use the approximate policy -kf::vec e = kf::cos>(input); +kf::vec e = kf::cos(input); + +// Use the approximate policy with degree 3 polynomial. +kf::vec f = kf::cos>(input); // You can use aliases to define your own policy using my_own_policy = kf::fast_policy; -kf::vec f = kf::cos(input); +kf::vec g = kf::cos(input); ``` ## Setting `default_policy` +If no policy is explicitly set, any function use the `kf::default_policy`. By default, `kf::default_policy` is set to `kf::accurate_policy`. Set the preprocessor option `KERNEL_FLOAT_FAST_MATH=1` to change the default policy to `kf::fast_policy`. diff --git a/example.cu b/example.cu new file mode 100644 index 0000000..f15e3e7 --- /dev/null +++ b/example.cu @@ -0,0 +1,13 @@ +#include "kernel_float.h" +#include + +namespace kf = kernel_float; + +__global__ void kernel( + kf::vec_ptr input, + float constant, + kf::vec_ptr output +) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + output(i) = input[i] + kf::cast(constant); +} diff --git a/examples/hip_compat.h b/examples/hip_compat.h new file mode 100644 index 0000000..f4c0b87 --- /dev/null +++ b/examples/hip_compat.h @@ -0,0 +1,22 @@ +#pragma once + +/** + * This header file provides a mapping from CUDA-specific function names and types to their equivalent HIP + * counterparts, allowing for cross-platform development between CUDA and HIP. By including this header, code + * originally written for CUDA can be compiled with the HIP compiler (hipcc) by automatically replacing CUDA API + * calls with their HIP equivalents. + */ +#ifdef __HIPCC__ +#define cudaError_t hipError_t +#define cudaSuccess hipSuccess +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaFree hipFree +#define cudaMemcpy hipMemcpy +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyDefault hipMemcpyDefault +#define cudaMemset hipMemset +#define cudaSetDevice hipSetDevice +#define cudaDeviceSynchronize hipDeviceSynchronize +#endif \ No newline at end of file diff --git a/examples/pi/CMakeLists.txt b/examples/pi/CMakeLists.txt index 6a767d3..92e9108 100644 --- a/examples/pi/CMakeLists.txt +++ b/examples/pi/CMakeLists.txt @@ -1,12 +1,18 @@ -cmake_minimum_required(VERSION 3.17) +cmake_minimum_required(VERSION 3.20) set (PROJECT_NAME kernel_float_pi) -project(${PROJECT_NAME} LANGUAGES CXX CUDA) -set (CMAKE_CXX_STANDARD 17) +project(${PROJECT_NAME} LANGUAGES CXX) +set (CMAKE_CXX_STANDARD 17) add_executable(${PROJECT_NAME} "${PROJECT_SOURCE_DIR}/main.cu") target_link_libraries(${PROJECT_NAME} kernel_float) -set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") -find_package(CUDA REQUIRED) -target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") +endif() + +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties("${PROJECT_SOURCE_DIR}/main.cu" PROPERTIES LANGUAGE HIP) +endif() \ No newline at end of file diff --git a/examples/pi/main.cu b/examples/pi/main.cu index 58edd8b..970a2ff 100644 --- a/examples/pi/main.cu +++ b/examples/pi/main.cu @@ -1,6 +1,7 @@ #include #include +#include "../hip_compat.h" #include "kernel_float.h" #define CUDA_CHECK(call) \ @@ -9,12 +10,12 @@ if (__err != cudaSuccess) { \ fprintf( \ stderr, \ - "CUDA error at %s:%d code=%d(%s) \"%s\" \n", \ + "CUDA error at %s:%d (%s): %s (code %d) \n", \ __FILE__, \ __LINE__, \ - __err, \ + #call, \ cudaGetErrorString(__err), \ - #call); \ + __err); \ exit(EXIT_FAILURE); \ } \ } while (0) diff --git a/examples/vector_add/CMakeLists.txt b/examples/vector_add/CMakeLists.txt index b4d7bb8..eec05b9 100644 --- a/examples/vector_add/CMakeLists.txt +++ b/examples/vector_add/CMakeLists.txt @@ -1,12 +1,18 @@ cmake_minimum_required(VERSION 3.17) set (PROJECT_NAME kernel_float_vecadd) -project(${PROJECT_NAME} LANGUAGES CXX CUDA) -set (CMAKE_CXX_STANDARD 17) +project(${PROJECT_NAME} LANGUAGES CXX) +set (CMAKE_CXX_STANDARD 17) add_executable(${PROJECT_NAME} "${PROJECT_SOURCE_DIR}/main.cu") target_link_libraries(${PROJECT_NAME} kernel_float) -set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") -find_package(CUDA REQUIRED) -target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties("${PROJECT_SOURCE_DIR}/main.cu" PROPERTIES LANGUAGE HIP) +endif() + +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") +endif() diff --git a/examples/vector_add/main.cu b/examples/vector_add/main.cu index 705cf8c..2ac2eb4 100644 --- a/examples/vector_add/main.cu +++ b/examples/vector_add/main.cu @@ -3,6 +3,7 @@ #include #include +#include "../hip_compat.h" #include "kernel_float.h" namespace kf = kernel_float; @@ -21,7 +22,7 @@ __global__ void my_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * N < length) { - output(i) = kf::fma(input[i], input[i], kf::cast<__half>(constant)); + output[i] = kf::fma(input[i], input[i], kf::cast(constant)); } } diff --git a/examples/vector_add_tiling/CMakeLists.txt b/examples/vector_add_tiling/CMakeLists.txt index a744c34..1992272 100644 --- a/examples/vector_add_tiling/CMakeLists.txt +++ b/examples/vector_add_tiling/CMakeLists.txt @@ -1,12 +1,18 @@ cmake_minimum_required(VERSION 3.17) set (PROJECT_NAME kernel_float_vecadd_tiling) -project(${PROJECT_NAME} LANGUAGES CXX CUDA) -set (CMAKE_CXX_STANDARD 17) +project(${PROJECT_NAME} LANGUAGES CXX) +set (CMAKE_CXX_STANDARD 17) add_executable(${PROJECT_NAME} "${PROJECT_SOURCE_DIR}/main.cu") target_link_libraries(${PROJECT_NAME} kernel_float) -set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") -find_package(CUDA REQUIRED) -target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties("${PROJECT_SOURCE_DIR}/main.cu" PROPERTIES LANGUAGE HIP) +endif() + +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(${PROJECT_NAME} PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + set_target_properties(${PROJECT_NAME} PROPERTIES CUDA_ARCHITECTURES "80") +endif() diff --git a/examples/vector_add_tiling/main.cu b/examples/vector_add_tiling/main.cu index ddd28ab..14a0983 100644 --- a/examples/vector_add_tiling/main.cu +++ b/examples/vector_add_tiling/main.cu @@ -3,6 +3,7 @@ #include #include +#include "../hip_compat.h" #include "kernel_float.h" #include "kernel_float/tiling.h" namespace kf = kernel_float; diff --git a/include/kernel_float.h b/include/kernel_float.h index ee098fc..be7e21c 100644 --- a/include/kernel_float.h +++ b/include/kernel_float.h @@ -1,6 +1,7 @@ #ifndef KERNEL_FLOAT_H #define KERNEL_FLOAT_H +#include "kernel_float/approx.h" #include "kernel_float/base.h" #include "kernel_float/bf16.h" #include "kernel_float/binops.h" diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 3a7e02c..6b821cc 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -7,41 +7,41 @@ namespace kernel_float { namespace detail { template -struct broadcast_extent_helper; +struct broadcast_extent_impl; template -struct broadcast_extent_helper { +struct broadcast_extent_impl { using type = E; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent; }; template<> -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent<1>; }; template -struct broadcast_extent_helper: - broadcast_extent_helper::type, C, Rest...> {}; +struct broadcast_extent_impl: + broadcast_extent_impl::type, C, Rest...> {}; } // namespace detail template -using broadcast_extent = typename detail::broadcast_extent_helper::type; +using broadcast_extent = typename detail::broadcast_extent_impl::type; template using broadcast_vector_extent_type = broadcast_extent...>; @@ -116,33 +116,51 @@ broadcast_like(const V& input, const R& other) { return broadcast(input, vector_extent_type {}); } -namespace detail { +/** + * The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all + * operations are performed without any approximations or optimizations that could potentially alter the precise + * outcome of the computations + */ +struct accurate_policy {}; -template -struct apply_impl { - KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { -#pragma unroll - for (size_t i = 0; i < N; i++) { - output[i] = fun(args[i]...); - } - } +/** + * The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving + * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve + * approximations that slightly compromise precision. + */ +struct fast_policy { + using fallback_policy = accurate_policy; }; -template -struct apply_fastmath_impl: apply_impl {}; -} // namespace detail - -struct accurate_policy { - template - using type = detail::apply_impl; +/** + * This template policy allows developers to specify a custom degree of approximation for their computations. By + * adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the + * specific needs of your application. Higher values mean more precision. + */ +template +struct approx_level_policy { + using fallback_policy = approx_level_policy<>; }; -struct fast_policy { - template - using type = detail::apply_fastmath_impl; +template<> +struct approx_level_policy<> { + using fallback_policy = fast_policy; }; -#ifdef KERNEL_FLOAT_POLICY +/** + * The approximate_policy serves as the default approximation policy, providing a standard level of approximation + * without requiring explicit configuration. It balances accuracy and performance, making it suitable for + * general-purpose use cases where neither extreme precision nor maximum speed is necessary. + */ +using approx_policy = approx_level_policy<>; + +/** + * The `default_policy` acts as the standard computation policy. It can be configured externally using the + * `KERNEL_FLOAT_GLOBAL_POLICY` macro. If `KERNEL_FLOAT_GLOBAL_POLICY` is not defined, default to `accurate_policy`. + */ +#if defined(KERNEL_FLOAT_GLOBAL_POLICY) +using default_policy = KERNEL_FLOAT_GLOBAL_POLICY; +#elif defined(KERNEL_FLOAT_POLICY) using default_policy = KERNEL_FLOAT_POLICY; #else using default_policy = accurate_policy; @@ -150,8 +168,35 @@ using default_policy = accurate_policy; namespace detail { +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + +template +struct apply_impl; + +template +struct apply_base_impl: apply_impl {}; + +template +struct apply_impl: apply_base_impl {}; + +// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + output[i] = invoke_impl::call(fun, args[i]...); + } + } +}; + template -struct map_policy_impl { +struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; static constexpr size_t remainder = N % packet_size; @@ -159,7 +204,7 @@ struct map_policy_impl { if constexpr (N / packet_size > 0) { #pragma unroll for (size_t i = 0; i < N - remainder; i += packet_size) { - Policy::template type::call( + apply_impl::call( fun, output + i, (args + i)...); @@ -169,14 +214,14 @@ struct map_policy_impl { if constexpr (remainder > 0) { #pragma unroll for (size_t i = N - remainder; i < N; i++) { - Policy::template type::call(fun, output + i, (args + i)...); + apply_impl::call(fun, output + i, (args + i)...); } } } }; template -using map_impl = map_policy_impl; +using default_map_impl = map_impl; } // namespace detail @@ -200,7 +245,7 @@ KERNEL_FLOAT_INLINE map_type map(F fun, const Args&... args) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, Output, vector_value_type...>::call( + detail::map_impl, Output, vector_value_type...>::call( fun, result.data(), (detail::broadcast_impl, vector_extent_type, E>::call( @@ -212,4 +257,4 @@ KERNEL_FLOAT_INLINE map_type map(F fun, const Args&... args) { } // namespace kernel_float -#endif // KERNEL_FLOAT_APPLY_H \ No newline at end of file +#endif // KERNEL_FLOAT_APPLY_H diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h new file mode 100644 index 0000000..32a53aa --- /dev/null +++ b/include/kernel_float/approx.h @@ -0,0 +1,432 @@ +#pragma once + +#include "apply.h" +#include "bf16.h" +#include "fp16.h" +#include "macros.h" + +namespace kernel_float { + +namespace approx { + +static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); +static_assert(sizeof(unsigned short) * 8 == 16, "invalid size of unsigned short"); +using uint32_t = unsigned int; +using uint16_t = unsigned short; + +template +KERNEL_FLOAT_DEVICE T transmute(const U& input) { + static_assert(sizeof(T) == sizeof(U), "types must have equal size"); + T result {}; + ::memcpy(&result, &input, sizeof(T)); + return result; +} + +KERNEL_FLOAT_DEVICE uint32_t +bitwise_if_else(uint32_t condition, uint32_t if_true, uint32_t if_false) { + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + // equivalent to (condition & if_true) | ((~condition) & if_false) + asm("lop3.b32 %0, %1, %2, %3, 0xCA;" + : "=r"(result) + : "r"(condition), "r"(if_true), "r"(if_false)); +#else + result = (condition & if_true) | ((~condition) & if_false); +#endif + return result; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x) { + return y; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x, T coef, TRest... coefs) { + y = __hfma2(x, y, T2 {coef, coef}); + return eval_poly_recur(y, x, coefs...); +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly(T2 x, T coef, TRest... coefs) { + return eval_poly_recur(T2 {coef, coef}, x, coefs...); +} + +#define KERNEL_FLOAT_DEFINE_POLY(NAME, N, ...) \ + template \ + struct NAME { \ + template \ + static KERNEL_FLOAT_DEVICE T2 call(T2 x) { \ + return eval_poly(x, __VA_ARGS__); \ + } \ + }; + +template +struct sin_poly: sin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 1, 1.365) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 2, -21.56, 5.18) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 3, 53.53, -38.06, 6.184) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 4, -56.1, 77.94, -41.1, 6.277) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 5, 32.78, -74.5, 81.4, -41.34, 6.28) + +template +struct cos_poly: cos_poly {}; +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 1, 0.0) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 2, -8.0, 0.6943) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 3, 38.94, -17.5, 0.9707) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 4, -59.66, 61.12, -19.56, 0.9985) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 5, 45.66, -82.4, 64.7, -19.73, 1.0) + +template +struct asin_poly: asin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 1, 1.531) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 2, -0.169, 1.567) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) { + // Flip signbit of input when sign<0 + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + asm("lop3.b32 %0, %1, %2, %3, 0x6A;" + : "=r"(result) + : "r"(0x80008000), "r"(transmute(sign)), "r"(transmute(input))); +#else + result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); +#endif + + return transmute(result); +} + +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(half2_t a, half2_t b) { + uint32_t val; +#if KERNEL_FLOAT_IS_CUDA + uint32_t ai = *(reinterpret_cast(&a)); + uint32_t bi = *(reinterpret_cast(&b)); + asm("{ set.gt.u32.f16x2 %0,%1,%2;\n}" : "=r"(val) : "r"(ai), "r"(bi)); +#else + val = transmute(make_short2(a.x > b.x ? ~0 : 0, a.y > b.y ? ~0 : 0)); +#endif + return val; +} + +KERNEL_FLOAT_INLINE half2_t make_half2(half x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) { + /* Using rint is too slow. Round using floating-point magic instead. */ + // half2_t x = arg * make_half2(-0.15915494309); + // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); + + // 1/(2pi) = 0.15915494309189535 + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE half2_t cos(half2_t x) { + half2_t xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE half2_t sin(half2_t x) { + half2_t xf = normalize_trig_input(x); + return sin_poly::call(__hmul2(xf, xf)) * xf; +} + +template +KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) { + // Flip bits + uint32_t m = ~transmute(x); + + // Multiply by bias (add contant) + half2_t y = transmute(uint32_t(0x776d776d) + m); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + // y += y * (1 - x * y) + y = __hfma2(y, __hfma2(-x, y, make_half2(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) { + // A small number added such that rsqrt(0) does not return NaN + static constexpr double EPS = 0.00000768899917602539; + + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias + static constexpr uint32_t BIAS = 0x199c199c; + half2_t y = transmute(uint32_t(r) + BIAS); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS)); + half2_t correction = __hfma2(half_x, y * y, make_half2(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE half2_t sqrt(half2_t x) { + if (Iter == 1) { + half2_t y = rsqrt<0>(x); + + // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` + half2_t xy = x * y; + return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); + } + + return x * rsqrt(x); +} + +template +KERNEL_FLOAT_DEVICE half2_t asin(half2_t x) { + static constexpr double HALF_PI = 1.57079632679; + auto abs_x = __habs2(x); + auto v = asin_poly::call(abs_x); + auto abs_y = __hfma2(-v, sqrt(make_half2(1) - abs_x), make_half2(HALF_PI)); + return flipsign(abs_y, x); +} + +template +KERNEL_FLOAT_DEVICE half2_t acos(half2_t x) { + static constexpr double HALF_PI = 1.57079632679; + return make_half2(HALF_PI) - asin(x); +} + +template +KERNEL_FLOAT_DEVICE half2_t exp(half2_t x) { + half2_t y; + + if (Deg == 0) { + // Bring the value to range [32, 64] + // 1.442 = 1/log(2) + // 46.969 = 32.5/log(2) + half2_t m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + + // Transmute to int, shift higher mantissa bits into exponent field. + y = transmute((transmute(m) & 0x03ff03ff) << 5); + } else { + // Add a large number to round to an integer + half2_t v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + + // The exponent is now in the lower 5 bits. Shift that into the exponent field. + half2_t exp = transmute((transmute(v) & 0x001f001f) << 10); + + // The fractional part can be obtained from "1231-v". + // 0.6934 = log(2) + half2_t frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + + // This is the Taylor expansion of "exp(x)-1" around 0 + half2_t adjust; + if (Deg == 1) { + adjust = frac; + } else if (Deg == 2) { + // adjust = frac + 0.5 * frac^2 + adjust = __hfma2(frac, __hmul2(frac, make_half2(0.5)), frac); + } else /* if (Deg == 2) */ { + // adjust = frac + 0.5 * frac^2 + 1/6 * frac^3 + adjust = __hfma2( + frac, + __hmul2(__hfma2(frac, make_half2(0.1666), make_half2(0.5)), frac), + frac); + } + + // result = exp * (adjust + 1) + y = __hfma2(exp, adjust, exp); + } + + // Values below -10.39 (= -15*log(2)) become zero + uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); + return transmute(zero_mask & transmute(y)); +} + +template +KERNEL_FLOAT_DEVICE half2_t log(half2_t arg) { + // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) + uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); + + // 0.6934 = log(2) + // 32.53 = 46.969*log(2) + return __hfma2(transmute(bits), make_half2(0.6934), make_half2(-32.53125)); +} + +template +KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) { + if (Deg == 0) { + return x * rcp<0>(make_half2(0.2869) + __habs2(x)); + } else { + auto c0 = make_half2(0.4531); + auto c1 = make_half2(0.5156); + auto x2b = __hfma2(x, x, c1); + return (x * x2b) * rcp(__hfma2(x2b, __habs2(x), c0)); + } +} + +#endif // KERNEL_FLOAT_FP16_AVAILABLE + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(double x) { + return {__double2bfloat16(x), __double2bfloat16(x)}; +} + +KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) { + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + bfloat16x2_t ws = __hadd2( + __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), + make_bfloat162(OFFSET)); + return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); + return __hmul2(sin_poly::call(__hmul2(xf, xf)), xf); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t rcp(bfloat16x2_t x) { + bfloat16x2_t y = transmute(uint32_t(0x7ef07ef0) + ~transmute(x)); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + y = __hfma2(y, __hfma2(__hneg2(x), y, make_bfloat162(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt(bfloat16x2_t x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias (0x1f36) + bfloat16x2_t y = transmute(uint32_t(r) + uint32_t(0x1f361f36)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + bfloat16x2_t half_x = __hmul2(make_bfloat162(-0.5), x); + bfloat16x2_t correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { + return __hmul2(x, rsqrt(x)); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { + static constexpr float SCALE = 1.44272065994 / 256.0; + static constexpr float OFFSET = 382.4958400542335; + static constexpr float MINIMUM = 382; + + float a = fmaxf(fmaf(__bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(__bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); + + return { + transmute(uint16_t(transmute(a))), + transmute(uint16_t(transmute(b)))}; +} +#endif +} // namespace approx + +namespace detail { +template +struct apply_impl, F, 1, T, T> { + KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { + T in2[2], out2[2]; + in2[0] = input[0]; + apply_impl, F, 2, T, T>::call(fun, out2, in2); + output[0] = out2[0]; + } +}; +} // namespace detail + +#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN, 2, T, T> { \ + KERNEL_FLOAT_INLINE static void call(ops::FUN, T* output, const T* input) { \ + auto res = approx::FUN({input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + \ + template<> \ + struct apply_impl, 2, T, T>: \ + apply_impl, ops::FUN, 2, T, T> {}; \ + } + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2) +#endif + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0) +//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(sin) +KERNEL_FLOAT_DEFINE_APPROX_FUN(cos) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(exp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(log) + +} // namespace kernel_float diff --git a/include/kernel_float/base.h b/include/kernel_float/base.h index 403bceb..2f658b5 100644 --- a/include/kernel_float/base.h +++ b/include/kernel_float/base.h @@ -4,6 +4,10 @@ #include "macros.h" #include "meta.h" +#if KERNEL_FLOAT_IS_HIP +#include +#endif + namespace kernel_float { template @@ -266,7 +270,7 @@ using promoted_vector_value_type = promote_t...>; template KERNEL_FLOAT_INLINE vector_storage_type into_vector_storage(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } } // namespace kernel_float diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index 7db292c..f89fc3c 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -4,7 +4,16 @@ #include "macros.h" #if KERNEL_FLOAT_BF16_AVAILABLE +//#define CUDA_NO_BFLOAT16 (1) +//#define __CUDA_NO_BFLOAT16_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT162_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1) + +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif #include "binops.h" #include "reduce.h" @@ -12,103 +21,144 @@ namespace kernel_float { +#if KERNEL_FLOAT_IS_CUDA +using bfloat16_t = __nv_bfloat16; +using bfloat16x2_t = __nv_bfloat162; +#elif KERNEL_FLOAT_IS_HIP +using bfloat16_t = __hip_bfloat16; +using bfloat16x2_t = __hip_bfloat162; +#endif + +#if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 +#define KERNEL_FLOAT_BF16_OPS_SUPPORTED 1 +#endif + template<> -struct preferred_vector_size<__nv_bfloat16> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, bfloat16_t) template<> -struct into_vector_impl<__nv_bfloat162> { - using value_type = __nv_bfloat16; +struct into_vector_impl { + using value_type = bfloat16_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__nv_bfloat16, 2> call(__nv_bfloat162 input) { + static vector_storage call(bfloat16x2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__nv_bfloat16> { +struct allow_float_fallback { static constexpr bool value = true; }; }; // namespace detail -#if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat16* result, const __nv_bfloat16* a) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, bfloat16_t, bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME, bfloat16_t* result, const bfloat16_t* a) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } -#else -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) -#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 -KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) + KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp2, ::hexp2, ::h2exp2) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log2, ::hlog2, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) + KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) + +// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops. +// For CUDA, we can just use the regular bfloat16 functions (see above). +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data &= 0x7FFF; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data ^= 0x8000; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) { + return {hip_habs(a.x), hip_habs(a.y)}; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) { + return {hip_hneg(a.x), hip_hneg(a.y)}; +} + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2) #endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 \ - operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void call( \ - ops::NAME<__nv_bfloat16>, \ - __nv_bfloat16* result, \ - const __nv_bfloat16* a, \ - const __nv_bfloat16* b) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}, __nv_bfloat162 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t left, bfloat16_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl< \ + accurate_policy, \ + ops::NAME, \ + 2, \ + bfloat16_t, \ + bfloat16_t, \ + bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void call( \ + ops::NAME, \ + bfloat16_t* result, \ + const bfloat16_t* a, \ + const bfloat16_t* b) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}, bfloat16x2_t {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } -#else -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) -#endif +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2) @@ -122,13 +172,13 @@ KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2) KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) +#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED namespace ops { template<> -struct fma<__nv_bfloat16> { - KERNEL_FLOAT_INLINE __nv_bfloat16 - operator()(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) const { +struct fma { + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t a, bfloat16_t b, bfloat16_t c) const { return __hfma(a, b, c); } }; @@ -137,95 +187,98 @@ struct fma<__nv_bfloat16> { namespace detail { template<> struct apply_impl< - ops::fma<__nv_bfloat16>, + accurate_policy, + ops::fma, 2, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16> { + bfloat16_t, + bfloat16_t, + bfloat16_t, + bfloat16_t> { KERNEL_FLOAT_INLINE static void call( - ops::fma<__nv_bfloat16>, - __nv_bfloat16* result, - const __nv_bfloat16* a, - const __nv_bfloat16* b, - const __nv_bfloat16* c) { - __nv_bfloat162 r = __hfma2( - __nv_bfloat162 {a[0], a[1]}, - __nv_bfloat162 {b[0], b[1]}, - __nv_bfloat162 {c[0], c[1]}); + ops::fma, + bfloat16_t* result, + const bfloat16_t* a, + const bfloat16_t* b, + const bfloat16_t* c) { + bfloat16x2_t r = __hfma2( + bfloat16x2_t {a[0], a[1]}, + bfloat16x2_t {b[0], b[1]}, + bfloat16x2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; -} // namespace detail -#endif - -namespace ops { -template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) { - return __double2bfloat16(input); - }; -}; -template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(float input) { - return __float2bfloat16(input); +// clang-format off +#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \ + template \ + struct apply_impl, N, bfloat16_t, bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, bfloat16_t* output, const bfloat16_t* input) { \ + float v[N]; \ + map_impl, N, float, bfloat16_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, bfloat16_t, float>::call({}, output, v); \ + } \ }; -}; +// clang-format on -template<> -struct cast<__nv_bfloat16, float> { - KERNEL_FLOAT_INLINE float operator()(__nv_bfloat16 input) { - return __bfloat162float(input); - }; -}; -} // namespace ops +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH) +} // namespace detail +#endif -#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ - namespace ops { \ - template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) { \ - return TO_HALF; \ - } \ - }; \ - template<> \ - struct cast<__nv_bfloat16, T> { \ - KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) { \ - return FROM_HALF; \ - } \ - }; \ +#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ + namespace ops { \ + template<> \ + struct cast { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(T input) { \ + return TO_HALF; \ + } \ + }; \ + template<> \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(bfloat16_t input) { \ + return FROM_HALF; \ + } \ + }; \ } -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input)) +KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input)) + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED // clang-format off // there are no official char casts. Instead, cast to int and then to char KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input)); +KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input))); KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input)); -KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input)); +KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); // clang-format on -#else +#endif + +#if KERNEL_FLOAT_IS_CUDA +//KERNEL_FLOAT_BF16_CAST( +// bool, +// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, +// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); +#elif KERNEL_FLOAT_IS_HIP KERNEL_FLOAT_BF16_CAST( bool, - __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, - (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); + __hip_bfloat16 {input ? (unsigned short)0 : (unsigned short)0x3C00}, + (__hip_bfloat16(input).data & 0x7FFF) != 0); #endif -using bfloat16 = __nv_bfloat16; -KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16) +KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, bfloat16_t) } // namespace kernel_float @@ -234,12 +287,12 @@ KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) namespace kernel_float { template<> -struct promote_type<__nv_bfloat16, __half> { +struct promote_type { using type = float; }; template<> -struct promote_type<__half, __nv_bfloat16> { +struct promote_type { using type = float; }; diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index eb958f1..85ac823 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -44,7 +44,7 @@ using zip_common_type = vector< * vec c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f] * ``` */ -template +template KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, const R& right) { using T = promoted_vector_value_type; using O = result_t; @@ -52,7 +52,7 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co vector_storage> result; - detail::map_impl, O, T, T>::call( + detail::map_impl, O, T, T>::call( fun, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -65,10 +65,17 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co return result; } -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ - template> \ - KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ + template< \ + typename Accuracy = default_policy, \ + typename L, \ + typename R, \ + typename C = promoted_vector_value_type> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common( \ + ops::NAME {}, \ + static_cast(left), \ + static_cast(right)); \ } #define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \ @@ -144,7 +151,7 @@ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool(left) ^ bool(right), boo // clang-format on // clang-format off -template typename F, typename T, typename E, typename R> +template typename F, typename T, typename E, typename R> static constexpr bool is_vector_assign_allowed = is_vector_broadcastable && is_implicit_convertible< @@ -182,14 +189,16 @@ namespace ops { template struct min { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left < right ? left : right; + auto cond = less {}(left, right); + return cast {}(cond) ? left : right; } }; template struct max { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left > right ? left : right; + auto cond = greater {}(left, right); + return cast {}(cond) ? left : right; } }; @@ -290,36 +299,51 @@ struct multiply { }; // namespace ops namespace detail { -template -struct apply_fastmath_impl, N, T, T, T> { - KERNEL_FLOAT_INLINE static void - call(ops::divide fun, T* result, const T* lhs, const T* rhs) { +template +struct apply_base_impl, N, T, T, T> { + KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; // Fast way to perform division is to multiply by the reciprocal - apply_fastmath_impl, N, T, T>::call({}, rhs_rcp, rhs); - apply_fastmath_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); + apply_impl, N, T, T>::call({}, rhs_rcp, rhs); + apply_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); } }; #if KERNEL_FLOAT_IS_DEVICE template<> -struct apply_fastmath_impl, 1, float, float, float> { +struct apply_impl, 1, float, float, float> { KERNEL_FLOAT_INLINE static void - call(ops::divide fun, float* result, const float* lhs, const float* rhs) { + call(ops::divide, float* result, const float* lhs, const float* rhs) { *result = __fdividef(*lhs, *rhs); } }; #endif } // namespace detail +namespace detail { +// Override `pow` using `log2` and `exp2` +template +struct apply_base_impl, N, T, T, T> { + KERNEL_FLOAT_INLINE static void call(ops::pow, T* result, const T* lhs, const T* rhs) { + T lhs_log[N]; + T result_log[N]; + + // Fast way to perform power function is using log2 and exp2 + apply_impl, N, T, T>::call({}, lhs_log, lhs); + apply_impl, N, T, T, T>::call({}, result_log, lhs_log, rhs); + apply_impl, N, T, T, T>::call({}, result, result_log); + } +}; +} // namespace detail + template> KERNEL_FLOAT_INLINE zip_common_type, T, T> fast_divide(const L& left, const R& right) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, extent_size, T, T, T>::call( + detail::map_impl, extent_size, T, T, T>::call( ops::divide {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -359,8 +383,7 @@ template< typename L, typename R, typename T = promoted_vector_value_type, - typename = - enable_if_t> && is_vector_broadcastable>>> + typename = enable_if_t<(vector_size == 3 && vector_size == 3)>> KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { return detail::cross_impl::call(convert_storage(left), convert_storage(right)); } diff --git a/include/kernel_float/conversion.h b/include/kernel_float/conversion.h index 8538be7..881cc51 100644 --- a/include/kernel_float/conversion.h +++ b/include/kernel_float/conversion.h @@ -17,7 +17,10 @@ struct convert_impl { static vector_storage> call(vector_storage> input) { using F = ops::cast; vector_storage> intermediate; - detail::map_impl, T2, T>::call(F {}, intermediate.data(), input.data()); + detail::default_map_impl, T2, T>::call( + F {}, + intermediate.data(), + input.data()); return detail::broadcast_impl::call(intermediate); } }; @@ -48,7 +51,7 @@ struct convert_impl { using F = ops::cast; vector_storage> result; - detail::map_impl, T2, T>::call(F {}, result.data(), input.data()); + detail::default_map_impl, T2, T>::call(F {}, result.data(), input.data()); return result; } }; @@ -84,11 +87,11 @@ KERNEL_FLOAT_INLINE vector> convert(const V& input, extent new_s template struct AssignConversionProxy { KERNEL_FLOAT_INLINE - explicit AssignConversionProxy(T* ptr) : ptr_(ptr) {} + explicit AssignConversionProxy(T&& ptr) : ptr_(std::forward(ptr)) {} template KERNEL_FLOAT_INLINE AssignConversionProxy& operator=(U&& values) { - *ptr_ = detail::convert_impl< + ptr_ = detail::convert_impl< vector_value_type, vector_extent_type, vector_value_type, @@ -99,12 +102,12 @@ struct AssignConversionProxy { } private: - T* ptr_; + T ptr_; }; /** * Takes a vector reference and gives back a helper object. This object allows you to assign - * a vector of a different type to another vector while perofrming implicit type converion. + * a vector of a different type to another vector while performing implicit type conversion. * * For example, if `x = expression;` does not compile because `x` and `expression` are * different vector types, you can use `cast_to(x) = expression;` to make it work. @@ -117,9 +120,9 @@ struct AssignConversionProxy { * cast_to(x) = y; // Normally, `x = y;` would give an error, but `cast_to` fixes that. * ``` */ -template -KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { - return AssignConversionProxy(&input); +template +KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T&& input) { + return AssignConversionProxy(std::forward(input)); } /** @@ -132,7 +135,7 @@ KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { * ``` */ template -KERNEL_FLOAT_INLINE vector> fill(T value = {}, extent = {}) { +KERNEL_FLOAT_INLINE vector> fill(T value, extent = {}) { vector_storage input = {value}; return detail::broadcast_impl, extent>::call(input); } diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 0b62d9b..cbc3383 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -4,54 +4,66 @@ #include "macros.h" #if KERNEL_FLOAT_FP16_AVAILABLE +//#define CUDA_NO_HALF (1) +//#define __CUDA_NO_HALF_OPERATORS__ (1) +//#define __CUDA_NO_HALF2_OPERATORS__ (1) +//#define __CUDA_NO_HALF_CONVERSIONS__ (1) + +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif #include "vector.h" namespace kernel_float { +using half_t = ::__half; +using half2_t = ::__half2; + template<> -struct preferred_vector_size<__half> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, half_t) template<> -struct into_vector_impl<__half2> { - using value_type = __half; +struct into_vector_impl { + using value_type = half_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__half, 2> call(__half2 input) { + static vector_storage call(half2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__half> { +struct allow_float_fallback { static constexpr bool value = true; }; -}; // namespace detail +} // namespace detail #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half input) { \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t input) { \ return FUN1(input); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half> { \ - KERNEL_FLOAT_INLINE static void call(ops::NAME<__half>, __half* result, const __half* a) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}); \ + struct apply_impl, 2, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void call(ops::NAME, half_t* result, const half_t* a) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -60,52 +72,61 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) #endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil) -KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos) -KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp) -KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor) -KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log) -KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin) -KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) -KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) +KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) + +KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp2, hexp2, h2exp2) +KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) +KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log2, hlog2, h2log2) +KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) + +KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp) + +KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) #if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __half, __half, __half> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t left, half_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, half_t, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME, half_t* result, const half_t* a, const half_t* b) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } #else #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) #endif +// There are not available in HIP +#if KERNEL_FLOAT_IS_CUDA +KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) +KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) +#endif + KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2) KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div) -KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) -KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __hneu, __hneu2) @@ -117,8 +138,8 @@ KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2) #if KERNEL_FLOAT_IS_DEVICE namespace ops { template<> -struct fma<__half> { - KERNEL_FLOAT_INLINE __half operator()(__half a, __half b, __half c) const { +struct fma { + KERNEL_FLOAT_INLINE half_t operator()(half_t a, half_t b, half_t c) const { return __hfma(a, b, c); } }; @@ -126,32 +147,70 @@ struct fma<__half> { namespace detail { template<> -struct apply_impl, 2, __half, __half, __half, __half> { +struct apply_impl, 2, half_t, half_t, half_t, half_t> { KERNEL_FLOAT_INLINE static void - call(ops::fma<__half>, __half* result, const __half* a, const __half* b, const __half* c) { - __half2 r = __hfma2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}, __half2 {c[0], c[1]}); + call(ops::fma, half_t* result, const half_t* a, const half_t* b, const half_t* c) { + half2_t r = __hfma2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}, half2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; + +// clang-format off +#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \ + template \ + struct apply_impl, N, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, half_t* output, const half_t* input) { \ + float v[N]; \ + map_impl, N, float, half_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, half_t, float>::call({}, output, v); \ + } \ + }; +// clang-format on + +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH) } // namespace detail #endif #define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __half operator()(T input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE half_t operator()(T input) { \ return TO_HALF; \ } \ }; \ template<> \ - struct cast<__half, T> { \ - KERNEL_FLOAT_INLINE T operator()(__half input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(half_t input) { \ return FROM_HALF; \ } \ }; \ } +// Only CUDA has a special `__double2half` intrinsic +#if KERNEL_FLOAT_IS_HIP +#define KERNEL_FLOAT_FP16_CAST_FWD(T) \ + KERNEL_FLOAT_FP16_CAST(T, static_cast<_Float16>(input), static_cast(input)) + +KERNEL_FLOAT_FP16_CAST_FWD(double) +KERNEL_FLOAT_FP16_CAST_FWD(float) + +KERNEL_FLOAT_FP16_CAST_FWD(char) +KERNEL_FLOAT_FP16_CAST_FWD(signed char) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned char) + +KERNEL_FLOAT_FP16_CAST_FWD(signed short) +KERNEL_FLOAT_FP16_CAST_FWD(signed int) +KERNEL_FLOAT_FP16_CAST_FWD(signed long) +KERNEL_FLOAT_FP16_CAST_FWD(signed long long) + +KERNEL_FLOAT_FP16_CAST_FWD(unsigned short) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned int) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long) +#else KERNEL_FLOAT_FP16_CAST(double, __double2half(input), double(__half2float(input))); KERNEL_FLOAT_FP16_CAST(float, __float2half(input), __half2float(input)); @@ -160,20 +219,20 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input)); +KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input))); KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input)); +KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); +#endif -using half = __half; -KERNEL_FLOAT_VECTOR_ALIAS(half, __half) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +KERNEL_FLOAT_VECTOR_ALIAS(half, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) } // namespace kernel_float diff --git a/include/kernel_float/fp8.h b/include/kernel_float/fp8.h index 49c55d4..c6ee2e9 100644 --- a/include/kernel_float/fp8.h +++ b/include/kernel_float/fp8.h @@ -64,7 +64,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { #define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \ namespace detail { \ template<> \ - struct apply_impl, 2, FP8_TY, T> { \ + struct apply_impl, 2, FP8_TY, T> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, FP8_TY* result, const T* v) { \ __half2_raw x; \ memcpy(&x, v, 2 * sizeof(T)); \ @@ -73,7 +73,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { } \ }; \ template<> \ - struct apply_impl, 2, T, FP8_TY> { \ + struct apply_impl, 2, T, FP8_TY> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, T* result, const FP8_TY* v) { \ __nv_fp8x2_storage_t x; \ memcpy(&x, v, 2 * sizeof(FP8_TY)); \ @@ -91,12 +91,12 @@ KERNEL_FLOAT_FP8_CAST(double) #include "fp16.h" namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__half) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(half_t) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_FP16_AVAILABLE @@ -105,12 +105,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) #include "bf16.h" namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__nv_bfloat16) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(bfloat16_t) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_BF16_AVAILABLE diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 4eee8f1..88bbdbc 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -1,42 +1,52 @@ #ifndef KERNEL_FLOAT_MACROS_H #define KERNEL_FLOAT_MACROS_H +#ifdef __HIPCC__ +#include "hip/hip_runtime.h" +#endif + +// clang-format off #ifdef __CUDACC__ -#define KERNEL_FLOAT_CUDA (1) + #define KERNEL_FLOAT_IS_CUDA (1) + #define KERNEL_FLOAT_DEVICE __forceinline__ __device__ + + #ifdef __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __device__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else // __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __host__ + #define KERNEL_FLOAT_IS_HOST (1) + #endif // __CUDA_ARCH__ +#elif defined(__HIPCC__) + #define KERNEL_FLOAT_IS_HIP (1) + #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ -#ifdef __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __device__ -#define KERNEL_FLOAT_IS_DEVICE (1) -#define KERNEL_FLOAT_IS_HOST (0) -#define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__) -#else // __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __host__ -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDA_ARCH__ -#else // __CUDACC__ -#define KERNEL_FLOAT_INLINE inline -#define KERNEL_FLOAT_CUDA (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDACC__ + #ifdef __HIP_DEVICE_COMPILE__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else + #define KERNEL_FLOAT_IS_HOST (1) + #endif + +#else + #define KERNEL_FLOAT_INLINE inline + #define KERNEL_FLOAT_IS_HOST (1) +#endif #ifndef KERNEL_FLOAT_FP16_AVAILABLE -#define KERNEL_FLOAT_FP16_AVAILABLE (1) + #define KERNEL_FLOAT_FP16_AVAILABLE (1) #endif // KERNEL_FLOAT_FP16_AVAILABLE #ifndef KERNEL_FLOAT_BF16_AVAILABLE -#define KERNEL_FLOAT_BF16_AVAILABLE (1) + #define KERNEL_FLOAT_BF16_AVAILABLE (1) #endif // KERNEL_FLOAT_BF16_AVAILABLE #ifndef KERNEL_FLOAT_FP8_AVAILABLE -#ifdef __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) -#else // __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (0) -#endif // __CUDACC_VER_MAJOR__ + #ifdef __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) + #else // __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (0) + #endif // __CUDACC_VER_MAJOR__ #endif // KERNEL_FLOAT_FP8_AVAILABLE #define KERNEL_FLOAT_ASSERT(expr) \ @@ -51,20 +61,22 @@ // TOOD: check if this way is support across all compilers #if defined(__has_builtin) && 0 // Seems that `__builtin_assume_aligned` leads to segfaults -#if __has_builtin(__builtin_assume_aligned) -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( - __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) -#else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) -#endif + #if __has_builtin(__builtin_assume_aligned) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( + __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) + #else + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #endif #else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) #endif #define KERNEL_FLOAT_MAX_ALIGNMENT (32) #if KERNEL_FLOAT_FAST_MATH -#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; + #define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; #endif + // clang-format on + #endif //KERNEL_FLOAT_MACROS_H diff --git a/include/kernel_float/memory.h b/include/kernel_float/memory.h index d76ed1e..2f0a92a 100644 --- a/include/kernel_float/memory.h +++ b/include/kernel_float/memory.h @@ -33,7 +33,7 @@ struct copy_impl> { * * ``` * // Load 2 elements at data[0] and data[8], skip data[2] and data[4] - * vec values = = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); + * vec values = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); * ``` */ template> @@ -80,7 +80,7 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const I& indices, const V& values, const * vec values = read<4>(data); * * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values = read<4>(values + 10, data); + * vec values = read<4>(data + 10); * ``` */ template @@ -107,31 +107,42 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const V& values) { } namespace detail { +/** + * Returns the greatest common divisor of `a` and `b`. + */ KERNEL_FLOAT_INLINE constexpr size_t gcd(size_t a, size_t b) { return b == 0 ? a : gcd(b, a % b); } -template +/** + * Returns true if a pointer having alignment of `a` bytes also has an alignment of `b` bytes. Returns false otherwise. + */ +KERNEL_FLOAT_INLINE +constexpr size_t alignment_divisible(size_t a, size_t b) { + return gcd(a, KERNEL_FLOAT_MAX_ALIGNMENT) % gcd(b, KERNEL_FLOAT_MAX_ALIGNMENT) == 0; +} + +template struct copy_aligned_impl { static constexpr size_t K = N > 8 ? 8 : (N > 4 ? 4 : (N > 2 ? 2 : 1)); - static constexpr size_t alignment_K = gcd(alignment, sizeof(T) * K); + static constexpr size_t Alignment_K = gcd(Alignment, sizeof(T) * K); KERNEL_FLOAT_INLINE static void load(T* output, const T* input) { - copy_aligned_impl::load(output, input); - copy_aligned_impl::load(output + K, input + K); + copy_aligned_impl::load(output, input); + copy_aligned_impl::load(output + K, input + K); } KERNEL_FLOAT_INLINE static void store(T* output, const T* input) { - copy_aligned_impl::store(output, input); - copy_aligned_impl::store(output + K, input + K); + copy_aligned_impl::store(output, input); + copy_aligned_impl::store(output + K, input + K); } }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { KERNEL_FLOAT_INLINE static void load(T* output, const T* input) {} @@ -139,8 +150,8 @@ struct copy_aligned_impl { static void store(T* output, const T* input) {} }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { using storage_type = T; KERNEL_FLOAT_INLINE @@ -154,9 +165,9 @@ struct copy_aligned_impl { } }; -template -struct copy_aligned_impl sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 2 * sizeof(T)); +template +struct copy_aligned_impl sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 2 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1; }; @@ -174,9 +185,9 @@ struct copy_aligned_impl sizeof(T))>> } }; -template -struct copy_aligned_impl 2 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 4 * sizeof(T)); +template +struct copy_aligned_impl 2 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 4 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3; }; @@ -200,9 +211,9 @@ struct copy_aligned_impl 2 * sizeof(T) } }; -template -struct copy_aligned_impl 4 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 8 * sizeof(T)); +template +struct copy_aligned_impl 4 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 8 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3, v4, v5, v6, v7; }; @@ -246,8 +257,8 @@ struct copy_aligned_impl 4 * sizeof(T) * // Load 4 elements at locations data[0], data[1], data[2], data[3] * vec values = read_aligned<4>(data); * - * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values2 = read_aligned<4>(data + 10); + * // Load 4 elements at locations data[12], data[13], data[14], data[15] + * vec values2 = read_aligned<4>(data + 12); * ``` */ template @@ -294,10 +305,13 @@ KERNEL_FLOAT_INLINE void write_aligned(T* ptr, const V& values) { * @tparam U The underlying storage type. Defaults to the same type as T. * @tparam Align The alignment constraint for read and write operations. */ -template +template struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; /** @@ -313,7 +327,12 @@ struct vector_ref { * @return vector_type A vector of type vector_type containing the read and converted data. */ KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert(result); } /** @@ -324,7 +343,9 @@ struct vector_ref { */ template KERNEL_FLOAT_INLINE void write(const V& values) const { - write_aligned(data_, convert(values)); + detail::copy_aligned_impl::store( + KERNEL_FLOAT_ASSUME_ALIGNED(U, data_, access_alignment), + convert_storage(values).data()); } /** @@ -357,16 +378,24 @@ struct vector_ref { /** * Specialization for `vector_ref` if the backing storage is const. */ -template -struct vector_ref { +template +struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = const U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; KERNEL_FLOAT_INLINE explicit vector_ref(pointer_type data) : data_(data) {} KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert_storage(result); } KERNEL_FLOAT_INLINE operator vector_type() const { @@ -381,13 +410,13 @@ struct vector_ref { pointer_type data_ = nullptr; }; -#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ - template \ - KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ - vector_ref ptr, \ - const V& value) { \ - ptr.write(ptr.read() OP value); \ - return ptr; \ +#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ + template \ + KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ + vector_ref ptr, \ + const V& value) { \ + ptr.write(ptr.read() OP value); \ + return ptr; \ } KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(+, +=) @@ -395,6 +424,17 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(-, -=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(*, *=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) +template +struct into_vector_impl> { + using value_type = T; + using extent_type = extent; + + KERNEL_FLOAT_INLINE + static vector_storage call(const vector_ref& reference) { + return reference.read(); + } +}; + /** * A wrapper for a pointer that enables vectorized access and supports type conversions.. * @@ -408,9 +448,11 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) * @tparam T The type of the elements as viewed by the user. * @tparam N The alignment of T in number of elements. * @tparam U The underlying storage type, defaults to T. + * @tparam Alignment The assumed alignment of the underlying U* pointer. */ -template +template struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = U*; using value_type = decay_t; @@ -428,44 +470,47 @@ struct vector_ptr { * Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor * only allows conversion if the alignment of the source is greater than or equal to the alignment of the target. */ - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} /** - * Accesses a reference to a vector at a specific index with optional alignment considerations. - * - * @tparam N The number of elements in the vector to access, defaults to the alignment. - * @param index The index at which to access the vector. + * Shorthand for `at(0)`. */ - template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE const vector_ref operator*() const { + return vector_ref {data_}; } /** - * Accesses a vector at a specific index. + * Accesses a reference to a vector at a specific index with optional alignment considerations. * - * @tparam K The number of elements to read, defaults to `N`. - * @param index The index from which to read the data. + * @tparam N The number of elements in the vector to access, defaults to N. + * @param index The index at which to access the vector. */ template - KERNEL_FLOAT_INLINE vector> read(size_t index) const { - return this->template at(index).read(); + KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { + return vector_ref {data_ + index * N}; } /** - * Shorthand for `read(index)`. + * Shorthand for `at(index)`. */ - KERNEL_FLOAT_INLINE const vector> operator[](size_t index) const { - return read(index); + KERNEL_FLOAT_INLINE vector_ref + operator[](size_t index) const { + return at(index); } /** - * Shorthand for `read(0)`. + * Accesses a vector at a specific index. + * + * @tparam K The number of elements to read, defaults to `N`. + * @param index The index from which to read the data. */ - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); + template + KERNEL_FLOAT_INLINE vector> read(size_t index = 0) const { + return this->template at(index).read(); } /** @@ -481,15 +526,6 @@ struct vector_ptr { this->template at(index).write(values); } - /** - * Shorthand for `at(index)`. Returns a vector reference to can be used - * to assign to this pointer, contrary to `operator[]` that does not - * allow assignment. - */ - KERNEL_FLOAT_INLINE vector_ref operator()(size_t index) const { - return at(index); - } - /** * Gets the raw data pointer managed by this `vector_ptr`. */ @@ -504,26 +540,36 @@ struct vector_ptr { /** * Specialization for `vector_ptr` if the backing storage is const. */ -template -struct vector_ptr { +template +struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = const U*; using value_type = decay_t; vector_ptr() = default; + KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {} - template - KERNEL_FLOAT_INLINE - vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} + KERNEL_FLOAT_INLINE vector_ref operator*() const { + return vector_ref {data_}; + } + template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE vector_ref + at(size_t index) const { + return vector_ref {data_ + index * N}; } template @@ -535,10 +581,6 @@ struct vector_ptr { return read(index); } - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); - } - KERNEL_FLOAT_INLINE pointer_type get() const { return data_; } @@ -547,44 +589,54 @@ struct vector_ptr { pointer_type data_ = nullptr; }; -template -KERNEL_FLOAT_INLINE vector_ptr operator+(vector_ptr p, size_t i) { - return vector_ptr {p.get() + i * N}; +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(vector_ptr p, size_t i) { + return vector_ptr {p.get() + i * N}; } -template -KERNEL_FLOAT_INLINE vector_ptr operator+(size_t i, vector_ptr p) { +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(size_t i, vector_ptr p) { return p + i; } +template< + typename T, + size_t N, + typename U, + size_t A, + typename = enable_if_t<(N * sizeof(U)) % A == 0>> +KERNEL_FLOAT_INLINE vector_ptr& operator+=(vector_ptr& p, size_t i) { + return p = p + i; +} + /** - * Creates a `vector_ptr` from a raw pointer `U*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. * - * @tparam T The type of the elements as viewed by the user. This type may differ from `U`. * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. - * @tparam U The type of the elements pointed to by the raw pointer. + * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(U* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } // Doxygen cannot deal with the `assert_aligned` being defined twice, we ignore the second definition. /// @cond IGNORE /** - * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*`. The alignment is assumed to be KERNEL_FLOAT_MAX_ALIGNMENT. * - * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } /// @endcond -template -using vec_ptr = vector_ptr; +template +using vec_ptr = vector_ptr; #if defined(__cpp_deduction_guides) template diff --git a/include/kernel_float/meta.h b/include/kernel_float/meta.h index 005be9b..5256129 100644 --- a/include/kernel_float/meta.h +++ b/include/kernel_float/meta.h @@ -270,9 +270,6 @@ struct enable_if_impl { template using enable_if_t = typename detail::enable_if_impl::type; -template -using identity_t = T; - KERNEL_FLOAT_INLINE constexpr size_t round_up_to_power_of_two(size_t n) { size_t result = 1; diff --git a/include/kernel_float/prelude.h b/include/kernel_float/prelude.h index 59872a0..a723820 100644 --- a/include/kernel_float/prelude.h +++ b/include/kernel_float/prelude.h @@ -61,14 +61,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double) KERNEL_FLOAT_TYPE_ALIAS(float64x, double) #if KERNEL_FLOAT_FP16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(half, __half) -KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) -KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) +KERNEL_FLOAT_TYPE_ALIAS(half, half_t) +KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) +KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) #endif #if KERNEL_FLOAT_BF16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16) -KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, bfloat16_t) +KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t) #endif #if KERNEL_FLOAT_BF8_AVAILABLE @@ -82,12 +82,12 @@ static constexpr extent kextent = {}; template KERNEL_FLOAT_INLINE kvec, sizeof...(Args)> make_kvec(Args&&... args) { - return ::kernel_float::make_vec(std::forward(args)...); + return ::kernel_float::make_vec(static_cast(args)...); }; template KERNEL_FLOAT_INLINE into_vector_type into_kvec(V&& input) { - return ::kernel_float::into_vec(std::forward(input)); + return ::kernel_float::into_vec(static_cast(input)); } template diff --git a/include/kernel_float/reduce.h b/include/kernel_float/reduce.h index c859265..853df3a 100644 --- a/include/kernel_float/reduce.h +++ b/include/kernel_float/reduce.h @@ -24,7 +24,7 @@ struct reduce_recur_impl { template KERNEL_FLOAT_INLINE static T call(F fun, const T* input) { vector_storage temp; - map_impl::call(fun, temp.data(), input, input + K); + default_map_impl::call(fun, temp.data(), input, input + K); if constexpr (N < 2 * K) { #pragma unroll @@ -183,11 +183,11 @@ struct dot_impl { if constexpr (N / K > 0) { T accum[K] = {T {}}; - apply_impl, K, T, T, T>::call({}, accum, left, right); + apply_impl, K, T, T, T>::call({}, accum, left, right); #pragma unroll for (size_t i = 1; i < N / K; i++) { - apply_impl, K, T, T, T, T>::call( + apply_impl, K, T, T, T, T>::call( ops::fma {}, accum, left + i * K, @@ -200,7 +200,7 @@ struct dot_impl { if constexpr (N % K > 0) { for (size_t i = N - N % K; i < N; i++) { - apply_impl, 1, T, T, T, T>::call( + apply_impl, 1, T, T, T, T>::call( {}, &result, left + i, @@ -267,21 +267,21 @@ struct magnitude_impl { }; // The 3-argument overload of hypot is only available on host from C++17 -#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST -template<> -struct magnitude_impl { - static float call(const float* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; - -template<> -struct magnitude_impl { - static double call(const double* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; -#endif +//#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST +//template<> +//struct magnitude_impl { +// static float call(const float* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +// +//template<> +//struct magnitude_impl { +// static double call(const double* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +//#endif } // namespace detail diff --git a/include/kernel_float/tiling.h b/include/kernel_float/tiling.h index 10ad379..d4f944a 100644 --- a/include/kernel_float/tiling.h +++ b/include/kernel_float/tiling.h @@ -329,9 +329,7 @@ struct tiling { using index_type = IndexType; using point_type = vector>; -#if KERNEL_FLOAT_IS_DEVICE - __forceinline__ __device__ tiling() : block_(threadIdx) {} -#endif + __forceinline__ __device__ tiling() : block_(dim3(threadIdx)) {} KERNEL_FLOAT_INLINE tiling(BlockDim block, vec offset = {}) : block_(block), offset_(offset) {} diff --git a/include/kernel_float/triops.h b/include/kernel_float/triops.h index 12cca59..8e27e70 100644 --- a/include/kernel_float/triops.h +++ b/include/kernel_float/triops.h @@ -41,7 +41,7 @@ KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values, cons using F = ops::conditional; vector_storage> result; - detail::map_impl, T, bool, T, T>::call( + detail::default_map_impl, T, bool, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, bool, E>::call( @@ -99,12 +99,12 @@ struct fma { namespace detail { template -struct apply_impl, N, T, T, T, T> { +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail @@ -140,7 +140,7 @@ KERNEL_FLOAT_INLINE vector fma(const A& a, const B& b, const C& c) { using F = ops::fma; vector_storage> result; - detail::map_impl, T, T, T, T>::call( + detail::default_map_impl, T, T, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index fce130e..81ddec0 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -20,43 +20,49 @@ struct cast { }; template -struct cast { +struct cast { KERNEL_FLOAT_INLINE T operator()(T input) noexcept { return input; } }; -template -struct cast_float_fallback; - -template -struct cast_float_fallback { +template +struct cast { KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return R(input); + if constexpr ( + detail::allow_float_fallback::value || detail::allow_float_fallback::value) { + return cast {}(cast {}(input)); + } else { + return R(input); + } } }; -// clang-format off -template -struct cast_float_fallback< - T, - R, - enable_if_t< - !is_same_type && - !is_same_type && - (detail::allow_float_fallback::value || detail::allow_float_fallback::value) - > -> { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast {}(cast {}(input)); +template<> +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; } }; -// clang-format on -template -struct cast { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast_float_fallback {}(input); +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(T input) noexcept { + return float(input); + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(float input) noexcept { + return T(input); } }; @@ -123,8 +129,8 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast \ struct NAME::value>> { \ KERNEL_FLOAT_INLINE T operator()(T input_arg) { \ - float input = float(input_arg); \ - return T(EXPR_F32); \ + float input = ops::cast {}(input_arg); \ + return ops::cast {}(EXPR_F32); \ } \ }; \ \ @@ -140,49 +146,56 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast(ops::NAME> {}, input); \ } +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) + KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) + KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) #if KERNEL_FLOAT_IS_DEVICE - -#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ +#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ - *result = FAST_FUN(*inputs); \ + T input = inputs[0]; \ + *result = EXPR_F32; \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) // Seems to be missing? +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log2, __log2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log10, __log10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) + +// This PTX is only supported on CUDA +#if KERNEL_FLOAT_IS_CUDA #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F fun, T* result, const T* inputs) { \ - asm(INSTR : "=" REG(*result) : REG(*inputs)); \ + asm(INSTR " %0, %1;" : "=" REG(*result) : REG(*inputs)); \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64 %0, %1;", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64 %0, %1;", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.ftz.f64", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32 %0, %1;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32 %0, %1;", "f") +// These are no longer necessary due to the KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN above +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") +#endif +#define KERNEL_FLOAT_FAST_F32_MAP(F) \ + F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) + +#else +#define KERNEL_FLOAT_FAST_F32_MAP(F) #endif } // namespace kernel_float diff --git a/include/kernel_float/vector.h b/include/kernel_float/vector.h index e52fa02..222db2e 100644 --- a/include/kernel_float/vector.h +++ b/include/kernel_float/vector.h @@ -38,7 +38,7 @@ struct vector: public S { // Copy anything of type `storage_type` KERNEL_FLOAT_INLINE vector(const value_type& input = {}) : - storage_type(detail::broadcast_impl, E>::call(input)) {} + storage_type(detail::broadcast_impl, E>::call(input)) {} // For all other arguments, we convert it using `convert_storage` according to broadcast rules template, T>, int> = 0> @@ -82,6 +82,14 @@ struct vector: public S { return *this; } + /** + * Returns an instance of the `extent_type` for this vector. + */ + KERNEL_FLOAT_INLINE + extent_type extent() const { + return {}; + } + /** * Returns a pointer to the underlying storage data. */ @@ -206,7 +214,7 @@ struct vector: public S { */ KERNEL_FLOAT_INLINE void set(size_t x, T value) { - at(x) = std::move(value); + at(x) = static_cast(value); } /** @@ -239,7 +247,8 @@ struct vector: public S { * Broadcast this vector into a new size `(Ns...)`. */ template - KERNEL_FLOAT_INLINE vector> broadcast(extent new_size = {}) const { + KERNEL_FLOAT_INLINE vector> + broadcast(kernel_float::extent new_size = {}) const { return kernel_float::broadcast(*this, new_size); } @@ -274,15 +283,24 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE void for_each(F fun) const { - return kernel_float::for_each(*this, std::move(fun)); + return kernel_float::for_each(*this, fun); } /** - * Returns the result of `*this + lhs * rhs`. + * Returns the result of `this + lhs * rhs`. * * The operation is performed using a single `kernel_float::fma` call, which may be faster then perform * the addition and multiplication separately. */ + template< + typename L, + typename R, + typename T2 = promote_t, vector_value_type>, + typename E2 = broadcast_extent, vector_extent_type>> + KERNEL_FLOAT_INLINE vector add_mul(const L& lhs, const R& rhs) const { + return ::kernel_float::fma(lhs, rhs, *this); + } + template< typename L, typename R, @@ -303,7 +321,7 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE into_vector_type into_vec(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } template @@ -325,6 +343,7 @@ template using vec8 = vec; #define KERNEL_FLOAT_VECTOR_ALIAS(NAME, T) \ template \ + using v##NAME = vec; \ using NAME##1 = vec; \ using NAME##2 = vec; \ using NAME##3 = vec; \ diff --git a/kernel_tuner/vector_add.cu b/kernel_tuner/vector_add.cu index 5260bf8..2c0f339 100644 --- a/kernel_tuner/vector_add.cu +++ b/kernel_tuner/vector_add.cu @@ -1,6 +1,7 @@ #include "kernel_float.h" namespace kf = kernel_float; +extern "C" __global__ void vector_add( kf::vec* c, const kf::vec* a, diff --git a/kernel_tuner/vector_add.py b/kernel_tuner/vector_add.py index 8c5b2ee..4085cba 100644 --- a/kernel_tuner/vector_add.py +++ b/kernel_tuner/vector_add.py @@ -9,6 +9,7 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + "/../" flags = [f"-I{ROOT_DIR}/include", "-std=c++17"] + def tune(): # Prepare input data @@ -54,7 +55,7 @@ def tune(): answer=answer, observers=observers, metrics=metrics, - lang="cupy", + lang="CUDA", compiler_options=flags ) diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index a79da18..6d1754b 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,49 +16,59 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-09-23 14:12:25.024358 -// git hash: 3a88b56a57cce5e1f3365aa6e8efb76a14f7f865 +// date: 2025-01-27 16:26:28.827757 +// git hash: 09dc82096e4c013a079f0e315da1ccce17453c93 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H #define KERNEL_FLOAT_MACROS_H +#ifdef __HIPCC__ +#include "hip/hip_runtime.h" +#endif + +// clang-format off #ifdef __CUDACC__ -#define KERNEL_FLOAT_CUDA (1) - -#ifdef __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __device__ -#define KERNEL_FLOAT_IS_DEVICE (1) -#define KERNEL_FLOAT_IS_HOST (0) -#define KERNEL_FLOAT_CUDA_ARCH (__CUDA_ARCH__) -#else // __CUDA_ARCH__ -#define KERNEL_FLOAT_INLINE __forceinline__ __host__ -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDA_ARCH__ -#else // __CUDACC__ -#define KERNEL_FLOAT_INLINE inline -#define KERNEL_FLOAT_CUDA (0) -#define KERNEL_FLOAT_IS_HOST (1) -#define KERNEL_FLOAT_IS_DEVICE (0) -#define KERNEL_FLOAT_CUDA_ARCH (0) -#endif // __CUDACC__ + #define KERNEL_FLOAT_IS_CUDA (1) + #define KERNEL_FLOAT_DEVICE __forceinline__ __device__ + + #ifdef __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __device__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else // __CUDA_ARCH__ + #define KERNEL_FLOAT_INLINE __forceinline__ __host__ + #define KERNEL_FLOAT_IS_HOST (1) + #endif // __CUDA_ARCH__ +#elif defined(__HIPCC__) + #define KERNEL_FLOAT_IS_HIP (1) + #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ + #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ + + #ifdef __HIP_DEVICE_COMPILE__ + #define KERNEL_FLOAT_IS_DEVICE (1) + #else + #define KERNEL_FLOAT_IS_HOST (1) + #endif + +#else + #define KERNEL_FLOAT_INLINE inline + #define KERNEL_FLOAT_IS_HOST (1) +#endif #ifndef KERNEL_FLOAT_FP16_AVAILABLE -#define KERNEL_FLOAT_FP16_AVAILABLE (1) + #define KERNEL_FLOAT_FP16_AVAILABLE (1) #endif // KERNEL_FLOAT_FP16_AVAILABLE #ifndef KERNEL_FLOAT_BF16_AVAILABLE -#define KERNEL_FLOAT_BF16_AVAILABLE (1) + #define KERNEL_FLOAT_BF16_AVAILABLE (1) #endif // KERNEL_FLOAT_BF16_AVAILABLE #ifndef KERNEL_FLOAT_FP8_AVAILABLE -#ifdef __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) -#else // __CUDACC_VER_MAJOR__ -#define KERNEL_FLOAT_FP8_AVAILABLE (0) -#endif // __CUDACC_VER_MAJOR__ + #ifdef __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (__CUDACC_VER_MAJOR__ >= 12) + #else // __CUDACC_VER_MAJOR__ + #define KERNEL_FLOAT_FP8_AVAILABLE (0) + #endif // __CUDACC_VER_MAJOR__ #endif // KERNEL_FLOAT_FP8_AVAILABLE #define KERNEL_FLOAT_ASSERT(expr) \ @@ -73,22 +83,24 @@ // TOOD: check if this way is support across all compilers #if defined(__has_builtin) && 0 // Seems that `__builtin_assume_aligned` leads to segfaults -#if __has_builtin(__builtin_assume_aligned) -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( - __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) -#else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) -#endif + #if __has_builtin(__builtin_assume_aligned) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) static_cast ( + __builtin_assume_aligned(static_cast (PTR), (ALIGNMENT))) + #else + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #endif #else -#define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) + #define KERNEL_FLOAT_ASSUME_ALIGNED(TYPE, PTR, ALIGNMENT) (PTR) #endif #define KERNEL_FLOAT_MAX_ALIGNMENT (32) #if KERNEL_FLOAT_FAST_MATH -#define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; + #define KERNEL_FLOAT_POLICY ::kernel_float::fast_policy; #endif + // clang-format on + #endif //KERNEL_FLOAT_MACROS_H #ifndef KERNEL_FLOAT_CORE_H #define KERNEL_FLOAT_CORE_H @@ -362,9 +374,6 @@ struct enable_if_impl { template using enable_if_t = typename detail::enable_if_impl::type; -template -using identity_t = T; - KERNEL_FLOAT_INLINE constexpr size_t round_up_to_power_of_two(size_t n) { size_t result = 1; @@ -383,6 +392,10 @@ constexpr size_t round_up_to_power_of_two(size_t n) { +#if KERNEL_FLOAT_IS_HIP +#include +#endif + namespace kernel_float { template @@ -645,7 +658,7 @@ using promoted_vector_value_type = promote_t...>; template KERNEL_FLOAT_INLINE vector_storage_type into_vector_storage(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } } // namespace kernel_float @@ -660,41 +673,41 @@ namespace kernel_float { namespace detail { template -struct broadcast_extent_helper; +struct broadcast_extent_impl; template -struct broadcast_extent_helper { +struct broadcast_extent_impl { using type = E; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent> { +struct broadcast_extent_impl, extent> { using type = extent; }; template -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent; }; template<> -struct broadcast_extent_helper, extent<1>> { +struct broadcast_extent_impl, extent<1>> { using type = extent<1>; }; template -struct broadcast_extent_helper: - broadcast_extent_helper::type, C, Rest...> {}; +struct broadcast_extent_impl: + broadcast_extent_impl::type, C, Rest...> {}; } // namespace detail template -using broadcast_extent = typename detail::broadcast_extent_helper::type; +using broadcast_extent = typename detail::broadcast_extent_impl::type; template using broadcast_vector_extent_type = broadcast_extent...>; @@ -769,33 +782,51 @@ broadcast_like(const V& input, const R& other) { return broadcast(input, vector_extent_type {}); } -namespace detail { +/** + * The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all + * operations are performed without any approximations or optimizations that could potentially alter the precise + * outcome of the computations + */ +struct accurate_policy {}; -template -struct apply_impl { - KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { -#pragma unroll - for (size_t i = 0; i < N; i++) { - output[i] = fun(args[i]...); - } - } +/** + * The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving + * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve + * approximations that slightly compromise precision. + */ +struct fast_policy { + using fallback_policy = accurate_policy; }; -template -struct apply_fastmath_impl: apply_impl {}; -} // namespace detail - -struct accurate_policy { - template - using type = detail::apply_impl; +/** + * This template policy allows developers to specify a custom degree of approximation for their computations. By + * adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the + * specific needs of your application. Higher values mean more precision. + */ +template +struct approx_level_policy { + using fallback_policy = approx_level_policy<>; }; -struct fast_policy { - template - using type = detail::apply_fastmath_impl; +template<> +struct approx_level_policy<> { + using fallback_policy = fast_policy; }; -#ifdef KERNEL_FLOAT_POLICY +/** + * The approximate_policy serves as the default approximation policy, providing a standard level of approximation + * without requiring explicit configuration. It balances accuracy and performance, making it suitable for + * general-purpose use cases where neither extreme precision nor maximum speed is necessary. + */ +using approx_policy = approx_level_policy<>; + +/** + * The `default_policy` acts as the standard computation policy. It can be configured externally using the + * `KERNEL_FLOAT_GLOBAL_POLICY` macro. If `KERNEL_FLOAT_GLOBAL_POLICY` is not defined, default to `accurate_policy`. + */ +#if defined(KERNEL_FLOAT_GLOBAL_POLICY) +using default_policy = KERNEL_FLOAT_GLOBAL_POLICY; +#elif defined(KERNEL_FLOAT_POLICY) using default_policy = KERNEL_FLOAT_POLICY; #else using default_policy = accurate_policy; @@ -803,8 +834,35 @@ using default_policy = accurate_policy; namespace detail { +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + +template +struct apply_impl; + +template +struct apply_base_impl: apply_impl {}; + +template +struct apply_impl: apply_base_impl {}; + +// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + output[i] = invoke_impl::call(fun, args[i]...); + } + } +}; + template -struct map_policy_impl { +struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; static constexpr size_t remainder = N % packet_size; @@ -812,7 +870,7 @@ struct map_policy_impl { if constexpr (N / packet_size > 0) { #pragma unroll for (size_t i = 0; i < N - remainder; i += packet_size) { - Policy::template type::call( + apply_impl::call( fun, output + i, (args + i)...); @@ -822,14 +880,14 @@ struct map_policy_impl { if constexpr (remainder > 0) { #pragma unroll for (size_t i = N - remainder; i < N; i++) { - Policy::template type::call(fun, output + i, (args + i)...); + apply_impl::call(fun, output + i, (args + i)...); } } } }; template -using map_impl = map_policy_impl; +using default_map_impl = map_impl; } // namespace detail @@ -853,7 +911,7 @@ KERNEL_FLOAT_INLINE map_type map(F fun, const Args&... args) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, Output, vector_value_type...>::call( + detail::map_impl, Output, vector_value_type...>::call( fun, result.data(), (detail::broadcast_impl, vector_extent_type, E>::call( @@ -1137,43 +1195,49 @@ struct cast { }; template -struct cast { +struct cast { KERNEL_FLOAT_INLINE T operator()(T input) noexcept { return input; } }; -template -struct cast_float_fallback; - -template -struct cast_float_fallback { +template +struct cast { KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return R(input); + if constexpr ( + detail::allow_float_fallback::value || detail::allow_float_fallback::value) { + return cast {}(cast {}(input)); + } else { + return R(input); + } } }; -// clang-format off -template -struct cast_float_fallback< - T, - R, - enable_if_t< - !is_same_type && - !is_same_type && - (detail::allow_float_fallback::value || detail::allow_float_fallback::value) - > -> { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast {}(cast {}(input)); +template<> +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; } }; -// clang-format on -template -struct cast { - KERNEL_FLOAT_INLINE R operator()(T input) noexcept { - return cast_float_fallback {}(input); +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(float input) noexcept { + return input; + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE float operator()(T input) noexcept { + return float(input); + } +}; + +template +struct cast { + KERNEL_FLOAT_INLINE T operator()(float input) noexcept { + return T(input); } }; @@ -1240,8 +1304,8 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast \ struct NAME::value>> { \ KERNEL_FLOAT_INLINE T operator()(T input_arg) { \ - float input = float(input_arg); \ - return T(EXPR_F32); \ + float input = ops::cast {}(input_arg); \ + return ops::cast {}(EXPR_F32); \ } \ }; \ \ @@ -1257,49 +1321,56 @@ KERNEL_FLOAT_DEFINE_UNARY_OP(logical_not, !, (ops::cast {}(!ops::cast(ops::NAME> {}, input); \ } +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) + KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) + KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) #if KERNEL_FLOAT_IS_DEVICE - -#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ +#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ - *result = FAST_FUN(*inputs); \ + T input = inputs[0]; \ + *result = EXPR_F32; \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) // Seems to be missing? +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log2, __log2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log10, __log10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) + +// This PTX is only supported on CUDA +#if KERNEL_FLOAT_IS_CUDA #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ template<> \ - struct apply_fastmath_impl, 1, T, T> { \ + struct apply_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F fun, T* result, const T* inputs) { \ - asm(INSTR : "=" REG(*result) : REG(*inputs)); \ + asm(INSTR " %0, %1;" : "=" REG(*result) : REG(*inputs)); \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64 %0, %1;", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64 %0, %1;", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rcp, "rcp.approx.ftz.f64", "d") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.ftz.f64", "d") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32 %0, %1;", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f") + +// These are no longer necessary due to the KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN above +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") +#endif -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32 %0, %1;", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32 %0, %1;", "f") +#define KERNEL_FLOAT_FAST_F32_MAP(F) \ + F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt) +#else +#define KERNEL_FLOAT_FAST_F32_MAP(F) #endif } // namespace kernel_float @@ -1390,7 +1479,10 @@ struct convert_impl { static vector_storage> call(vector_storage> input) { using F = ops::cast; vector_storage> intermediate; - detail::map_impl, T2, T>::call(F {}, intermediate.data(), input.data()); + detail::default_map_impl, T2, T>::call( + F {}, + intermediate.data(), + input.data()); return detail::broadcast_impl::call(intermediate); } }; @@ -1421,7 +1513,7 @@ struct convert_impl { using F = ops::cast; vector_storage> result; - detail::map_impl, T2, T>::call(F {}, result.data(), input.data()); + detail::default_map_impl, T2, T>::call(F {}, result.data(), input.data()); return result; } }; @@ -1457,11 +1549,11 @@ KERNEL_FLOAT_INLINE vector> convert(const V& input, extent new_s template struct AssignConversionProxy { KERNEL_FLOAT_INLINE - explicit AssignConversionProxy(T* ptr) : ptr_(ptr) {} + explicit AssignConversionProxy(T&& ptr) : ptr_(std::forward(ptr)) {} template KERNEL_FLOAT_INLINE AssignConversionProxy& operator=(U&& values) { - *ptr_ = detail::convert_impl< + ptr_ = detail::convert_impl< vector_value_type, vector_extent_type, vector_value_type, @@ -1472,12 +1564,12 @@ struct AssignConversionProxy { } private: - T* ptr_; + T ptr_; }; /** * Takes a vector reference and gives back a helper object. This object allows you to assign - * a vector of a different type to another vector while perofrming implicit type converion. + * a vector of a different type to another vector while performing implicit type conversion. * * For example, if `x = expression;` does not compile because `x` and `expression` are * different vector types, you can use `cast_to(x) = expression;` to make it work. @@ -1490,9 +1582,9 @@ struct AssignConversionProxy { * cast_to(x) = y; // Normally, `x = y;` would give an error, but `cast_to` fixes that. * ``` */ -template -KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { - return AssignConversionProxy(&input); +template +KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T&& input) { + return AssignConversionProxy(std::forward(input)); } /** @@ -1505,7 +1597,7 @@ KERNEL_FLOAT_INLINE AssignConversionProxy cast_to(T& input) { * ``` */ template -KERNEL_FLOAT_INLINE vector> fill(T value = {}, extent = {}) { +KERNEL_FLOAT_INLINE vector> fill(T value, extent = {}) { vector_storage input = {value}; return detail::broadcast_impl, extent>::call(input); } @@ -1634,7 +1726,7 @@ using zip_common_type = vector< * vec c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f] * ``` */ -template +template KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, const R& right) { using T = promoted_vector_value_type; using O = result_t; @@ -1642,7 +1734,7 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co vector_storage> result; - detail::map_impl, O, T, T>::call( + detail::map_impl, O, T, T>::call( fun, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -1655,10 +1747,17 @@ KERNEL_FLOAT_INLINE zip_common_type zip_common(F fun, const L& left, co return result; } -#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ - template> \ - KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ - return zip_common(ops::NAME {}, std::forward(left), std::forward(right)); \ +#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \ + template< \ + typename Accuracy = default_policy, \ + typename L, \ + typename R, \ + typename C = promoted_vector_value_type> \ + KERNEL_FLOAT_INLINE zip_common_type, L, R> NAME(L&& left, R&& right) { \ + return zip_common( \ + ops::NAME {}, \ + static_cast(left), \ + static_cast(right)); \ } #define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \ @@ -1734,7 +1833,7 @@ KERNEL_FLOAT_DEFINE_BINARY_OP_FALLBACK(bit_xor, ^, bool(left) ^ bool(right), boo // clang-format on // clang-format off -template typename F, typename T, typename E, typename R> +template typename F, typename T, typename E, typename R> static constexpr bool is_vector_assign_allowed = is_vector_broadcastable && is_implicit_convertible< @@ -1772,14 +1871,16 @@ namespace ops { template struct min { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left < right ? left : right; + auto cond = less {}(left, right); + return cast {}(cond) ? left : right; } }; template struct max { KERNEL_FLOAT_INLINE T operator()(T left, T right) { - return left > right ? left : right; + auto cond = greater {}(left, right); + return cast {}(cond) ? left : right; } }; @@ -1880,36 +1981,51 @@ struct multiply { }; // namespace ops namespace detail { -template -struct apply_fastmath_impl, N, T, T, T> { - KERNEL_FLOAT_INLINE static void - call(ops::divide fun, T* result, const T* lhs, const T* rhs) { +template +struct apply_base_impl, N, T, T, T> { + KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; // Fast way to perform division is to multiply by the reciprocal - apply_fastmath_impl, N, T, T>::call({}, rhs_rcp, rhs); - apply_fastmath_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); + apply_impl, N, T, T>::call({}, rhs_rcp, rhs); + apply_impl, N, T, T, T>::call({}, result, lhs, rhs_rcp); } }; #if KERNEL_FLOAT_IS_DEVICE template<> -struct apply_fastmath_impl, 1, float, float, float> { +struct apply_impl, 1, float, float, float> { KERNEL_FLOAT_INLINE static void - call(ops::divide fun, float* result, const float* lhs, const float* rhs) { + call(ops::divide, float* result, const float* lhs, const float* rhs) { *result = __fdividef(*lhs, *rhs); } }; #endif } // namespace detail +namespace detail { +// Override `pow` using `log2` and `exp2` +template +struct apply_base_impl, N, T, T, T> { + KERNEL_FLOAT_INLINE static void call(ops::pow, T* result, const T* lhs, const T* rhs) { + T lhs_log[N]; + T result_log[N]; + + // Fast way to perform power function is using log2 and exp2 + apply_impl, N, T, T>::call({}, lhs_log, lhs); + apply_impl, N, T, T, T>::call({}, result_log, lhs_log, rhs); + apply_impl, N, T, T, T>::call({}, result, result_log); + } +}; +} // namespace detail + template> KERNEL_FLOAT_INLINE zip_common_type, T, T> fast_divide(const L& left, const R& right) { using E = broadcast_vector_extent_type; vector_storage> result; - detail::map_policy_impl, extent_size, T, T, T>::call( + detail::map_impl, extent_size, T, T, T>::call( ops::divide {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -1949,8 +2065,7 @@ template< typename L, typename R, typename T = promoted_vector_value_type, - typename = - enable_if_t> && is_vector_broadcastable>>> + typename = enable_if_t<(vector_size == 3 && vector_size == 3)>> KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { return detail::cross_impl::call(convert_storage(left), convert_storage(right)); } @@ -2456,7 +2571,7 @@ struct copy_impl> { * * ``` * // Load 2 elements at data[0] and data[8], skip data[2] and data[4] - * vec values = = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); + * vec values = read(data, make_vec(0, 2, 4, 8), make_vec(true, false, false, true)); * ``` */ template> @@ -2503,7 +2618,7 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const I& indices, const V& values, const * vec values = read<4>(data); * * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values = read<4>(values + 10, data); + * vec values = read<4>(data + 10); * ``` */ template @@ -2530,31 +2645,42 @@ KERNEL_FLOAT_INLINE void write(T* ptr, const V& values) { } namespace detail { +/** + * Returns the greatest common divisor of `a` and `b`. + */ KERNEL_FLOAT_INLINE constexpr size_t gcd(size_t a, size_t b) { return b == 0 ? a : gcd(b, a % b); } -template +/** + * Returns true if a pointer having alignment of `a` bytes also has an alignment of `b` bytes. Returns false otherwise. + */ +KERNEL_FLOAT_INLINE +constexpr size_t alignment_divisible(size_t a, size_t b) { + return gcd(a, KERNEL_FLOAT_MAX_ALIGNMENT) % gcd(b, KERNEL_FLOAT_MAX_ALIGNMENT) == 0; +} + +template struct copy_aligned_impl { static constexpr size_t K = N > 8 ? 8 : (N > 4 ? 4 : (N > 2 ? 2 : 1)); - static constexpr size_t alignment_K = gcd(alignment, sizeof(T) * K); + static constexpr size_t Alignment_K = gcd(Alignment, sizeof(T) * K); KERNEL_FLOAT_INLINE static void load(T* output, const T* input) { - copy_aligned_impl::load(output, input); - copy_aligned_impl::load(output + K, input + K); + copy_aligned_impl::load(output, input); + copy_aligned_impl::load(output + K, input + K); } KERNEL_FLOAT_INLINE static void store(T* output, const T* input) { - copy_aligned_impl::store(output, input); - copy_aligned_impl::store(output + K, input + K); + copy_aligned_impl::store(output, input); + copy_aligned_impl::store(output + K, input + K); } }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { KERNEL_FLOAT_INLINE static void load(T* output, const T* input) {} @@ -2562,8 +2688,8 @@ struct copy_aligned_impl { static void store(T* output, const T* input) {} }; -template -struct copy_aligned_impl { +template +struct copy_aligned_impl { using storage_type = T; KERNEL_FLOAT_INLINE @@ -2577,9 +2703,9 @@ struct copy_aligned_impl { } }; -template -struct copy_aligned_impl sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 2 * sizeof(T)); +template +struct copy_aligned_impl sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 2 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1; }; @@ -2597,9 +2723,9 @@ struct copy_aligned_impl sizeof(T))>> } }; -template -struct copy_aligned_impl 2 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 4 * sizeof(T)); +template +struct copy_aligned_impl 2 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 4 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3; }; @@ -2623,9 +2749,9 @@ struct copy_aligned_impl 2 * sizeof(T) } }; -template -struct copy_aligned_impl 4 * sizeof(T))>> { - static constexpr size_t storage_alignment = gcd(alignment, 8 * sizeof(T)); +template +struct copy_aligned_impl 4 * sizeof(T))>> { + static constexpr size_t storage_alignment = gcd(Alignment, 8 * sizeof(T)); struct alignas(storage_alignment) storage_type { T v0, v1, v2, v3, v4, v5, v6, v7; }; @@ -2669,8 +2795,8 @@ struct copy_aligned_impl 4 * sizeof(T) * // Load 4 elements at locations data[0], data[1], data[2], data[3] * vec values = read_aligned<4>(data); * - * // Load 4 elements at locations data[10], data[11], data[12], data[13] - * vec values2 = read_aligned<4>(data + 10); + * // Load 4 elements at locations data[12], data[13], data[14], data[15] + * vec values2 = read_aligned<4>(data + 12); * ``` */ template @@ -2717,10 +2843,13 @@ KERNEL_FLOAT_INLINE void write_aligned(T* ptr, const V& values) { * @tparam U The underlying storage type. Defaults to the same type as T. * @tparam Align The alignment constraint for read and write operations. */ -template +template struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; /** @@ -2736,7 +2865,12 @@ struct vector_ref { * @return vector_type A vector of type vector_type containing the read and converted data. */ KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert(result); } /** @@ -2747,7 +2881,9 @@ struct vector_ref { */ template KERNEL_FLOAT_INLINE void write(const V& values) const { - write_aligned(data_, convert(values)); + detail::copy_aligned_impl::store( + KERNEL_FLOAT_ASSUME_ALIGNED(U, data_, access_alignment), + convert_storage(values).data()); } /** @@ -2780,16 +2916,24 @@ struct vector_ref { /** * Specialization for `vector_ref` if the backing storage is const. */ -template -struct vector_ref { +template +struct vector_ref { + static constexpr size_t access_alignment = detail::gcd(Alignment, KERNEL_FLOAT_MAX_ALIGNMENT); + static_assert(access_alignment >= alignof(U), "invalid alignment for pointer type"); + using pointer_type = const U*; - using value_type = decay_t; + using value_type = T; using vector_type = vector>; KERNEL_FLOAT_INLINE explicit vector_ref(pointer_type data) : data_(data) {} KERNEL_FLOAT_INLINE vector_type read() const { - return convert(read_aligned(data_)); + vector_storage result; + detail::copy_aligned_impl::load( + result.data(), + KERNEL_FLOAT_ASSUME_ALIGNED(const U, data_, access_alignment)); + + return convert_storage(result); } KERNEL_FLOAT_INLINE operator vector_type() const { @@ -2804,13 +2948,13 @@ struct vector_ref { pointer_type data_ = nullptr; }; -#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ - template \ - KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ - vector_ref ptr, \ - const V& value) { \ - ptr.write(ptr.read() OP value); \ - return ptr; \ +#define KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(OP, OP_ASSIGN) \ + template \ + KERNEL_FLOAT_INLINE vector_ref operator OP_ASSIGN( \ + vector_ref ptr, \ + const V& value) { \ + ptr.write(ptr.read() OP value); \ + return ptr; \ } KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(+, +=) @@ -2818,6 +2962,17 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(-, -=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(*, *=) KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) +template +struct into_vector_impl> { + using value_type = T; + using extent_type = extent; + + KERNEL_FLOAT_INLINE + static vector_storage call(const vector_ref& reference) { + return reference.read(); + } +}; + /** * A wrapper for a pointer that enables vectorized access and supports type conversions.. * @@ -2831,9 +2986,11 @@ KERNEL_FLOAT_VECTOR_REF_ASSIGN_OP(/, /=) * @tparam T The type of the elements as viewed by the user. * @tparam N The alignment of T in number of elements. * @tparam U The underlying storage type, defaults to T. + * @tparam Alignment The assumed alignment of the underlying U* pointer. */ -template +template struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = U*; using value_type = decay_t; @@ -2851,44 +3008,47 @@ struct vector_ptr { * Constructs a vector_ptr from another vector_ptr with potentially different alignment and type. This constructor * only allows conversion if the alignment of the source is greater than or equal to the alignment of the target. */ - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} /** - * Accesses a reference to a vector at a specific index with optional alignment considerations. - * - * @tparam N The number of elements in the vector to access, defaults to the alignment. - * @param index The index at which to access the vector. + * Shorthand for `at(0)`. */ - template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE const vector_ref operator*() const { + return vector_ref {data_}; } /** - * Accesses a vector at a specific index. + * Accesses a reference to a vector at a specific index with optional alignment considerations. * - * @tparam K The number of elements to read, defaults to `N`. - * @param index The index from which to read the data. + * @tparam N The number of elements in the vector to access, defaults to N. + * @param index The index at which to access the vector. */ template - KERNEL_FLOAT_INLINE vector> read(size_t index) const { - return this->template at(index).read(); + KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { + return vector_ref {data_ + index * N}; } /** - * Shorthand for `read(index)`. + * Shorthand for `at(index)`. */ - KERNEL_FLOAT_INLINE const vector> operator[](size_t index) const { - return read(index); + KERNEL_FLOAT_INLINE vector_ref + operator[](size_t index) const { + return at(index); } /** - * Shorthand for `read(0)`. + * Accesses a vector at a specific index. + * + * @tparam K The number of elements to read, defaults to `N`. + * @param index The index from which to read the data. */ - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); + template + KERNEL_FLOAT_INLINE vector> read(size_t index = 0) const { + return this->template at(index).read(); } /** @@ -2904,15 +3064,6 @@ struct vector_ptr { this->template at(index).write(values); } - /** - * Shorthand for `at(index)`. Returns a vector reference to can be used - * to assign to this pointer, contrary to `operator[]` that does not - * allow assignment. - */ - KERNEL_FLOAT_INLINE vector_ref operator()(size_t index) const { - return at(index); - } - /** * Gets the raw data pointer managed by this `vector_ptr`. */ @@ -2927,26 +3078,36 @@ struct vector_ptr { /** * Specialization for `vector_ptr` if the backing storage is const. */ -template -struct vector_ptr { +template +struct vector_ptr { + static constexpr size_t offset_alignment = detail::gcd(Alignment, sizeof(U) * N); using pointer_type = const U*; using value_type = decay_t; vector_ptr() = default; + KERNEL_FLOAT_INLINE explicit vector_ptr(pointer_type p) : data_(p) {} - template - KERNEL_FLOAT_INLINE - vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} - template - KERNEL_FLOAT_INLINE vector_ptr(vector_ptr p, enable_if_t<(N2 % N == 0), int> = {}) : + template + KERNEL_FLOAT_INLINE vector_ptr( + vector_ptr p, + enable_if_t = {}) : data_(p.get()) {} + KERNEL_FLOAT_INLINE vector_ref operator*() const { + return vector_ref {data_}; + } + template - KERNEL_FLOAT_INLINE vector_ref at(size_t index) const { - return vector_ref {data_ + index * N}; + KERNEL_FLOAT_INLINE vector_ref + at(size_t index) const { + return vector_ref {data_ + index * N}; } template @@ -2958,10 +3119,6 @@ struct vector_ptr { return read(index); } - KERNEL_FLOAT_INLINE const vector> operator*() const { - return read(0); - } - KERNEL_FLOAT_INLINE pointer_type get() const { return data_; } @@ -2970,44 +3127,54 @@ struct vector_ptr { pointer_type data_ = nullptr; }; -template -KERNEL_FLOAT_INLINE vector_ptr operator+(vector_ptr p, size_t i) { - return vector_ptr {p.get() + i * N}; +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(vector_ptr p, size_t i) { + return vector_ptr {p.get() + i * N}; } -template -KERNEL_FLOAT_INLINE vector_ptr operator+(size_t i, vector_ptr p) { +template +KERNEL_FLOAT_INLINE vector_ptr +operator+(size_t i, vector_ptr p) { return p + i; } +template< + typename T, + size_t N, + typename U, + size_t A, + typename = enable_if_t<(N * sizeof(U)) % A == 0>> +KERNEL_FLOAT_INLINE vector_ptr& operator+=(vector_ptr& p, size_t i) { + return p = p + i; +} + /** - * Creates a `vector_ptr` from a raw pointer `U*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. * - * @tparam T The type of the elements as viewed by the user. This type may differ from `U`. * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. - * @tparam U The type of the elements pointed to by the raw pointer. + * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(U* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } // Doxygen cannot deal with the `assert_aligned` being defined twice, we ignore the second definition. /// @cond IGNORE /** - * Creates a `vector_ptr` from a raw pointer `T*` by asserting a specific alignment `N`. + * Creates a `vector_ptr` from a raw pointer `T*`. The alignment is assumed to be KERNEL_FLOAT_MAX_ALIGNMENT. * - * @tparam N The alignment constraint for the vector_ptr. Defaults to KERNEL_FLOAT_MAX_ALIGNMENT. * @tparam T The type of the elements pointed to by the raw pointer. */ -template -KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { - return vector_ptr {ptr}; +template +KERNEL_FLOAT_INLINE vector_ptr assert_aligned(T* ptr) { + return vector_ptr {ptr}; } /// @endcond -template -using vec_ptr = vector_ptr; +template +using vec_ptr = vector_ptr; #if defined(__cpp_deduction_guides) template @@ -3071,7 +3238,7 @@ KERNEL_FLOAT_INLINE vector where(const C& cond, const L& true_values, cons using F = ops::conditional; vector_storage> result; - detail::map_impl, T, bool, T, T>::call( + detail::default_map_impl, T, bool, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, bool, E>::call( @@ -3129,12 +3296,12 @@ struct fma { namespace detail { template -struct apply_impl, N, T, T, T, T> { +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail @@ -3170,7 +3337,7 @@ KERNEL_FLOAT_INLINE vector fma(const A& a, const B& b, const C& c) { using F = ops::fma; vector_storage> result; - detail::map_impl, T, T, T, T>::call( + detail::default_map_impl, T, T, T, T>::call( F {}, result.data(), detail::convert_impl, vector_extent_type, T, E>::call( @@ -3215,7 +3382,7 @@ struct reduce_recur_impl { template KERNEL_FLOAT_INLINE static T call(F fun, const T* input) { vector_storage temp; - map_impl::call(fun, temp.data(), input, input + K); + default_map_impl::call(fun, temp.data(), input, input + K); if constexpr (N < 2 * K) { #pragma unroll @@ -3374,11 +3541,11 @@ struct dot_impl { if constexpr (N / K > 0) { T accum[K] = {T {}}; - apply_impl, K, T, T, T>::call({}, accum, left, right); + apply_impl, K, T, T, T>::call({}, accum, left, right); #pragma unroll for (size_t i = 1; i < N / K; i++) { - apply_impl, K, T, T, T, T>::call( + apply_impl, K, T, T, T, T>::call( ops::fma {}, accum, left + i * K, @@ -3391,7 +3558,7 @@ struct dot_impl { if constexpr (N % K > 0) { for (size_t i = N - N % K; i < N; i++) { - apply_impl, 1, T, T, T, T>::call( + apply_impl, 1, T, T, T, T>::call( {}, &result, left + i, @@ -3458,21 +3625,21 @@ struct magnitude_impl { }; // The 3-argument overload of hypot is only available on host from C++17 -#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST -template<> -struct magnitude_impl { - static float call(const float* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; - -template<> -struct magnitude_impl { - static double call(const double* input) { - return ::hypot(input[0], input[1], input[2]); - } -}; -#endif +//#if defined(__cpp_lib_hypot) && KERNEL_FLOAT_IS_HOST +//template<> +//struct magnitude_impl { +// static float call(const float* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +// +//template<> +//struct magnitude_impl { +// static double call(const double* input) { +// return ::hypot(input[0], input[1], input[2]); +// } +//}; +//#endif } // namespace detail @@ -3534,7 +3701,7 @@ struct vector: public S { // Copy anything of type `storage_type` KERNEL_FLOAT_INLINE vector(const value_type& input = {}) : - storage_type(detail::broadcast_impl, E>::call(input)) {} + storage_type(detail::broadcast_impl, E>::call(input)) {} // For all other arguments, we convert it using `convert_storage` according to broadcast rules template, T>, int> = 0> @@ -3578,6 +3745,14 @@ struct vector: public S { return *this; } + /** + * Returns an instance of the `extent_type` for this vector. + */ + KERNEL_FLOAT_INLINE + extent_type extent() const { + return {}; + } + /** * Returns a pointer to the underlying storage data. */ @@ -3702,7 +3877,7 @@ struct vector: public S { */ KERNEL_FLOAT_INLINE void set(size_t x, T value) { - at(x) = std::move(value); + at(x) = static_cast(value); } /** @@ -3735,7 +3910,8 @@ struct vector: public S { * Broadcast this vector into a new size `(Ns...)`. */ template - KERNEL_FLOAT_INLINE vector> broadcast(extent new_size = {}) const { + KERNEL_FLOAT_INLINE vector> + broadcast(kernel_float::extent new_size = {}) const { return kernel_float::broadcast(*this, new_size); } @@ -3770,15 +3946,24 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE void for_each(F fun) const { - return kernel_float::for_each(*this, std::move(fun)); + return kernel_float::for_each(*this, fun); } /** - * Returns the result of `*this + lhs * rhs`. + * Returns the result of `this + lhs * rhs`. * * The operation is performed using a single `kernel_float::fma` call, which may be faster then perform * the addition and multiplication separately. */ + template< + typename L, + typename R, + typename T2 = promote_t, vector_value_type>, + typename E2 = broadcast_extent, vector_extent_type>> + KERNEL_FLOAT_INLINE vector add_mul(const L& lhs, const R& rhs) const { + return ::kernel_float::fma(lhs, rhs, *this); + } + template< typename L, typename R, @@ -3799,7 +3984,7 @@ struct vector: public S { */ template KERNEL_FLOAT_INLINE into_vector_type into_vec(V&& input) { - return into_vector_impl::call(std::forward(input)); + return into_vector_impl::call(static_cast(input)); } template @@ -3821,6 +4006,7 @@ template using vec8 = vec; #define KERNEL_FLOAT_VECTOR_ALIAS(NAME, T) \ template \ + using v##NAME = vec; \ using NAME##1 = vec; \ using NAME##2 = vec; \ using NAME##3 = vec; \ @@ -3874,54 +4060,66 @@ vec(Args&&... args) -> vec, sizeof...(Args)>; #if KERNEL_FLOAT_FP16_AVAILABLE +//#define CUDA_NO_HALF (1) +//#define __CUDA_NO_HALF_OPERATORS__ (1) +//#define __CUDA_NO_HALF2_OPERATORS__ (1) +//#define __CUDA_NO_HALF_CONVERSIONS__ (1) + +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif namespace kernel_float { +using half_t = ::__half; +using half2_t = ::__half2; + template<> -struct preferred_vector_size<__half> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __half) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, half_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, half_t) template<> -struct into_vector_impl<__half2> { - using value_type = __half; +struct into_vector_impl { + using value_type = half_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__half, 2> call(__half2 input) { + static vector_storage call(half2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__half> { +struct allow_float_fallback { static constexpr bool value = true; }; -}; // namespace detail +} // namespace detail #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ namespace ops { \ template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half input) { \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t input) { \ return FUN1(input); \ } \ }; \ } \ namespace detail { \ template<> \ - struct apply_impl, 2, __half, __half> { \ - KERNEL_FLOAT_INLINE static void call(ops::NAME<__half>, __half* result, const __half* a) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}); \ + struct apply_impl, 2, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void call(ops::NAME, half_t* result, const half_t* a) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}); \ result[0] = r.x, result[1] = r.y; \ } \ }; \ @@ -3930,52 +4128,61 @@ struct allow_float_fallback<__half> { #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) #endif -KERNEL_FLOAT_FP16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_FP16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_FP16_UNARY_FUN(ceil, ::hceil, ::h2ceil) -KERNEL_FLOAT_FP16_UNARY_FUN(cos, ::hcos, ::h2cos) -KERNEL_FLOAT_FP16_UNARY_FUN(exp, ::hexp, ::h2exp) -KERNEL_FLOAT_FP16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_FP16_UNARY_FUN(floor, ::hfloor, ::h2floor) -KERNEL_FLOAT_FP16_UNARY_FUN(log, ::hlog, ::h2log) -KERNEL_FLOAT_FP16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_FP16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin) -KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) -KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) +KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) +KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) + +KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp2, hexp2, h2exp2) +KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) +KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log2, hlog2, h2log2) +KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) + +KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, hrsqrt, h2rsqrt) +KERNEL_FLOAT_FP16_UNARY_FUN(rcp, hrcp, h2rcp) + +KERNEL_FLOAT_FP16_UNARY_FUN(abs, __habs, __habs2) +KERNEL_FLOAT_FP16_UNARY_FUN(floor, hfloor, h2floor) +KERNEL_FLOAT_FP16_UNARY_FUN(ceil, hceil, h2ceil) +KERNEL_FLOAT_FP16_UNARY_FUN(rint, hrint, h2rint) +KERNEL_FLOAT_FP16_UNARY_FUN(trunc, htrunc, h2trunc) +KERNEL_FLOAT_FP16_UNARY_FUN(negate, __hneg, __hneg2) #if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__half> { \ - KERNEL_FLOAT_INLINE __half operator()(__half left, __half right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __half, __half, __half> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__half>, __half* result, const __half* a, const __half* b) { \ - __half2 r = FUN2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ +#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME { \ + KERNEL_FLOAT_INLINE half_t operator()(half_t left, half_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, half_t, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME, half_t* result, const half_t* a, const half_t* b) { \ + half2_t r = FUN2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ } #else #define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) #endif +// There are not available in HIP +#if KERNEL_FLOAT_IS_CUDA +KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) +KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) +#endif + KERNEL_FLOAT_FP16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_FP16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2) KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div) -KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2) -KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2) KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2) KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __hneu, __hneu2) @@ -3987,8 +4194,8 @@ KERNEL_FLOAT_FP16_BINARY_FUN(greater_equal, __hge, __hgt2) #if KERNEL_FLOAT_IS_DEVICE namespace ops { template<> -struct fma<__half> { - KERNEL_FLOAT_INLINE __half operator()(__half a, __half b, __half c) const { +struct fma { + KERNEL_FLOAT_INLINE half_t operator()(half_t a, half_t b, half_t c) const { return __hfma(a, b, c); } }; @@ -3996,32 +4203,70 @@ struct fma<__half> { namespace detail { template<> -struct apply_impl, 2, __half, __half, __half, __half> { +struct apply_impl, 2, half_t, half_t, half_t, half_t> { KERNEL_FLOAT_INLINE static void - call(ops::fma<__half>, __half* result, const __half* a, const __half* b, const __half* c) { - __half2 r = __hfma2(__half2 {a[0], a[1]}, __half2 {b[0], b[1]}, __half2 {c[0], c[1]}); + call(ops::fma, half_t* result, const half_t* a, const half_t* b, const half_t* c) { + half2_t r = __hfma2(half2_t {a[0], a[1]}, half2_t {b[0], b[1]}, half2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; + +// clang-format off +#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \ + template \ + struct apply_impl, N, half_t, half_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, half_t* output, const half_t* input) { \ + float v[N]; \ + map_impl, N, float, half_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, half_t, float>::call({}, output, v); \ + } \ + }; +// clang-format on + +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH) } // namespace detail #endif #define KERNEL_FLOAT_FP16_CAST(T, TO_HALF, FROM_HALF) \ namespace ops { \ template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __half operator()(T input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE half_t operator()(T input) { \ return TO_HALF; \ } \ }; \ template<> \ - struct cast<__half, T> { \ - KERNEL_FLOAT_INLINE T operator()(__half input) { \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(half_t input) { \ return FROM_HALF; \ } \ }; \ } +// Only CUDA has a special `__double2half` intrinsic +#if KERNEL_FLOAT_IS_HIP +#define KERNEL_FLOAT_FP16_CAST_FWD(T) \ + KERNEL_FLOAT_FP16_CAST(T, static_cast<_Float16>(input), static_cast(input)) + +KERNEL_FLOAT_FP16_CAST_FWD(double) +KERNEL_FLOAT_FP16_CAST_FWD(float) + +KERNEL_FLOAT_FP16_CAST_FWD(char) +KERNEL_FLOAT_FP16_CAST_FWD(signed char) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned char) + +KERNEL_FLOAT_FP16_CAST_FWD(signed short) +KERNEL_FLOAT_FP16_CAST_FWD(signed int) +KERNEL_FLOAT_FP16_CAST_FWD(signed long) +KERNEL_FLOAT_FP16_CAST_FWD(signed long long) + +KERNEL_FLOAT_FP16_CAST_FWD(unsigned short) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned int) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long) +KERNEL_FLOAT_FP16_CAST_FWD(unsigned long long) +#else KERNEL_FLOAT_FP16_CAST(double, __double2half(input), double(__half2float(input))); KERNEL_FLOAT_FP16_CAST(float, __float2half(input), __half2float(input)); @@ -4030,20 +4275,20 @@ KERNEL_FLOAT_FP16_CAST(char, __int2half_rn(input), (char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed char, __int2half_rn(input), (signed char)__half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned char, __int2half_rn(input), (unsigned char)__half2int_rz(input)); -KERNEL_FLOAT_FP16_CAST(signed short, __half2short_rz(input), __short2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(signed int, __half2int_rz(input), __int2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(signed short, __short2half_rn(input), __half2short_rz(input)); +KERNEL_FLOAT_FP16_CAST(signed int, __int2half_rn(input), __half2int_rz(input)); KERNEL_FLOAT_FP16_CAST(signed long, __ll2half_rn(input), (signed long)(__half2ll_rz(input))); KERNEL_FLOAT_FP16_CAST(signed long long, __ll2half_rn(input), __half2ll_rz(input)); -KERNEL_FLOAT_FP16_CAST(unsigned short, __half2ushort_rz(input), __ushort2half_rn(input)); -KERNEL_FLOAT_FP16_CAST(unsigned int, __half2uint_rz(input), __uint2half_rn(input)); +KERNEL_FLOAT_FP16_CAST(unsigned short, __ushort2half_rn(input), __half2ushort_rz(input)); +KERNEL_FLOAT_FP16_CAST(unsigned int, __uint2half_rn(input), __half2uint_rz(input)); KERNEL_FLOAT_FP16_CAST(unsigned long, __ull2half_rn(input), (unsigned long)(__half2ull_rz(input))); KERNEL_FLOAT_FP16_CAST(unsigned long long, __ull2half_rn(input), __half2ull_rz(input)); +#endif -using half = __half; -KERNEL_FLOAT_VECTOR_ALIAS(half, __half) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) +KERNEL_FLOAT_VECTOR_ALIAS(half, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) } // namespace kernel_float @@ -4056,7 +4301,16 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, __half) #if KERNEL_FLOAT_BF16_AVAILABLE +//#define CUDA_NO_BFLOAT16 (1) +//#define __CUDA_NO_BFLOAT16_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT162_OPERATORS__ (1) +//#define __CUDA_NO_BFLOAT16_CONVERSIONS__ (1) + +#if KERNEL_FLOAT_IS_CUDA #include +#elif KERNEL_FLOAT_IS_HIP +#include +#endif @@ -4064,103 +4318,144 @@ KERNEL_FLOAT_VECTOR_ALIAS(half, __half) namespace kernel_float { +#if KERNEL_FLOAT_IS_CUDA +using bfloat16_t = __nv_bfloat16; +using bfloat16x2_t = __nv_bfloat162; +#elif KERNEL_FLOAT_IS_HIP +using bfloat16_t = __hip_bfloat16; +using bfloat16x2_t = __hip_bfloat162; +#endif + +#if KERNEL_FLOAT_IS_CUDA && __CUDA_ARCH__ >= 800 +#define KERNEL_FLOAT_BF16_OPS_SUPPORTED 1 +#endif + template<> -struct preferred_vector_size<__nv_bfloat16> { +struct preferred_vector_size { static constexpr size_t value = 2; }; -KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(__nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, __nv_bfloat16) +KERNEL_FLOAT_DEFINE_PROMOTED_FLOAT(bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, bfloat16_t) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(double, bfloat16_t) template<> -struct into_vector_impl<__nv_bfloat162> { - using value_type = __nv_bfloat16; +struct into_vector_impl { + using value_type = bfloat16_t; using extent_type = extent<2>; KERNEL_FLOAT_INLINE - static vector_storage<__nv_bfloat16, 2> call(__nv_bfloat162 input) { + static vector_storage call(bfloat16x2_t input) { return {input.x, input.y}; } }; namespace detail { template<> -struct allow_float_fallback<__nv_bfloat16> { +struct allow_float_fallback { static constexpr bool value = true; }; }; // namespace detail -#if KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(__nv_bfloat16 input) { \ - return FUN1(input); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::NAME<__nv_bfloat16>, __nv_bfloat16* result, const __nv_bfloat16* a) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ - } -#else -#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) -#endif - -#if KERNEL_FLOAT_CUDA_ARCH >= 800 -KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) -KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) -KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +#define KERNEL_FLOAT_BF16_UNARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t input) { \ + return FUN1(input); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl, 2, bfloat16_t, bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::NAME, bfloat16_t* result, const bfloat16_t* a) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ + } + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) + KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp2, ::hexp2, ::h2exp2) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) -KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log2, ::hlog2, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) -KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) -KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) + KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) -KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt) KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp) -#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ - namespace ops { \ - template<> \ - struct NAME<__nv_bfloat16> { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 \ - operator()(__nv_bfloat16 left, __nv_bfloat16 right) const { \ - return FUN1(left, right); \ - } \ - }; \ - } \ - namespace detail { \ - template<> \ - struct apply_impl, 2, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16> { \ - KERNEL_FLOAT_INLINE static void call( \ - ops::NAME<__nv_bfloat16>, \ - __nv_bfloat16* result, \ - const __nv_bfloat16* a, \ - const __nv_bfloat16* b) { \ - __nv_bfloat162 r = FUN2(__nv_bfloat162 {a[0], a[1]}, __nv_bfloat162 {b[0], b[1]}); \ - result[0] = r.x, result[1] = r.y; \ - } \ - }; \ - } -#else -#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) +KERNEL_FLOAT_BF16_UNARY_FUN(abs, ::__habs, ::__habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(floor, ::hfloor, ::h2floor) +KERNEL_FLOAT_BF16_UNARY_FUN(ceil, ::hceil, ::h2ceil) +KERNEL_FLOAT_BF16_UNARY_FUN(rint, ::hrint, ::h2rint) +KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, ::__hneg, ::__hneg2) + +// For some reason, HIP struggles with the functions `__habs` and `__hneg`. We define them here using bitwise ops. +// For CUDA, we can just use the regular bfloat16 functions (see above). +#elif KERNEL_FLOAT_IS_HIP +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_habs(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data &= 0x7FFF; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat16 hip_hneg(const __hip_bfloat16 a) { + __hip_bfloat16 res = a; + res.data ^= 0x8000; + return res; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_habs2(const __hip_bfloat162 a) { + return {hip_habs(a.x), hip_habs(a.y)}; +} + +KERNEL_FLOAT_INLINE __hip_bfloat162 hip_hneg2(const __hip_bfloat162 a) { + return {hip_hneg(a.x), hip_hneg(a.y)}; +} + +KERNEL_FLOAT_BF16_UNARY_FUN(abs, hip_habs, hip_habs2) +KERNEL_FLOAT_BF16_UNARY_FUN(negate, hip_hneg, hip_hneg2) #endif +#define KERNEL_FLOAT_BF16_BINARY_FUN(NAME, FUN1, FUN2) \ + namespace ops { \ + template<> \ + struct NAME { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t left, bfloat16_t right) const { \ + return ops::cast {}(FUN1(left, right)); \ + } \ + }; \ + } \ + namespace detail { \ + template<> \ + struct apply_impl< \ + accurate_policy, \ + ops::NAME, \ + 2, \ + bfloat16_t, \ + bfloat16_t, \ + bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void call( \ + ops::NAME, \ + bfloat16_t* result, \ + const bfloat16_t* a, \ + const bfloat16_t* b) { \ + bfloat16x2_t r = FUN2(bfloat16x2_t {a[0], a[1]}, bfloat16x2_t {b[0], b[1]}); \ + result[0] = r.x, result[1] = r.y; \ + } \ + }; \ + } + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED KERNEL_FLOAT_BF16_BINARY_FUN(add, __hadd, __hadd2) KERNEL_FLOAT_BF16_BINARY_FUN(subtract, __hsub, __hsub2) KERNEL_FLOAT_BF16_BINARY_FUN(multiply, __hmul, __hmul2) @@ -4174,13 +4469,13 @@ KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2) KERNEL_FLOAT_BF16_BINARY_FUN(less_equal, __hle, __hle2) KERNEL_FLOAT_BF16_BINARY_FUN(greater, __hgt, __hgt2) KERNEL_FLOAT_BF16_BINARY_FUN(greater_equal, __hge, __hgt2) +#endif -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED namespace ops { template<> -struct fma<__nv_bfloat16> { - KERNEL_FLOAT_INLINE __nv_bfloat16 - operator()(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) const { +struct fma { + KERNEL_FLOAT_INLINE bfloat16_t operator()(bfloat16_t a, bfloat16_t b, bfloat16_t c) const { return __hfma(a, b, c); } }; @@ -4189,95 +4484,98 @@ struct fma<__nv_bfloat16> { namespace detail { template<> struct apply_impl< - ops::fma<__nv_bfloat16>, + accurate_policy, + ops::fma, 2, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16, - __nv_bfloat16> { + bfloat16_t, + bfloat16_t, + bfloat16_t, + bfloat16_t> { KERNEL_FLOAT_INLINE static void call( - ops::fma<__nv_bfloat16>, - __nv_bfloat16* result, - const __nv_bfloat16* a, - const __nv_bfloat16* b, - const __nv_bfloat16* c) { - __nv_bfloat162 r = __hfma2( - __nv_bfloat162 {a[0], a[1]}, - __nv_bfloat162 {b[0], b[1]}, - __nv_bfloat162 {c[0], c[1]}); + ops::fma, + bfloat16_t* result, + const bfloat16_t* a, + const bfloat16_t* b, + const bfloat16_t* c) { + bfloat16x2_t r = __hfma2( + bfloat16x2_t {a[0], a[1]}, + bfloat16x2_t {b[0], b[1]}, + bfloat16x2_t {c[0], c[1]}); result[0] = r.x, result[1] = r.y; } }; -} // namespace detail -#endif - -namespace ops { -template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(double input) { - return __double2bfloat16(input); - }; -}; -template<> -struct cast { - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(float input) { - return __float2bfloat16(input); +// clang-format off +#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \ + template \ + struct apply_impl, N, bfloat16_t, bfloat16_t> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::OP, bfloat16_t* output, const bfloat16_t* input) { \ + float v[N]; \ + map_impl, N, float, bfloat16_t>::call({}, v, input); \ + map_impl, N, float, float>::call({}, v, v); \ + map_impl, N, bfloat16_t, float>::call({}, output, v); \ + } \ }; -}; +// clang-format on -template<> -struct cast<__nv_bfloat16, float> { - KERNEL_FLOAT_INLINE float operator()(__nv_bfloat16 input) { - return __bfloat162float(input); - }; -}; -} // namespace ops +KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH) +} // namespace detail +#endif -#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ - namespace ops { \ - template<> \ - struct cast { \ - KERNEL_FLOAT_INLINE __nv_bfloat16 operator()(T input) { \ - return TO_HALF; \ - } \ - }; \ - template<> \ - struct cast<__nv_bfloat16, T> { \ - KERNEL_FLOAT_INLINE T operator()(__nv_bfloat16 input) { \ - return FROM_HALF; \ - } \ - }; \ - } - -#if KERNEL_FLOAT_CUDA_ARCH >= 800 +#define KERNEL_FLOAT_BF16_CAST(T, TO_HALF, FROM_HALF) \ + namespace ops { \ + template<> \ + struct cast { \ + KERNEL_FLOAT_INLINE bfloat16_t operator()(T input) { \ + return TO_HALF; \ + } \ + }; \ + template<> \ + struct cast { \ + KERNEL_FLOAT_INLINE T operator()(bfloat16_t input) { \ + return FROM_HALF; \ + } \ + }; \ + } + +KERNEL_FLOAT_BF16_CAST(float, __float2bfloat16(input), __bfloat162float(input)) +KERNEL_FLOAT_BF16_CAST(double, __double2bfloat16(input), __bfloat162float(input)) + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED // clang-format off // there are no official char casts. Instead, cast to int and then to char KERNEL_FLOAT_BF16_CAST(char, __int2bfloat16_rn(input), (char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed char, __int2bfloat16_rn(input), (signed char)__bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned char, __int2bfloat16_rn(input), (unsigned char)__bfloat162int_rz(input)); -KERNEL_FLOAT_BF16_CAST(signed short, __bfloat162short_rz(input), __short2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(signed int, __bfloat162int_rz(input), __int2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(signed short, __short2bfloat16_rn(input), __bfloat162short_rz(input)); +KERNEL_FLOAT_BF16_CAST(signed int, __int2bfloat16_rn(input), __bfloat162int_rz(input)); KERNEL_FLOAT_BF16_CAST(signed long, __ll2bfloat16_rn(input), (signed long)(__bfloat162ll_rz(input))); KERNEL_FLOAT_BF16_CAST(signed long long, __ll2bfloat16_rn(input), __bfloat162ll_rz(input)); -KERNEL_FLOAT_BF16_CAST(unsigned short, __bfloat162ushort_rz(input), __ushort2bfloat16_rn(input)); -KERNEL_FLOAT_BF16_CAST(unsigned int, __bfloat162uint_rz(input), __uint2bfloat16_rn(input)); +KERNEL_FLOAT_BF16_CAST(unsigned short, __ushort2bfloat16_rn(input), __bfloat162ushort_rz(input)); +KERNEL_FLOAT_BF16_CAST(unsigned int, __uint2bfloat16_rn(input), __bfloat162uint_rz(input)); KERNEL_FLOAT_BF16_CAST(unsigned long, __ull2bfloat16_rn(input), (unsigned long)(__bfloat162ull_rz(input))); KERNEL_FLOAT_BF16_CAST(unsigned long long, __ull2bfloat16_rn(input), __bfloat162ull_rz(input)); // clang-format on -#else +#endif + +#if KERNEL_FLOAT_IS_CUDA +//KERNEL_FLOAT_BF16_CAST( +// bool, +// __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, +// (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); +#elif KERNEL_FLOAT_IS_HIP KERNEL_FLOAT_BF16_CAST( bool, - __nv_bfloat16_raw {input ? (unsigned short)0 : (unsigned short)0x3C00}, - (__nv_bfloat16_raw(input).x & 0x7FFF) != 0); + __hip_bfloat16 {input ? (unsigned short)0 : (unsigned short)0x3C00}, + (__hip_bfloat16(input).data & 0x7FFF) != 0); #endif -using bfloat16 = __nv_bfloat16; -KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(float16x, __nv_bfloat16) -//KERNEL_FLOAT_TYPE_ALIAS(f16x, __nv_bfloat16) +KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(float16x, bfloat16_t) +//KERNEL_FLOAT_TYPE_ALIAS(f16x, bfloat16_t) } // namespace kernel_float @@ -4286,12 +4584,12 @@ KERNEL_FLOAT_VECTOR_ALIAS(bfloat16x, __nv_bfloat16) namespace kernel_float { template<> -struct promote_type<__nv_bfloat16, __half> { +struct promote_type { using type = float; }; template<> -struct promote_type<__half, __nv_bfloat16> { +struct promote_type { using type = float; }; @@ -4301,6 +4599,438 @@ struct promote_type<__half, __nv_bfloat16> { #endif #endif //KERNEL_FLOAT_BF16_H +#pragma once + + + + + + +namespace kernel_float { + +namespace approx { + +static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); +static_assert(sizeof(unsigned short) * 8 == 16, "invalid size of unsigned short"); +using uint32_t = unsigned int; +using uint16_t = unsigned short; + +template +KERNEL_FLOAT_DEVICE T transmute(const U& input) { + static_assert(sizeof(T) == sizeof(U), "types must have equal size"); + T result {}; + ::memcpy(&result, &input, sizeof(T)); + return result; +} + +KERNEL_FLOAT_DEVICE uint32_t +bitwise_if_else(uint32_t condition, uint32_t if_true, uint32_t if_false) { + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + // equivalent to (condition & if_true) | ((~condition) & if_false) + asm("lop3.b32 %0, %1, %2, %3, 0xCA;" + : "=r"(result) + : "r"(condition), "r"(if_true), "r"(if_false)); +#else + result = (condition & if_true) | ((~condition) & if_false); +#endif + return result; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x) { + return y; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x, T coef, TRest... coefs) { + y = __hfma2(x, y, T2 {coef, coef}); + return eval_poly_recur(y, x, coefs...); +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly(T2 x, T coef, TRest... coefs) { + return eval_poly_recur(T2 {coef, coef}, x, coefs...); +} + +#define KERNEL_FLOAT_DEFINE_POLY(NAME, N, ...) \ + template \ + struct NAME { \ + template \ + static KERNEL_FLOAT_DEVICE T2 call(T2 x) { \ + return eval_poly(x, __VA_ARGS__); \ + } \ + }; + +template +struct sin_poly: sin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 1, 1.365) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 2, -21.56, 5.18) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 3, 53.53, -38.06, 6.184) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 4, -56.1, 77.94, -41.1, 6.277) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 5, 32.78, -74.5, 81.4, -41.34, 6.28) + +template +struct cos_poly: cos_poly {}; +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 1, 0.0) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 2, -8.0, 0.6943) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 3, 38.94, -17.5, 0.9707) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 4, -59.66, 61.12, -19.56, 0.9985) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 5, 45.66, -82.4, 64.7, -19.73, 1.0) + +template +struct asin_poly: asin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 1, 1.531) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 2, -0.169, 1.567) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEVICE half2_t flipsign(half2_t input, half2_t sign) { + // Flip signbit of input when sign<0 + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + asm("lop3.b32 %0, %1, %2, %3, 0x6A;" + : "=r"(result) + : "r"(0x80008000), "r"(transmute(sign)), "r"(transmute(input))); +#else + result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); +#endif + + return transmute(result); +} + +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(half2_t a, half2_t b) { + uint32_t val; +#if KERNEL_FLOAT_IS_CUDA + uint32_t ai = *(reinterpret_cast(&a)); + uint32_t bi = *(reinterpret_cast(&b)); + asm("{ set.gt.u32.f16x2 %0,%1,%2;\n}" : "=r"(val) : "r"(ai), "r"(bi)); +#else + val = transmute(make_short2(a.x > b.x ? ~0 : 0, a.y > b.y ? ~0 : 0)); +#endif + return val; +} + +KERNEL_FLOAT_INLINE half2_t make_half2(half x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE half2_t normalize_trig_input(half2_t x) { + /* Using rint is too slow. Round using floating-point magic instead. */ + // half2_t x = arg * make_half2(-0.15915494309); + // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); + + // 1/(2pi) = 0.15915494309189535 + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + half2_t ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE half2_t cos(half2_t x) { + half2_t xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE half2_t sin(half2_t x) { + half2_t xf = normalize_trig_input(x); + return sin_poly::call(__hmul2(xf, xf)) * xf; +} + +template +KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) { + // Flip bits + uint32_t m = ~transmute(x); + + // Multiply by bias (add contant) + half2_t y = transmute(uint32_t(0x776d776d) + m); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + // y += y * (1 - x * y) + y = __hfma2(y, __hfma2(-x, y, make_half2(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) { + // A small number added such that rsqrt(0) does not return NaN + static constexpr double EPS = 0.00000768899917602539; + + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias + static constexpr uint32_t BIAS = 0x199c199c; + half2_t y = transmute(uint32_t(r) + BIAS); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS)); + half2_t correction = __hfma2(half_x, y * y, make_half2(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE half2_t sqrt(half2_t x) { + if (Iter == 1) { + half2_t y = rsqrt<0>(x); + + // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` + half2_t xy = x * y; + return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); + } + + return x * rsqrt(x); +} + +template +KERNEL_FLOAT_DEVICE half2_t asin(half2_t x) { + static constexpr double HALF_PI = 1.57079632679; + auto abs_x = __habs2(x); + auto v = asin_poly::call(abs_x); + auto abs_y = __hfma2(-v, sqrt(make_half2(1) - abs_x), make_half2(HALF_PI)); + return flipsign(abs_y, x); +} + +template +KERNEL_FLOAT_DEVICE half2_t acos(half2_t x) { + static constexpr double HALF_PI = 1.57079632679; + return make_half2(HALF_PI) - asin(x); +} + +template +KERNEL_FLOAT_DEVICE half2_t exp(half2_t x) { + half2_t y; + + if (Deg == 0) { + // Bring the value to range [32, 64] + // 1.442 = 1/log(2) + // 46.969 = 32.5/log(2) + half2_t m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + + // Transmute to int, shift higher mantissa bits into exponent field. + y = transmute((transmute(m) & 0x03ff03ff) << 5); + } else { + // Add a large number to round to an integer + half2_t v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + + // The exponent is now in the lower 5 bits. Shift that into the exponent field. + half2_t exp = transmute((transmute(v) & 0x001f001f) << 10); + + // The fractional part can be obtained from "1231-v". + // 0.6934 = log(2) + half2_t frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + + // This is the Taylor expansion of "exp(x)-1" around 0 + half2_t adjust; + if (Deg == 1) { + adjust = frac; + } else if (Deg == 2) { + // adjust = frac + 0.5 * frac^2 + adjust = __hfma2(frac, __hmul2(frac, make_half2(0.5)), frac); + } else /* if (Deg == 2) */ { + // adjust = frac + 0.5 * frac^2 + 1/6 * frac^3 + adjust = __hfma2( + frac, + __hmul2(__hfma2(frac, make_half2(0.1666), make_half2(0.5)), frac), + frac); + } + + // result = exp * (adjust + 1) + y = __hfma2(exp, adjust, exp); + } + + // Values below -10.39 (= -15*log(2)) become zero + uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); + return transmute(zero_mask & transmute(y)); +} + +template +KERNEL_FLOAT_DEVICE half2_t log(half2_t arg) { + // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) + uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); + + // 0.6934 = log(2) + // 32.53 = 46.969*log(2) + return __hfma2(transmute(bits), make_half2(0.6934), make_half2(-32.53125)); +} + +template +KERNEL_FLOAT_DEVICE half2_t tanh(half2_t x) { + if (Deg == 0) { + return x * rcp<0>(make_half2(0.2869) + __habs2(x)); + } else { + auto c0 = make_half2(0.4531); + auto c1 = make_half2(0.5156); + auto x2b = __hfma2(x, x, c1); + return (x * x2b) * rcp(__hfma2(x2b, __habs2(x), c0)); + } +} + +#endif // KERNEL_FLOAT_FP16_AVAILABLE + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(bfloat16_t x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE bfloat16x2_t make_bfloat162(double x) { + return {__double2bfloat16(x), __double2bfloat16(x)}; +} + +KERNEL_FLOAT_DEVICE bfloat16x2_t normalize_trig_input(bfloat16x2_t x) { + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + bfloat16x2_t ws = __hadd2( + __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), + make_bfloat162(OFFSET)); + return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t cos(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t sin(bfloat16x2_t x) { + bfloat16x2_t xf = normalize_trig_input(x); + return __hmul2(sin_poly::call(__hmul2(xf, xf)), xf); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t rcp(bfloat16x2_t x) { + bfloat16x2_t y = transmute(uint32_t(0x7ef07ef0) + ~transmute(x)); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + y = __hfma2(y, __hfma2(__hneg2(x), y, make_bfloat162(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t rsqrt(bfloat16x2_t x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias (0x1f36) + bfloat16x2_t y = transmute(uint32_t(r) + uint32_t(0x1f361f36)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + bfloat16x2_t half_x = __hmul2(make_bfloat162(-0.5), x); + bfloat16x2_t correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { + return __hmul2(x, rsqrt(x)); +} + +template +KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { + static constexpr float SCALE = 1.44272065994 / 256.0; + static constexpr float OFFSET = 382.4958400542335; + static constexpr float MINIMUM = 382; + + float a = fmaxf(fmaf(__bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(__bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); + + return { + transmute(uint16_t(transmute(a))), + transmute(uint16_t(transmute(b)))}; +} +#endif +} // namespace approx + +namespace detail { +template +struct apply_impl, F, 1, T, T> { + KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { + T in2[2], out2[2]; + in2[0] = input[0]; + apply_impl, F, 2, T, T>::call(fun, out2, in2); + output[0] = out2[0]; + } +}; +} // namespace detail + +#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN, 2, T, T> { \ + KERNEL_FLOAT_INLINE static void call(ops::FUN, T* output, const T* input) { \ + auto res = approx::FUN({input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + \ + template<> \ + struct apply_impl, 2, T, T>: \ + apply_impl, ops::FUN, 2, T, T> {}; \ + } + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2) +#endif + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0) +//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(sin) +KERNEL_FLOAT_DEFINE_APPROX_FUN(cos) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(exp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(log) + +} // namespace kernel_float #ifndef KERNEL_FLOAT_FP8_H #define KERNEL_FLOAT_FP8_H @@ -4367,7 +5097,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { #define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \ namespace detail { \ template<> \ - struct apply_impl, 2, FP8_TY, T> { \ + struct apply_impl, 2, FP8_TY, T> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, FP8_TY* result, const T* v) { \ __half2_raw x; \ memcpy(&x, v, 2 * sizeof(T)); \ @@ -4376,7 +5106,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> { } \ }; \ template<> \ - struct apply_impl, 2, T, FP8_TY> { \ + struct apply_impl, 2, T, FP8_TY> { \ KERNEL_FLOAT_INLINE static void call(ops::cast, T* result, const FP8_TY* v) { \ __nv_fp8x2_storage_t x; \ memcpy(&x, v, 2 * sizeof(FP8_TY)); \ @@ -4394,12 +5124,12 @@ KERNEL_FLOAT_FP8_CAST(double) namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__half) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(half_t) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_FP16_AVAILABLE @@ -4408,12 +5138,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2) namespace kernel_float { -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3) -KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3) +KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2) -KERNEL_FLOAT_FP8_CAST(__nv_bfloat16) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3) -KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2) +KERNEL_FLOAT_FP8_CAST(bfloat16_t) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3) +KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2) } // namespace kernel_float #endif // KERNEL_FLOAT_BF16_AVAILABLE @@ -4482,14 +5212,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double) KERNEL_FLOAT_TYPE_ALIAS(float64x, double) #if KERNEL_FLOAT_FP16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(half, __half) -KERNEL_FLOAT_TYPE_ALIAS(f16x, __half) -KERNEL_FLOAT_TYPE_ALIAS(float16x, __half) +KERNEL_FLOAT_TYPE_ALIAS(half, half_t) +KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t) +KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t) #endif #if KERNEL_FLOAT_BF16_AVAILABLE -KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __nv_bfloat16) -KERNEL_FLOAT_TYPE_ALIAS(bf16x, __nv_bfloat16) +KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, bfloat16_t) +KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t) #endif #if KERNEL_FLOAT_BF8_AVAILABLE @@ -4503,12 +5233,12 @@ static constexpr extent kextent = {}; template KERNEL_FLOAT_INLINE kvec, sizeof...(Args)> make_kvec(Args&&... args) { - return ::kernel_float::make_vec(std::forward(args)...); + return ::kernel_float::make_vec(static_cast(args)...); }; template KERNEL_FLOAT_INLINE into_vector_type into_kvec(V&& input) { - return ::kernel_float::into_vec(std::forward(input)); + return ::kernel_float::into_vec(static_cast(input)); } template @@ -4876,9 +5606,7 @@ struct tiling { using index_type = IndexType; using point_type = vector>; -#if KERNEL_FLOAT_IS_DEVICE - __forceinline__ __device__ tiling() : block_(threadIdx) {} -#endif + __forceinline__ __device__ tiling() : block_(dim3(threadIdx)) {} KERNEL_FLOAT_INLINE tiling(BlockDim block, vec offset = {}) : block_(block), offset_(offset) {} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bd8edb6..523fe50 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,18 +1,24 @@ cmake_minimum_required(VERSION 3.16) -project(tests) +project(tests LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) file(GLOB FILES *.cu) add_executable(kernel_float_tests ${FILES}) target_link_libraries(kernel_float_tests PRIVATE kernel_float) -target_compile_options(kernel_float_tests PRIVATE "--extended-lambda") -set_target_properties(kernel_float_tests PROPERTIES CUDA_ARCHITECTURES "70;80") - -target_compile_options(kernel_float_tests PRIVATE "-ftime-report -ftime-report-details") add_subdirectory(Catch2) target_link_libraries(kernel_float_tests PRIVATE Catch2::Catch2WithMain) -find_package(CUDA REQUIRED) -target_include_directories(kernel_float_tests PRIVATE ${CUDA_TOOLKIT_INCLUDE}) +if(${KERNEL_FLOAT_LANGUAGE_CUDA}) + find_package(CUDA REQUIRED) + target_include_directories(kernel_float_tests PRIVATE ${CUDA_TOOLKIT_INCLUDE}) + + target_compile_options(kernel_float_tests PRIVATE "-ftime-report -ftime-report-details") + target_compile_options(kernel_float_tests PRIVATE "--extended-lambda") + set_target_properties(kernel_float_tests PROPERTIES CUDA_ARCHITECTURES "70;80") +endif() + +if(${KERNEL_FLOAT_LANGUAGE_HIP}) + set_source_files_properties(${FILES} PROPERTIES LANGUAGE HIP) +endif() \ No newline at end of file diff --git a/tests/common.h b/tests/common.h index ff8072d..d209774 100644 --- a/tests/common.h +++ b/tests/common.h @@ -1,8 +1,5 @@ #pragma once -#include -#include - #include #include "catch2/catch_all.hpp" @@ -10,10 +7,22 @@ namespace kf = kernel_float; +#if KERNEL_FLOAT_IS_HIP +#define cudaError_t hipError_t +#define cudaSuccess hipSuccess +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaSetDevice hipSetDevice +#define cudaDeviceSynchronize hipDeviceSynchronize + +using __nv_bfloat16 = __hip_bfloat16; +#endif + namespace detail { +#if KERNEL_FLOAT_IS_CUDA __attribute__((noinline)) static __host__ __device__ void __assertion_failed(const char* expr, const char* file, int line) { -#ifndef __CUDA_ARCH__ +#if KERNEL_FLOAT_IS_HOST std::string msg = "assertion failed: " + std::string(expr) + " (" + file + ":" + std::to_string(line) + ")"; throw std::runtime_error(msg); @@ -24,6 +33,23 @@ __assertion_failed(const char* expr, const char* file, int line) { ; #endif } + +#elif KERNEL_FLOAT_IS_HIP +__attribute__((noinline)) static __host__ void +__assertion_failed(const char* expr, const char* file, int line) { + std::string msg = + "assertion failed: " + std::string(expr) + " (" + file + ":" + std::to_string(line) + ")"; + throw std::runtime_error(msg); +} + +__attribute__((noinline)) static __device__ void +__assertion_failed(const char* expr, const char* file, int line) { + printf("assertion failed: %s (%s:%d)\n", expr, file, line); + __builtin_trap(); + while (1) + ; +} +#endif } // namespace detail #define ASSERT(...) \ @@ -52,28 +78,28 @@ struct equals_helper { template<> struct equals_helper { static __host__ __device__ bool call(const double& left, const double& right) { - return (isnan(left) && isnan(right)) || (left == right); + return (std::isnan(left) && std::isnan(right)) || (left == right); } }; template<> struct equals_helper { static __host__ __device__ bool call(const float& left, const float& right) { - return (isnan(left) && isnan(right)) || (left == right); + return (std::isnan(left) && std::isnan(right)) || (left == right); } }; template<> struct equals_helper<__half> { static __host__ __device__ bool call(const __half& left, const __half& right) { - return equals_helper::call(float(left), float(right)); + return equals_helper::call(__half2float(left), __half2float(right)); } }; template<> struct equals_helper<__nv_bfloat16> { static __host__ __device__ bool call(const __nv_bfloat16& left, const __nv_bfloat16& right) { - return equals_helper::call(float(left), float(right)); + return equals_helper::call(__bfloat162float(left), __bfloat162float(right)); } }; @@ -123,14 +149,14 @@ struct approx_helper { template<> struct approx_helper<__half> { static __host__ __device__ bool call(__half left, __half right) { - return approx_helper::call(double(left), double(right), 0.01); + return approx_helper::call(__half2float(left), __half2float(right), 0.01); } }; template<> struct approx_helper<__nv_bfloat16> { static __host__ __device__ bool call(__nv_bfloat16 left, __nv_bfloat16 right) { - return approx_helper::call(double(left), double(right), 0.05); + return approx_helper::call(__bfloat162float(left), __bfloat162float(right), 0.05); } }; } // namespace detail @@ -211,14 +237,14 @@ struct generator_value>> { template<> struct generator_value<__half> { __host__ __device__ static __half call(uint64_t seed) { - return __half(generator_value::call(seed)); + return __float2half(generator_value::call(seed)); } }; template<> struct generator_value<__nv_bfloat16> { __host__ __device__ static __nv_bfloat16 call(uint64_t seed) { - return __nv_bfloat16(generator_value::call(seed)); + return __float2bfloat16(generator_value::call(seed)); } }; } // namespace detail diff --git a/tests/memory.cu b/tests/memory.cu index 2621a92..7ab7ee4 100644 --- a/tests/memory.cu +++ b/tests/memory.cu @@ -156,7 +156,8 @@ struct vector_ptr_test { }; { - auto ptr = kf::vector_ptr {&storage.data[0]}; + kf::vector_ptr storage_ptr = kf::assert_aligned(storage.data); + kf::vector_ptr ptr = storage_ptr; ASSERT_EQ(ptr.get(), static_cast(storage.data)); T expected[N] = {T(double(N + I))...}; @@ -174,7 +175,8 @@ struct vector_ptr_test { } { - auto ptr = kf::vector_ptr {&storage.data[0]}; + kf::vector_ptr storage_ptr = kf::assert_aligned(storage.data); + kf::vector_ptr ptr = storage_ptr; ASSERT_EQ(ptr.get(), static_cast(storage.data)); T expected[N] = {T(double(N + I))...}; @@ -182,7 +184,7 @@ struct vector_ptr_test { auto a = ptr.read(1); ASSERT_EQ_ALL(a[I], expected[I]); - auto b = ptr[1]; + kf::vec b = ptr[1]; ASSERT_EQ_ALL(b[I], expected[I]); kf::vec c = ptr.at(1); @@ -191,16 +193,20 @@ struct vector_ptr_test { kf::vec overwrite = {T(double(100 + I))...}; ptr.at(1) = overwrite; - auto e = ptr[1]; + kf::vec e = ptr[1]; ASSERT_EQ_ALL(e[I], overwrite[I]); ptr.write(1, T(1337.0)); - auto f = ptr[1]; + kf::vec f = ptr[1]; ASSERT_EQ_ALL(f[I], T(1337.0)); ptr.at(1) += T(1.0); - auto g = ptr[1]; + kf::vec g = ptr[1]; ASSERT_EQ_ALL(g[I], T(1338.0)); + + kf::cast_to(ptr[1]) = double(3.14); + kf::vec h = ptr[1]; + ASSERT_EQ_ALL(h[I], T(3.14)); } } };