diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 7a57e2405bd891..209bf29b9716f9 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -15,6 +15,8 @@ #include #include +#include + namespace at { namespace native { @@ -121,84 +123,6 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) { // ----------------------------------- // prelu backward // ----------------------------------- -template -__global__ void prelu_cuda_backward_share_weights_kernel( - int numel, - const scalar_t *input_data, - const scalar_t *grad_out_data, - scalar_t *input_grad_data, - scalar_t *weight_grad_collector_data, - const scalar_t *weight_data, - inp_offset_calc_t inp_calc, - out_offset_calc_t out_calc -) { - scalar_t inputs[THREAD_WORK_SIZE]; - scalar_t grad_outs[THREAD_WORK_SIZE]; - scalar_t weight = *weight_data; - - int base_index = BLOCK_WORK_SIZE * blockIdx.x; - int remaining = std::min(numel - base_index, BLOCK_WORK_SIZE); - - // load data into registers - int thread_idx = threadIdx.x; - #pragma unroll - for(int i = 0; i < THREAD_WORK_SIZE; i++) { - if (thread_idx >= remaining) { - break; - } - int input_idx = thread_idx + base_index; - auto offsets = inp_calc.get(input_idx); - inputs[i] = input_data[offsets[0]]; - grad_outs[i] = grad_out_data[offsets[1]]; - thread_idx += num_threads; - } - - // compute and store - thread_idx = threadIdx.x; - #pragma unroll - for(int i = 0; i < THREAD_WORK_SIZE; i++) { - if (thread_idx >= remaining) { - break; - } - int input_idx = thread_idx + base_index; - auto offsets = out_calc.get(input_idx); - input_grad_data[offsets[0]] = (inputs[i] > 0) ? grad_outs[i] : weight * grad_outs[i]; - weight_grad_collector_data[offsets[1]] = (inputs[i] > 0) ? scalar_t(0) : inputs[i] * grad_outs[i]; - thread_idx += num_threads; - } -} - -template -void launch_prelu_cuda_backward_share_weights_kernel(TensorIterator &iter, const scalar_t* weight_data) { - if (!iter.can_use_32bit_indexing()) { - for (auto& sub_iter : iter.with_32bit_indexing()) { - launch_prelu_cuda_backward_share_weights_kernel(sub_iter, weight_data); - } - return; - } - - int64_t numel = iter.numel(); - if (numel == 0) { - return; - } - - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); - - scalar_t *input_grad_data = static_cast(iter.data_ptr(0)); - scalar_t *weight_grad_collector_data = static_cast(iter.data_ptr(1)); - const scalar_t *input_data = static_cast(iter.data_ptr(2)); - const scalar_t *grad_out_data = static_cast(iter.data_ptr(3)); - - int64_t grid = (numel + block_work_size - 1) / block_work_size; - auto stream = at::cuda::getCurrentCUDAStream(); - - TORCH_INTERNAL_ASSERT(iter.is_contiguous()); - prelu_cuda_backward_share_weights_kernel<<>>( - numel, input_data, grad_out_data, input_grad_data, weight_grad_collector_data, weight_data, - TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<2>() - ); -} - template void prelu_cuda_backward_kernel_share_weights( const Tensor& input, @@ -212,7 +136,13 @@ void prelu_cuda_backward_kernel_share_weights( .add_input(input) .add_input(grad_out) .build(); - launch_prelu_cuda_backward_share_weights_kernel(iter, weight_data); + + // N.B. `std::tuple` does not support `::operator=` on device code. + gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple { + scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out; + scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out; + return {input_grad, weight_grad_collector}; + }); } template diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index c6741fc879ce5a..093ace17297c1c 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -36,7 +36,6 @@ #include #include #include -#include #include #include #include @@ -60,32 +59,6 @@ namespace at { namespace native { -template -__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { - using traits = function_traits; - using return_t = typename traits::result_type; - using args_t = typename traits::ArgsTuple; - - int idx = blockIdx.x; - - return_t results[thread_work_size]; - args_t args[thread_work_size]; - - // load - policy.load(args, idx); - - // compute - #pragma unroll - for (int i = 0; i < thread_work_size; i++) { - if (policy.check_inbounds(i)) { - results[i] = c10::guts::apply(f, args[i]); - } - } - - // store - policy.store(results, idx); -} - template C10_LAUNCH_BOUNDS_1(num_threads) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index e54a068a6a4bf6..bb913dc0ec9e78 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -1,13 +1,6 @@ #pragma once -#include -#include -#include -#include - -namespace at { namespace native { - #define NUM_THREADS (C10_WARP_SIZE * 2) #define THREAD_WORK_SIZE 4 #define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads) @@ -16,25 +9,66 @@ constexpr int num_threads = NUM_THREADS; constexpr int thread_work_size = THREAD_WORK_SIZE; constexpr int block_work_size = BLOCK_WORK_SIZE; +#include +#include +#include +#include +#include + +#include + +namespace at { namespace native { + template static OffsetCalculator make_input_offset_calculator(const TensorIterator& iter) { // array size can not be 0, this happens when N == 0 constexpr int array_size = std::max(N, 1); - TORCH_INTERNAL_ASSERT(N == iter.ntensors() - 1); + TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); std::array strides; int64_t element_sizes[array_size]; for (int i = 0; i < N; i++) { - strides[i] = iter.strides(i + 1).data(); - element_sizes[i] = iter.element_size(i + 1); + strides[i] = iter.strides(i + iter.noutputs()).data(); + element_sizes[i] = iter.element_size(i + iter.noutputs()); } return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); } -static OffsetCalculator<1> make_output_offset_calculator(const TensorIterator& iter) { - std::array strides; - strides[0] = iter.strides(0).data(); - int64_t element_size = iter.element_size(0); - return OffsetCalculator<1>(iter.ndim(), iter.shape().data(), strides.data(), &element_size); +template +static OffsetCalculator make_output_offset_calculator(const TensorIterator& iter) { + TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); + std::array strides; + int64_t element_sizes[num_outputs]; + for (int i = 0; i < num_outputs; i++) { + strides[i] = iter.strides(i).data(); + element_sizes[i] = iter.element_size(i); + } + return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); +} + +template +__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { + using traits = function_traits; + using return_t = typename traits::result_type; + using args_t = typename traits::ArgsTuple; + + int idx = blockIdx.x; + + return_t results[thread_work_size]; + args_t args[thread_work_size]; + + // load + policy.load(args, idx); + + // compute + #pragma unroll + for (int i = 0; i < thread_work_size; i++) { + if (policy.check_inbounds(i)) { + results[i] = c10::guts::apply(f, args[i]); + } + } + + // store + policy.store(results, idx); } }} // namespace at::native @@ -130,4 +164,80 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) { } } +namespace { // functions for `gpu_kernel_multiple_outputs`. + +// check the return type is `thrust::tuple`, not `std::tuple`. +template struct is_tuple: std::false_type {}; + +template struct is_tuple>: std::true_type {}; + +template +C10_LAUNCH_BOUNDS_1(num_threads) +__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { + int remaining = N - block_work_size * blockIdx.x; + elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); +} + +template +static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { + TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + block_work_size - 1) / block_work_size; + auto stream = at::cuda::getCurrentCUDAStream(); + unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +void gpu_kernel_multiple_outputs_impl(TensorIterator& iter, const func_t& f) { + using traits = function_traits; + using output_t = typename traits::result_type; + static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); + constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_inputs = traits::arity; + constexpr int ntensors = num_outputs + num_inputs; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + if (iter.is_contiguous()) { + auto input_calc = TrivialOffsetCalculator(); + auto output_calc = TrivialOffsetCalculator(); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } else { + auto input_calc = make_input_offset_calculator(iter); + auto output_calc = make_output_offset_calculator(iter); + launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); + } +} +} // namespace + +template +void gpu_kernel_multiple_outputs(TensorIterator& iter, const func_t& f) { + ASSERT_HOST_DEVICE_LAMBDA(func_t); + + for (int arg = 0; arg < iter.ntensors(); arg++) { + TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); + } + + if (iter.numel() == 0) { + return; + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + gpu_kernel_multiple_outputs(sub_iter, f); + } + return; + } + + gpu_kernel_multiple_outputs_impl(iter, f); +} + }} //namespace at::native diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 03e24757310a91..9f79fd952504e2 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -9,6 +9,8 @@ #include #include +#include + // References: // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/ @@ -66,11 +68,24 @@ struct vectorized_load_helper { template struct unroll_load_helper { template - static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j) { + static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { using arg_t = std::tuple_element_t; // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we // need a +1 offset to get the input - std::get(args[j]) = loader.template load(self.data[arg_index + 1], offset[arg_index], arg_index); + std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index); + } +}; + +template +struct multi_outputs_store_helper { + template + C10_HOST_DEVICE static void apply( + at::detail::Array data, + at::detail::Array offsets, + thrust::tuple ret) { + using T = typename thrust::tuple_element>::type; + T *to = reinterpret_cast(data[current]) + offsets[current]; + *to = thrust::get(ret); } }; @@ -135,7 +150,7 @@ namespace policies { // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors -template +template struct unroll { data_t data; @@ -163,7 +178,7 @@ struct unroll { } int linear_idx = thread_idx + block_work_size * idx; auto offset = input_offset_calculator.get(linear_idx); - detail::static_unroll::with_args(*this, args, offset, loader, i); + detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); thread_idx += num_threads; } } @@ -244,6 +259,28 @@ struct vectorized { } }; +template +struct multi_outputs_unroll : unroll { + + __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc): + unroll(data, remaining, ic, oc, LoadWithoutCast(), StoreWithoutCast()) {} + + template + __device__ inline void store(return_t *from, int idx) { + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size; i++) { + if (thread_idx >= this->remaining) { + return; + } + int linear_idx = thread_idx + block_work_size * idx; + auto offsets = this->output_offset_calculator.get(linear_idx); + memory::detail::static_unroll::with_args(this->data, offsets, from[i]); + thread_idx += num_threads; + } + } +}; + } // namespace policies // This is only used in host, but we will wrap this into some templates