Skip to content

Commit

Permalink
Implement gpu_kernel_multiple_outputs (pytorch#37969)
Browse files Browse the repository at this point in the history
Summary:
This PR introduces a variant of `gpu_kernel` for functions that return multiple values with `thrust::tuple`.
With this I simplified `prelu_cuda_backward_share_weights_kernel`.

### Why using `thrust::tuple`?
Because `std::tuple` does not support `operator=` on device code which makes the implementation complicated.

Pull Request resolved: pytorch#37969

Reviewed By: paulshaoyuqiao

Differential Revision: D22868670

Pulled By: ngimel

fbshipit-source-id: eda0a29ac0347ad544b24bf60e3d809a7db1a929
  • Loading branch information
crcrpar authored and facebook-github-bot committed Aug 6, 2020
1 parent 1848b43 commit eb9ae7c
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 125 deletions.
88 changes: 9 additions & 79 deletions aten/src/ATen/native/cuda/Activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <ATen/native/cuda/Loops.cuh>
#include <c10/cuda/CUDAMathCompat.h>

#include <thrust/tuple.h>


namespace at { namespace native {

Expand Down Expand Up @@ -121,84 +123,6 @@ Tensor prelu_cuda(const Tensor& self, const Tensor& weight_) {
// -----------------------------------
// prelu backward
// -----------------------------------
template<typename scalar_t, typename inp_offset_calc_t, typename out_offset_calc_t>
__global__ void prelu_cuda_backward_share_weights_kernel(
int numel,
const scalar_t *input_data,
const scalar_t *grad_out_data,
scalar_t *input_grad_data,
scalar_t *weight_grad_collector_data,
const scalar_t *weight_data,
inp_offset_calc_t inp_calc,
out_offset_calc_t out_calc
) {
scalar_t inputs[THREAD_WORK_SIZE];
scalar_t grad_outs[THREAD_WORK_SIZE];
scalar_t weight = *weight_data;

int base_index = BLOCK_WORK_SIZE * blockIdx.x;
int remaining = std::min<int>(numel - base_index, BLOCK_WORK_SIZE);

// load data into registers
int thread_idx = threadIdx.x;
#pragma unroll
for(int i = 0; i < THREAD_WORK_SIZE; i++) {
if (thread_idx >= remaining) {
break;
}
int input_idx = thread_idx + base_index;
auto offsets = inp_calc.get(input_idx);
inputs[i] = input_data[offsets[0]];
grad_outs[i] = grad_out_data[offsets[1]];
thread_idx += num_threads;
}

// compute and store
thread_idx = threadIdx.x;
#pragma unroll
for(int i = 0; i < THREAD_WORK_SIZE; i++) {
if (thread_idx >= remaining) {
break;
}
int input_idx = thread_idx + base_index;
auto offsets = out_calc.get(input_idx);
input_grad_data[offsets[0]] = (inputs[i] > 0) ? grad_outs[i] : weight * grad_outs[i];
weight_grad_collector_data[offsets[1]] = (inputs[i] > 0) ? scalar_t(0) : inputs[i] * grad_outs[i];
thread_idx += num_threads;
}
}

template<typename scalar_t>
void launch_prelu_cuda_backward_share_weights_kernel(TensorIterator &iter, const scalar_t* weight_data) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
launch_prelu_cuda_backward_share_weights_kernel(sub_iter, weight_data);
}
return;
}

int64_t numel = iter.numel();
if (numel == 0) {
return;
}

TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());

scalar_t *input_grad_data = static_cast<scalar_t *>(iter.data_ptr(0));
scalar_t *weight_grad_collector_data = static_cast<scalar_t *>(iter.data_ptr(1));
const scalar_t *input_data = static_cast<const scalar_t *>(iter.data_ptr(2));
const scalar_t *grad_out_data = static_cast<const scalar_t *>(iter.data_ptr(3));

int64_t grid = (numel + block_work_size - 1) / block_work_size;
auto stream = at::cuda::getCurrentCUDAStream();

TORCH_INTERNAL_ASSERT(iter.is_contiguous());
prelu_cuda_backward_share_weights_kernel<scalar_t><<<grid, num_threads, 0, stream>>>(
numel, input_data, grad_out_data, input_grad_data, weight_grad_collector_data, weight_data,
TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<2>()
);
}

template <typename scalar_t>
void prelu_cuda_backward_kernel_share_weights(
const Tensor& input,
Expand All @@ -212,7 +136,13 @@ void prelu_cuda_backward_kernel_share_weights(
.add_input(input)
.add_input(grad_out)
.build();
launch_prelu_cuda_backward_share_weights_kernel(iter, weight_data);

// N.B. `std::tuple` does not support `::operator=` on device code.
gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple<scalar_t, scalar_t> {
scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out;
scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out;
return {input_grad, weight_grad_collector};
});
}

