diff --git a/.gitignore b/.gitignore index efa59fdfc962..0804b24a9ee7 100644 --- a/.gitignore +++ b/.gitignore @@ -463,3 +463,6 @@ dask-worker-space/ *.pub *.rdp *_rsa + +# hipify-perl -inplace leaves behind *.prehip files +*.prehip diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f57cf9622e6..f94b4a604355 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ option(USE_GPU "Enable GPU-accelerated training" OFF) option(USE_SWIG "Enable SWIG to generate Java API" OFF) option(USE_TIMETAG "Set to ON to output time costs" OFF) option(USE_CUDA "Enable CUDA-accelerated training " OFF) +option(USE_ROCM "Enable ROCM-accelerated training " OFF) option(USE_DEBUG "Set to ON for Debug mode" OFF) option(USE_SANITIZER "Use sanitizer flags" OFF) set( @@ -160,6 +161,11 @@ if(USE_CUDA) set(USE_OPENMP ON CACHE BOOL "CUDA requires OpenMP" FORCE) endif() +if(USE_ROCM) + enable_language(HIP) + set(USE_OPENMP ON CACHE BOOL "ROCM requires OpenMP" FORCE) +endif() + if(USE_OPENMP) if(APPLE) find_package(OpenMP) @@ -271,35 +277,53 @@ if(USE_CUDA) message(STATUS "ALLFEATS_DEFINES: ${ALLFEATS_DEFINES}") message(STATUS "FULLDATA_DEFINES: ${FULLDATA_DEFINES}") +endif() - function(add_histogram hsize hname hadd hconst hdir) - add_library(histo${hsize}${hname} OBJECT src/treelearner/kernels/histogram${hsize}.cu) - set_target_properties( - histo${hsize}${hname} - PROPERTIES - CUDA_SEPARABLE_COMPILATION ON - CUDA_ARCHITECTURES ${CUDA_ARCHS} - ) - if(hadd) - list(APPEND histograms histo${hsize}${hname}) - set(histograms ${histograms} PARENT_SCOPE) - endif() - target_compile_definitions( - histo${hsize}${hname} - PRIVATE - -DCONST_HESSIAN=${hconst} - ${hdir} - ) - endfunction() +if(USE_ROCM) + find_package(HIP) + include_directories(${HIP_INCLUDE_DIRS}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_AMD__") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${OpenMP_CXX_FLAGS} -fPIC -Wall") + + # avoid warning: unused variable 'mask' due to __shfl_down_sync work-around + set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-unused-variable") + # avoid warning: 'hipHostAlloc' is deprecated: use hipHostMalloc instead + set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-deprecated-declarations") + # avoid many warnings about missing overrides + set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-inconsistent-missing-override") + # avoid warning: shift count >= width of type in feature_histogram.hpp + set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-shift-count-overflow") - foreach(hsize _16_64_256) - add_histogram("${hsize}" "_sp_const" "True" "1" "${BASE_DEFINES}") - add_histogram("${hsize}" "_sp" "True" "0" "${BASE_DEFINES}") - add_histogram("${hsize}" "-allfeats_sp_const" "False" "1" "${ALLFEATS_DEFINES}") - add_histogram("${hsize}" "-allfeats_sp" "False" "0" "${ALLFEATS_DEFINES}") - add_histogram("${hsize}" "-fulldata_sp_const" "True" "1" "${FULLDATA_DEFINES}") - add_histogram("${hsize}" "-fulldata_sp" "True" "0" "${FULLDATA_DEFINES}") - endforeach() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DISABLED_WARNINGS}") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${DISABLED_WARNINGS}") + + if(USE_DEBUG) + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -g -O0") + else() + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3") + endif() + message(STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}") + + add_definitions(-DUSE_CUDA) + + set( + BASE_DEFINES + -DPOWER_FEATURE_WORKGROUPS=12 + -DUSE_CONSTANT_BUF=0 + ) + set( + ALLFEATS_DEFINES + ${BASE_DEFINES} + -DENABLE_ALL_FEATURES + ) + set( + FULLDATA_DEFINES + ${ALLFEATS_DEFINES} + -DIGNORE_INDICES + ) + + message(STATUS "ALLFEATS_DEFINES: ${ALLFEATS_DEFINES}") + message(STATUS "FULLDATA_DEFINES: ${FULLDATA_DEFINES}") endif() include(CheckCXXSourceCompiles) @@ -634,7 +658,9 @@ if(USE_CUDA) CUDA_RESOLVE_DEVICE_SYMBOLS ON ) endif() +endif() +if(USE_ROCM OR USE_CUDA) # histograms are list of object libraries. Linking object library to other # object libraries only gets usage requirements, the linked objects won't be # used. Thus we have to call target_link_libraries on final targets here. diff --git a/include/LightGBM/cuda/cuda_algorithms.hpp b/include/LightGBM/cuda/cuda_algorithms.hpp index abda07b1582f..33fd49eedb59 100644 --- a/include/LightGBM/cuda/cuda_algorithms.hpp +++ b/include/LightGBM/cuda/cuda_algorithms.hpp @@ -1,6 +1,7 @@ /*! * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifndef LIGHTGBM_CUDA_CUDA_ALGORITHMS_HPP_ @@ -14,6 +15,7 @@ #include #include +#include #include #include @@ -174,7 +176,7 @@ __device__ __forceinline__ void GlobalMemoryPrefixSum(T* array, const size_t len for (size_t index = start; index < end; ++index) { thread_sum += array[index]; } - __shared__ T shared_mem[32]; + __shared__ T shared_mem[WARPSIZE]; const T thread_base = ShufflePrefixSumExclusive(thread_sum, shared_mem); if (start < end) { array[start] += thread_base; @@ -483,7 +485,7 @@ __device__ void ShuffleSortedPrefixSumDevice(const VAL_T* in_values, const INDEX_T* sorted_indices, REDUCE_VAL_T* out_values, const INDEX_T num_data) { - __shared__ REDUCE_VAL_T shared_buffer[32]; + __shared__ REDUCE_VAL_T shared_buffer[WARPSIZE]; const INDEX_T num_data_per_thread = (num_data + static_cast(blockDim.x) - 1) / static_cast(blockDim.x); const INDEX_T start = num_data_per_thread * static_cast(threadIdx.x); const INDEX_T end = min(start + num_data_per_thread, num_data); @@ -572,8 +574,48 @@ __device__ VAL_T PercentileDevice(const VAL_T* values, INDEX_T* indices, REDUCE_WEIGHT_T* weights_prefix_sum, const double alpha, - const INDEX_T len); - + const INDEX_T len) { + if (len <= 1) { + return values[0]; + } + if (!USE_WEIGHT) { + BitonicArgSortDevice(values, indices, len); + const double float_pos = (1.0f - alpha) * len; + const INDEX_T pos = static_cast(float_pos); + if (pos < 1) { + return values[indices[0]]; + } else if (pos >= len) { + return values[indices[len - 1]]; + } else { + const double bias = float_pos - pos; + const VAL_T v1 = values[indices[pos - 1]]; + const VAL_T v2 = values[indices[pos]]; + return static_cast(v1 - (v1 - v2) * bias); + } + } else { + BitonicArgSortDevice(values, indices, len); + ShuffleSortedPrefixSumDevice(weights, indices, weights_prefix_sum, len); + const REDUCE_WEIGHT_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha); + __shared__ INDEX_T pos; + if (threadIdx.x == 0) { + pos = len; + } + __syncthreads(); + for (INDEX_T index = static_cast(threadIdx.x); index < len; index += static_cast(blockDim.x)) { + if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) { + pos = index; + } + } + __syncthreads(); + pos = min(pos, len - 1); + if (pos == 0 || pos == len - 1) { + return values[pos]; + } + const VAL_T v1 = values[indices[pos - 1]]; + const VAL_T v2 = values[indices[pos]]; + return static_cast(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1])); + } +} } // namespace LightGBM diff --git a/include/LightGBM/cuda/cuda_rocm_interop.h b/include/LightGBM/cuda/cuda_rocm_interop.h new file mode 100644 index 000000000000..670a7b84b547 --- /dev/null +++ b/include/LightGBM/cuda/cuda_rocm_interop.h @@ -0,0 +1,20 @@ +/*! + * Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. + */ +#ifdef USE_CUDA + +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__) +// ROCm doesn't have __shfl_down_sync, only __shfl_down without mask. +// Since mask is full 0xffffffff, we can use __shfl_down instead. +#define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset) +#define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset) +// ROCm warpSize is constexpr and is either 32 or 64 depending on gfx arch. +#define WARPSIZE warpSize +// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd +#define atomicAdd_block atomicAdd +#else +// CUDA warpSize is not a constexpr, but always 32 +#define WARPSIZE 32 +#endif + +#endif diff --git a/include/LightGBM/cuda/cuda_split_info.hpp b/include/LightGBM/cuda/cuda_split_info.hpp index f01ce2b02a02..8f933f06d58a 100644 --- a/include/LightGBM/cuda/cuda_split_info.hpp +++ b/include/LightGBM/cuda/cuda_split_info.hpp @@ -2,6 +2,7 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA @@ -40,24 +41,24 @@ class CUDASplitInfo { uint32_t* cat_threshold = nullptr; int* cat_threshold_real = nullptr; - __device__ CUDASplitInfo() { + __host__ __device__ CUDASplitInfo() { num_cat_threshold = 0; cat_threshold = nullptr; cat_threshold_real = nullptr; } - __device__ ~CUDASplitInfo() { + __host__ __device__ ~CUDASplitInfo() { if (num_cat_threshold > 0) { if (cat_threshold != nullptr) { - cudaFree(cat_threshold); + CUDASUCCESS_OR_FATAL(cudaFree(cat_threshold)); } if (cat_threshold_real != nullptr) { - cudaFree(cat_threshold_real); + CUDASUCCESS_OR_FATAL(cudaFree(cat_threshold_real)); } } } - __device__ CUDASplitInfo& operator=(const CUDASplitInfo& other) { + __host__ __device__ CUDASplitInfo& operator=(const CUDASplitInfo& other) { is_valid = other.is_valid; leaf_index = other.leaf_index; gain = other.gain; diff --git a/include/LightGBM/cuda/vector_cudahost.h b/include/LightGBM/cuda/vector_cudahost.h index 83fbe5cda9b7..a217a659b14f 100644 --- a/include/LightGBM/cuda/vector_cudahost.h +++ b/include/LightGBM/cuda/vector_cudahost.h @@ -1,6 +1,7 @@ /*! * Copyright (c) 2020 IBM Corporation, Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifndef LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_ #define LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_ @@ -45,7 +46,7 @@ struct CHAllocator { n = SIZE_ALIGNED(n); #ifdef USE_CUDA if (LGBM_config_::current_device == lgbm_device_cuda) { - cudaError_t ret = cudaHostAlloc(&ptr, n*sizeof(T), cudaHostAllocPortable); + cudaError_t ret = cudaHostAlloc(reinterpret_cast(&ptr), n*sizeof(T), cudaHostAllocPortable); if (ret != cudaSuccess) { Log::Warning("Defaulting to malloc in CHAllocator!!!"); ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16)); diff --git a/src/cuda/cuda_algorithms.cu b/src/cuda/cuda_algorithms.cu index 19c1507419e9..6a4b842cb844 100644 --- a/src/cuda/cuda_algorithms.cu +++ b/src/cuda/cuda_algorithms.cu @@ -1,17 +1,19 @@ /*! * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA #include +#include namespace LightGBM { template __global__ void ShufflePrefixSumGlobalKernel(T* values, size_t len, T* block_prefix_sum_buffer) { - __shared__ T shared_mem_buffer[32]; + __shared__ T shared_mem_buffer[WARPSIZE]; const size_t index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); T value = 0; if (index < len) { @@ -26,7 +28,7 @@ __global__ void ShufflePrefixSumGlobalKernel(T* values, size_t len, T* block_pre template __global__ void ShufflePrefixSumGlobalReduceBlockKernel(T* block_prefix_sum_buffer, int num_blocks) { - __shared__ T shared_mem_buffer[32]; + __shared__ T shared_mem_buffer[WARPSIZE]; const int num_blocks_per_thread = (num_blocks + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 2) / (GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1); int thread_block_start = threadIdx.x == 0 ? 0 : (threadIdx.x - 1) * num_blocks_per_thread; int thread_block_end = threadIdx.x == 0 ? 0 : min(thread_block_start + num_blocks_per_thread, num_blocks); @@ -96,7 +98,7 @@ void BitonicArgSortItemsGlobal( template __global__ void BlockReduceSum(T* block_buffer, const data_size_t num_blocks) { - __shared__ T shared_buffer[32]; + __shared__ T shared_buffer[WARPSIZE]; T thread_sum = 0; for (data_size_t block_index = static_cast(threadIdx.x); block_index < num_blocks; block_index += static_cast(blockDim.x)) { thread_sum += block_buffer[block_index]; @@ -109,7 +111,7 @@ __global__ void BlockReduceSum(T* block_buffer, const data_size_t num_blocks) { template __global__ void ShuffleReduceSumGlobalKernel(const VAL_T* values, const data_size_t num_value, REDUCE_T* block_buffer) { - __shared__ REDUCE_T shared_buffer[32]; + __shared__ REDUCE_T shared_buffer[WARPSIZE]; const data_size_t data_index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); const REDUCE_T value = (data_index < num_value ? static_cast(values[data_index]) : 0.0f); const REDUCE_T reduce_value = ShuffleReduceSum(value, shared_buffer, blockDim.x); @@ -131,7 +133,7 @@ template void ShuffleReduceSumGlobal(const double* values, size_ template __global__ void ShuffleReduceMinGlobalKernel(const VAL_T* values, const data_size_t num_value, REDUCE_T* block_buffer) { - __shared__ REDUCE_T shared_buffer[32]; + __shared__ REDUCE_T shared_buffer[WARPSIZE]; const data_size_t data_index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); const REDUCE_T value = (data_index < num_value ? static_cast(values[data_index]) : 0.0f); const REDUCE_T reduce_value = ShuffleReduceMin(value, shared_buffer, blockDim.x); @@ -142,7 +144,7 @@ __global__ void ShuffleReduceMinGlobalKernel(const VAL_T* values, const data_siz template __global__ void ShuffleBlockReduceMin(T* block_buffer, const data_size_t num_blocks) { - __shared__ T shared_buffer[32]; + __shared__ T shared_buffer[WARPSIZE]; T thread_min = 0; for (data_size_t block_index = static_cast(threadIdx.x); block_index < num_blocks; block_index += static_cast(blockDim.x)) { const T value = block_buffer[block_index]; @@ -168,7 +170,7 @@ template void ShuffleReduceMinGlobal(const label_t* values, siz template __global__ void ShuffleReduceDotProdGlobalKernel(const VAL_T* values1, const VAL_T* values2, const data_size_t num_value, REDUCE_T* block_buffer) { - __shared__ REDUCE_T shared_buffer[32]; + __shared__ REDUCE_T shared_buffer[WARPSIZE]; const data_size_t data_index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); const REDUCE_T value1 = (data_index < num_value ? static_cast(values1[data_index]) : 0.0f); const REDUCE_T value2 = (data_index < num_value ? static_cast(values2[data_index]) : 0.0f); @@ -191,7 +193,7 @@ template void ShuffleReduceDotProdGlobal(const label_t* values1 template __global__ void GlobalInclusiveArgPrefixSumKernel( const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, data_size_t num_data) { - __shared__ REDUCE_T shared_buffer[32]; + __shared__ REDUCE_T shared_buffer[WARPSIZE]; const data_size_t data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); REDUCE_T value = static_cast(data_index < num_data ? in_values[sorted_indices[data_index]] : 0); __syncthreads(); @@ -206,7 +208,7 @@ __global__ void GlobalInclusiveArgPrefixSumKernel( template __global__ void GlobalInclusivePrefixSumReduceBlockKernel(T* block_buffer, data_size_t num_blocks) { - __shared__ T shared_buffer[32]; + __shared__ T shared_buffer[WARPSIZE]; T thread_sum = 0; const data_size_t num_blocks_per_thread = (num_blocks + static_cast(blockDim.x)) / static_cast(blockDim.x); const data_size_t thread_start_block_index = static_cast(threadIdx.x) * num_blocks_per_thread; @@ -441,72 +443,6 @@ void BitonicArgSortGlobal(const data_size_t* values, int BitonicArgSortGlobalHelper(values, indices, len); } -template -__device__ VAL_T PercentileDevice(const VAL_T* values, - const WEIGHT_T* weights, - INDEX_T* indices, - REDUCE_WEIGHT_T* weights_prefix_sum, - const double alpha, - const INDEX_T len) { - if (len <= 1) { - return values[0]; - } - if (!USE_WEIGHT) { - BitonicArgSortDevice(values, indices, len); - const double float_pos = (1.0f - alpha) * len; - const INDEX_T pos = static_cast(float_pos); - if (pos < 1) { - return values[indices[0]]; - } else if (pos >= len) { - return values[indices[len - 1]]; - } else { - const double bias = float_pos - pos; - const VAL_T v1 = values[indices[pos - 1]]; - const VAL_T v2 = values[indices[pos]]; - return static_cast(v1 - (v1 - v2) * bias); - } - } else { - BitonicArgSortDevice(values, indices, len); - ShuffleSortedPrefixSumDevice(weights, indices, weights_prefix_sum, len); - const REDUCE_WEIGHT_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha); - __shared__ INDEX_T pos; - if (threadIdx.x == 0) { - pos = len; - } - __syncthreads(); - for (INDEX_T index = static_cast(threadIdx.x); index < len; index += static_cast(blockDim.x)) { - if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) { - pos = index; - } - } - __syncthreads(); - pos = min(pos, len - 1); - if (pos == 0 || pos == len - 1) { - return values[pos]; - } - const VAL_T v1 = values[indices[pos - 1]]; - const VAL_T v2 = values[indices[pos]]; - return static_cast(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1])); - } -} - -template __device__ double PercentileDevice( - const double* values, - const label_t* weights, - data_size_t* indices, - double* weights_prefix_sum, - const double alpha, - const data_size_t len); - -template __device__ double PercentileDevice( - const double* values, - const label_t* weights, - data_size_t* indices, - double* weights_prefix_sum, - const double alpha, - const data_size_t len); - - } // namespace LightGBM #endif // USE_CUDA diff --git a/src/metric/cuda/cuda_pointwise_metric.cu b/src/metric/cuda/cuda_pointwise_metric.cu index 6b200798f4c8..bd6f2841f894 100644 --- a/src/metric/cuda/cuda_pointwise_metric.cu +++ b/src/metric/cuda/cuda_pointwise_metric.cu @@ -2,11 +2,13 @@ * Copyright (c) 2022 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA #include +#include #include "cuda_binary_metric.hpp" #include "cuda_pointwise_metric.hpp" @@ -17,7 +19,7 @@ namespace LightGBM { template __global__ void EvalKernel(const data_size_t num_data, const label_t* labels, const label_t* weights, const double* scores, double* reduce_block_buffer, const double param) { - __shared__ double shared_mem_buffer[32]; + __shared__ double shared_mem_buffer[WARPSIZE]; const data_size_t index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); double point_metric = 0.0; if (index < num_data) { diff --git a/src/objective/cuda/cuda_binary_objective.cu b/src/objective/cuda/cuda_binary_objective.cu index 6f01c1745f72..08dcd3121a3e 100644 --- a/src/objective/cuda/cuda_binary_objective.cu +++ b/src/objective/cuda/cuda_binary_objective.cu @@ -2,20 +2,23 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA -#include - #include "cuda_binary_objective.hpp" +#include + +#include + namespace LightGBM { template __global__ void BoostFromScoreKernel_1_BinaryLogloss(const label_t* cuda_labels, const data_size_t num_data, double* out_cuda_sum_labels, double* out_cuda_sum_weights, const label_t* cuda_weights) { - __shared__ double shared_buffer[32]; + __shared__ double shared_buffer[WARPSIZE]; const uint32_t mask = 0xffffffff; const uint32_t warpLane = threadIdx.x % warpSize; const uint32_t warpID = threadIdx.x / warpSize; diff --git a/src/objective/cuda/cuda_rank_objective.cu b/src/objective/cuda/cuda_rank_objective.cu index af9f595f1aed..9975e9c59ef8 100644 --- a/src/objective/cuda/cuda_rank_objective.cu +++ b/src/objective/cuda/cuda_rank_objective.cu @@ -2,6 +2,7 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA @@ -307,6 +308,9 @@ __global__ void GetGradientsKernel_LambdarankNDCG_Sorted( void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t* gradients, score_t* hessians) const { const int num_blocks = (num_queries_ + NUM_QUERY_PER_BLOCK - 1) / NUM_QUERY_PER_BLOCK; const data_size_t num_rank_label = static_cast(label_gain_.size()); + const int device_index = GetCUDADevice(__FILE__, __LINE__); + cudaDeviceProp device_prop; + CUDASUCCESS_OR_FATAL(cudaGetDeviceProperties(&device_prop, device_index)); #define GetGradientsKernel_LambdarankNDCG_ARGS \ score, cuda_labels_, num_data_, \ @@ -321,7 +325,7 @@ void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t* gradients, hessians if (max_items_in_query_aligned_ <= 1024) { - if (num_rank_label <= 32) { + if (num_rank_label <= 32 && device_prop.warpSize == 32) { GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); } else if (num_rank_label <= 64) { GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); @@ -337,7 +341,7 @@ void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t* GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); } } else if (max_items_in_query_aligned_ <= 2048) { - if (num_rank_label <= 32) { + if (num_rank_label <= 32 && device_prop.warpSize == 32) { GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); } else if (num_rank_label <= 64) { GetGradientsKernel_LambdarankNDCG<<>>(GetGradientsKernel_LambdarankNDCG_ARGS); @@ -354,7 +358,7 @@ void CUDALambdarankNDCG::LaunchGetGradientsKernel(const double* score, score_t* } } else { BitonicArgSortItemsGlobal(score, num_queries_, cuda_query_boundaries_, cuda_item_indices_buffer_.RawData()); - if (num_rank_label <= 32) { + if (num_rank_label <= 32 && device_prop.warpSize == 32) { GetGradientsKernel_LambdarankNDCG_Sorted<32><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); } else if (num_rank_label <= 64) { GetGradientsKernel_LambdarankNDCG_Sorted<64><<>>(GetGradientsKernel_LambdarankNDCG_Sorted_ARGS); diff --git a/src/treelearner/cuda/cuda_best_split_finder.cu b/src/treelearner/cuda/cuda_best_split_finder.cu index d5c819d392c9..982ba73d9a68 100644 --- a/src/treelearner/cuda/cuda_best_split_finder.cu +++ b/src/treelearner/cuda/cuda_best_split_finder.cu @@ -2,14 +2,17 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA -#include +#include "cuda_best_split_finder.hpp" #include -#include "cuda_best_split_finder.hpp" +#include + +#include namespace LightGBM { @@ -161,9 +164,9 @@ __device__ void FindBestSplitsForLeafKernelInner( } } __shared__ uint32_t best_thread_index; - __shared__ double shared_double_buffer[32]; - __shared__ bool shared_bool_buffer[32]; - __shared__ uint32_t shared_int_buffer[32]; + __shared__ double shared_double_buffer[WARPSIZE]; + __shared__ bool shared_bool_buffer[WARPSIZE]; + __shared__ uint32_t shared_int_buffer[WARPSIZE]; const unsigned int threadIdx_x = threadIdx.x; const bool skip_sum = REVERSE ? (task->skip_default_bin && (task->num_bin - 1 - threadIdx_x) == static_cast(task->default_bin)) : @@ -515,9 +518,9 @@ __device__ void FindBestSplitsForLeafKernelCategoricalInner( const double parent_output, // output parameters CUDASplitInfo* cuda_best_split_info) { - __shared__ double shared_gain_buffer[32]; - __shared__ bool shared_found_buffer[32]; - __shared__ uint32_t shared_thread_index_buffer[32]; + __shared__ double shared_gain_buffer[WARPSIZE]; + __shared__ bool shared_found_buffer[WARPSIZE]; + __shared__ uint32_t shared_thread_index_buffer[WARPSIZE]; __shared__ uint32_t best_thread_index; const double cnt_factor = num_data / sum_hessians; const double min_gain_shift = parent_gain + min_gain_to_split; @@ -605,8 +608,8 @@ __device__ void FindBestSplitsForLeafKernelCategoricalInner( } else { __shared__ double shared_value_buffer[NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER]; __shared__ int16_t shared_index_buffer[NUM_THREADS_PER_BLOCK_BEST_SPLIT_FINDER]; - __shared__ uint16_t shared_mem_buffer_uint16[32]; - __shared__ double shared_mem_buffer_double[32]; + __shared__ uint16_t shared_mem_buffer_uint16[WARPSIZE]; + __shared__ double shared_mem_buffer_double[WARPSIZE]; __shared__ int used_bin; l2 += cat_l2; uint16_t is_valid_bin = 0; @@ -1086,9 +1089,9 @@ __device__ void FindBestSplitsForLeafKernelInner_GlobalMemory( } } __shared__ uint32_t best_thread_index; - __shared__ double shared_double_buffer[32]; - __shared__ bool shared_found_buffer[32]; - __shared__ uint32_t shared_thread_index_buffer[32]; + __shared__ double shared_double_buffer[WARPSIZE]; + __shared__ bool shared_found_buffer[WARPSIZE]; + __shared__ uint32_t shared_thread_index_buffer[WARPSIZE]; const unsigned int threadIdx_x = threadIdx.x; const uint32_t feature_num_bin_minus_offset = task->num_bin - task->mfb_offset; if (!REVERSE) { @@ -1301,9 +1304,9 @@ __device__ void FindBestSplitsForLeafKernelCategoricalInner_GlobalMemory( data_size_t* hist_index_buffer_ptr, // output parameters CUDASplitInfo* cuda_best_split_info) { - __shared__ double shared_gain_buffer[32]; - __shared__ bool shared_found_buffer[32]; - __shared__ uint32_t shared_thread_index_buffer[32]; + __shared__ double shared_gain_buffer[WARPSIZE]; + __shared__ bool shared_found_buffer[WARPSIZE]; + __shared__ uint32_t shared_thread_index_buffer[WARPSIZE]; __shared__ uint32_t best_thread_index; const double cnt_factor = num_data / sum_hessians; const double min_gain_shift = parent_gain + min_gain_to_split; @@ -1390,7 +1393,7 @@ __device__ void FindBestSplitsForLeafKernelCategoricalInner_GlobalMemory( sum_right_hessian, lambda_l1, l2, right_output); } } else { - __shared__ uint16_t shared_mem_buffer_uint16[32]; + __shared__ uint16_t shared_mem_buffer_uint16[WARPSIZE]; __shared__ int used_bin; l2 += cat_l2; uint16_t is_valid_bin = 0; @@ -1927,9 +1930,9 @@ __global__ void SyncBestSplitForLeafKernel(const int smaller_leaf_index, const i const int num_blocks_per_leaf, const bool larger_only, const int num_leaves) { - __shared__ double shared_gain_buffer[32]; - __shared__ bool shared_found_buffer[32]; - __shared__ uint32_t shared_thread_index_buffer[32]; + __shared__ double shared_gain_buffer[WARPSIZE]; + __shared__ bool shared_found_buffer[WARPSIZE]; + __shared__ uint32_t shared_thread_index_buffer[WARPSIZE]; const uint32_t threadIdx_x = threadIdx.x; const uint32_t blockIdx_x = blockIdx.x; @@ -2113,8 +2116,8 @@ void CUDABestSplitFinder::LaunchSyncBestSplitForLeafKernel( __global__ void FindBestFromAllSplitsKernel(const int cur_num_leaves, CUDASplitInfo* cuda_leaf_best_split_info, int* cuda_best_split_info_buffer) { - __shared__ double gain_shared_buffer[32]; - __shared__ int leaf_index_shared_buffer[32]; + __shared__ double gain_shared_buffer[WARPSIZE]; + __shared__ int leaf_index_shared_buffer[WARPSIZE]; double thread_best_gain = kMinScore; int thread_best_leaf_index = -1; const int threadIdx_x = static_cast(threadIdx.x); diff --git a/src/treelearner/cuda/cuda_data_partition.cu b/src/treelearner/cuda/cuda_data_partition.cu index 4ca9d9279443..1845ffa71235 100644 --- a/src/treelearner/cuda/cuda_data_partition.cu +++ b/src/treelearner/cuda/cuda_data_partition.cu @@ -2,6 +2,7 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA @@ -9,6 +10,7 @@ #include "cuda_data_partition.hpp" #include +#include #include #include @@ -290,7 +292,7 @@ __global__ void GenDataToLeftBitVectorKernel( uint16_t* block_to_left_offset, data_size_t* block_to_left_offset_buffer, data_size_t* block_to_right_offset_buffer) { - __shared__ uint16_t shared_mem_buffer[32]; + __shared__ uint16_t shared_mem_buffer[WARPSIZE]; uint16_t thread_to_left_offset_cnt = 0; const unsigned int local_data_index = blockIdx.x * blockDim.x + threadIdx.x; if (local_data_index < num_data_in_leaf) { @@ -585,7 +587,7 @@ __global__ void GenDataToLeftBitVectorKernel_Categorical( const uint8_t split_default_to_left, uint16_t* block_to_left_offset, data_size_t* block_to_left_offset_buffer, data_size_t* block_to_right_offset_buffer) { - __shared__ uint16_t shared_mem_buffer[32]; + __shared__ uint16_t shared_mem_buffer[WARPSIZE]; uint16_t thread_to_left_offset_cnt = 0; const unsigned int local_data_index = blockIdx.x * blockDim.x + threadIdx.x; if (local_data_index < num_data_in_leaf) { @@ -683,7 +685,7 @@ __global__ void AggregateBlockOffsetKernel0( data_size_t* block_to_right_offset_buffer, data_size_t* cuda_leaf_data_start, data_size_t* cuda_leaf_data_end, data_size_t* cuda_leaf_num_data, const data_size_t* cuda_data_indices, const data_size_t num_blocks) { - __shared__ uint32_t shared_mem_buffer[32]; + __shared__ uint32_t shared_mem_buffer[WARPSIZE]; __shared__ uint32_t to_left_total_count; const data_size_t num_data_in_leaf = cuda_leaf_num_data[left_leaf_index]; const unsigned int blockDim_x = blockDim.x; @@ -747,7 +749,7 @@ __global__ void AggregateBlockOffsetKernel1( data_size_t* block_to_right_offset_buffer, data_size_t* cuda_leaf_data_start, data_size_t* cuda_leaf_data_end, data_size_t* cuda_leaf_num_data, const data_size_t* cuda_data_indices, const data_size_t num_blocks) { - __shared__ uint32_t shared_mem_buffer[32]; + __shared__ uint32_t shared_mem_buffer[WARPSIZE]; __shared__ uint32_t to_left_total_count; const data_size_t num_data_in_leaf = cuda_leaf_num_data[left_leaf_index]; const unsigned int threadIdx_x = threadIdx.x; diff --git a/src/treelearner/cuda/cuda_histogram_constructor.cu b/src/treelearner/cuda/cuda_histogram_constructor.cu index 03d3b8979439..c7218ae68d38 100644 --- a/src/treelearner/cuda/cuda_histogram_constructor.cu +++ b/src/treelearner/cuda/cuda_histogram_constructor.cu @@ -2,6 +2,7 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA @@ -9,6 +10,7 @@ #include "cuda_histogram_constructor.hpp" #include +#include #include @@ -742,7 +744,7 @@ __global__ void FixHistogramKernel( const int* cuda_need_fix_histogram_features, const uint32_t* cuda_need_fix_histogram_features_num_bin_aligned, const CUDALeafSplitsStruct* cuda_smaller_leaf_splits) { - __shared__ hist_t shared_mem_buffer[32]; + __shared__ hist_t shared_mem_buffer[WARPSIZE]; const unsigned int blockIdx_x = blockIdx.x; const int feature_index = cuda_need_fix_histogram_features[blockIdx_x]; const uint32_t num_bin_aligned = cuda_need_fix_histogram_features_num_bin_aligned[blockIdx_x]; diff --git a/src/treelearner/cuda/cuda_leaf_splits.cu b/src/treelearner/cuda/cuda_leaf_splits.cu index 0c796be9f20a..ca0d33674cca 100644 --- a/src/treelearner/cuda/cuda_leaf_splits.cu +++ b/src/treelearner/cuda/cuda_leaf_splits.cu @@ -2,6 +2,7 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ @@ -9,6 +10,7 @@ #include "cuda_leaf_splits.hpp" #include +#include namespace LightGBM { @@ -16,7 +18,7 @@ template __global__ void CUDAInitValuesKernel1(const score_t* cuda_gradients, const score_t* cuda_hessians, const data_size_t num_data, const data_size_t* cuda_bagging_data_indices, double* cuda_sum_of_gradients, double* cuda_sum_of_hessians) { - __shared__ double shared_mem_buffer[32]; + __shared__ double shared_mem_buffer[WARPSIZE]; const data_size_t data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); double gradient = 0.0f; double hessian = 0.0f; @@ -43,7 +45,7 @@ __global__ void CUDAInitValuesKernel2( const data_size_t* cuda_data_indices_in_leaf, hist_t* cuda_hist_in_leaf, CUDALeafSplitsStruct* cuda_struct) { - __shared__ double shared_mem_buffer[32]; + __shared__ double shared_mem_buffer[WARPSIZE]; double thread_sum_of_gradients = 0.0f; double thread_sum_of_hessians = 0.0f; for (int block_index = static_cast(threadIdx.x); block_index < num_blocks_to_reduce; block_index += static_cast(blockDim.x)) { diff --git a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu index 670f1f36d643..1fb4a6c0255e 100644 --- a/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu +++ b/src/treelearner/cuda/cuda_single_gpu_tree_learner.cu @@ -2,14 +2,16 @@ * Copyright (c) 2021 Microsoft Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for * license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ #ifdef USE_CUDA -#include - #include "cuda_single_gpu_tree_learner.hpp" +#include +#include + #include namespace LightGBM { @@ -167,7 +169,7 @@ void CUDASingleGPUTreeLearner::LaunchReduceLeafStatKernel( template __global__ void CalcBitsetLenKernel(const CUDASplitInfo* best_split_info, size_t* out_len_buffer) { - __shared__ size_t shared_mem_buffer[32]; + __shared__ size_t shared_mem_buffer[WARPSIZE]; const T* vals = nullptr; if (IS_INNER) { vals = reinterpret_cast(best_split_info->cat_threshold); @@ -187,7 +189,7 @@ __global__ void CalcBitsetLenKernel(const CUDASplitInfo* best_split_info, size_t } __global__ void ReduceBlockMaxLen(size_t* out_len_buffer, const int num_blocks) { - __shared__ size_t shared_mem_buffer[32]; + __shared__ size_t shared_mem_buffer[WARPSIZE]; size_t max_len = 0; for (int i = static_cast(threadIdx.x); i < num_blocks; i += static_cast(blockDim.x)) { max_len = max(out_len_buffer[i], max_len); diff --git a/src/treelearner/kernels/histogram_16_64_256.cu b/src/treelearner/kernels/histogram_16_64_256.cu index 9d8427a6f9a8..c98714f51a68 100644 --- a/src/treelearner/kernels/histogram_16_64_256.cu +++ b/src/treelearner/kernels/histogram_16_64_256.cu @@ -1,15 +1,16 @@ /*! * Copyright (c) 2020 IBM Corporation. All rights reserved. * Licensed under the MIT License. See LICENSE file in the project root for license information. + * Modifications Copyright(C) 2023 Advanced Micro Devices, Inc. All rights reserved. */ +#include "histogram_16_64_256.hu" + #include #include #include -#include "histogram_16_64_256.hu" - namespace LightGBM { // atomic add for float number in local memory