forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDistributionBernoulli.cu
134 lines (122 loc) · 4.69 KB
/
DistributionBernoulli.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/AccumulateType.h>
#include <ATen/CUDAGenerator.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/DistributionTemplates.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <utility>
#include <functional>
#include <ATen/native/Distributions.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
#include <THC/THCGeneral.h>
#include <THC/THCApply.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <cstdint>
#include <limits>
#include <utility>
#include <type_traits>
namespace {
template<typename scalar_t, typename prob_t>
void bernoulli_tensor_cuda_kernel(
at::Tensor& ret, const at::Tensor& p,
std::pair<uint64_t, uint64_t> seeds) {
// The template argument `4` below indicates that we want to operate on four
// element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4>(
ret, p,
[seeds] __device__(
int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
blockIdx.x * blockDim.x + threadIdx.x,
seeds.second,
&state);
// See Note [Register spilling in curand call for CUDA < 10]
float4 rand = curand_uniform4(&state);
switch (n) {
case 4: {
CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
v4 = static_cast<scalar_t>(rand.w <= p4);
// fallthrough
}
case 3: {
CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
v3 = static_cast<scalar_t>(rand.z <= p3);
// fallthrough
}
case 2: {
CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
v2 = static_cast<scalar_t>(rand.y <= p2);
// fallthrough
}
case 1: {
CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
v1 = static_cast<scalar_t>(rand.x <= p1);
}
}
}
);
}
} // namespace
namespace at { namespace native {
Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator gen_) {
NoNamesGuard guard;
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(10);
}
auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
using self_t = scalar_t;
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] {
using p_t = scalar_t;
return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, rng_engine_inputs);
});
});
return self;
}
void bernoulli_scalar_cuda_kernel(TensorIterator& iter, double p_, Generator gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
AT_DISPATCH_ALL_TYPES_AND3(
at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
if (std::is_same<scalar_t, double>::value) {
// define lambda for bernoulli transformation
auto bernoulli_func = [p_] __device__ (double rand) {
return static_cast<scalar_t>(rand <= p_);
};
distribution_nullary_kernel<scalar_t, double, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
bernoulli_func);
} else {
auto p = static_cast<float>(p_);
auto bernoulli_func = [p] __device__ (float rand) {
return static_cast<scalar_t>(rand <= p);
};
distribution_nullary_kernel<scalar_t, float, curand4_engine_calls>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
bernoulli_func);
}
});
}
Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator gen) {
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
auto iter = TensorIterator::nullary_op(self);
bernoulli_scalar_cuda_kernel(iter, p, gen);
return self;
}
}} // namespace at::native