template <typename scalar_t>
Expand Down
27 changes: 0 additions & 27 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include <ATen/core/Array.h>
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/macros/Macros.h>
#include <c10/core/ScalarType.h>
#include <c10/util/TypeCast.h>
Expand All @@ -60,32 +59,6 @@

namespace at { namespace native {

template<typename func_t, typename policy_t>
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;

int idx = blockIdx.x;

return_t results[thread_work_size];
args_t args[thread_work_size];

// load
policy.load(args, idx);

// compute
#pragma unroll
for (int i = 0; i < thread_work_size; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}
}

// store
policy.store(results, idx);
}

template<int vec_size, typename func_t, typename array_t>
C10_LAUNCH_BOUNDS_1(num_threads)
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
Expand Down
140 changes: 125 additions & 15 deletions aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@

#pragma once

#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>

namespace at { namespace native {

#define NUM_THREADS (C10_WARP_SIZE * 2)
#define THREAD_WORK_SIZE 4
#define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads)
Expand All @@ -16,25 +9,66 @@ constexpr int num_threads = NUM_THREADS;
constexpr int thread_work_size = THREAD_WORK_SIZE;
constexpr int block_work_size = BLOCK_WORK_SIZE;

#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>

#include <thrust/tuple.h>

namespace at { namespace native {

template<int N>
static OffsetCalculator<N> make_input_offset_calculator(const TensorIterator& iter) {
// array size can not be 0, this happens when N == 0
constexpr int array_size = std::max<int>(N, 1);
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - 1);
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
std::array<const int64_t*, array_size> strides;
int64_t element_sizes[array_size];
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i + 1).data();
element_sizes[i] = iter.element_size(i + 1);
strides[i] = iter.strides(i + iter.noutputs()).data();
element_sizes[i] = iter.element_size(i + iter.noutputs());
}
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}

static OffsetCalculator<1> make_output_offset_calculator(const TensorIterator& iter) {
std::array<const int64_t*, 1> strides;
strides[0] = iter.strides(0).data();
int64_t element_size = iter.element_size(0);
return OffsetCalculator<1>(iter.ndim(), iter.shape().data(), strides.data(), &element_size);
template <int num_outputs = 1>
static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorIterator& iter) {
TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
std::array<const int64_t*, num_outputs> strides;
int64_t element_sizes[num_outputs];
for (int i = 0; i < num_outputs; i++) {
strides[i] = iter.strides(i).data();
element_sizes[i] = iter.element_size(i);
}
return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}

template<typename func_t, typename policy_t>
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;

int idx = blockIdx.x;

return_t results[thread_work_size];
args_t args[thread_work_size];

// load
policy.load(args, idx);

// compute
#pragma unroll
for (int i = 0; i < thread_work_size; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}
}

// store
policy.store(results, idx);
}

}} // namespace at::native
Expand Down Expand Up @@ -130,4 +164,80 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
}
}

namespace { // functions for `gpu_kernel_multiple_outputs`.

// check the return type is `thrust::tuple`, not `std::tuple`.
template <typename T> struct is_tuple: std::false_type {};

template <typename ...T> struct is_tuple<thrust::tuple<T...>>: std::true_type {};

template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
C10_LAUNCH_BOUNDS_1(num_threads)
__global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) {
int remaining = N - block_work_size * blockIdx.x;
elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll<array_t, inp_calc_t, out_calc_t, num_outputs>(data, remaining, ic, oc));
}

template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size - 1) / block_work_size;
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_elementwise_kernel_for_multi_outputs<num_outputs, func_t, array_t><<<grid, num_threads, 0, stream>>>(N, f, data, ic, oc);
AT_CUDA_CHECK(cudaGetLastError());
}

template <typename func_t>
void gpu_kernel_multiple_outputs_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using output_t = typename traits::result_type;
static_assert(is_tuple<output_t>::value, "f's return type must be `thrust::tuple`");
constexpr int num_outputs = thrust::tuple_size<output_t>::value;
constexpr int num_inputs = traits::arity;
constexpr int ntensors = num_outputs + num_inputs;

TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);

at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}

int64_t numel = iter.numel();

if (iter.is_contiguous()) {
auto input_calc = TrivialOffsetCalculator<num_inputs>();
auto output_calc = TrivialOffsetCalculator<num_outputs>();
launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
} else {
auto input_calc = make_input_offset_calculator<num_inputs>(iter);
auto output_calc = make_output_offset_calculator<num_outputs>(iter);
launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
}
}
} // namespace

template <typename func_t>
void gpu_kernel_multiple_outputs(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);

for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
}

if (iter.numel() == 0) {
return;
}

if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel_multiple_outputs(sub_iter, f);
}
return;
}

gpu_kernel_multiple_outputs_impl(iter, f);
}

}} //namespace at::native
Loading

0 comments on commit eb9ae7c

Please sign in to comment.