forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dropout.cu
416 lines (380 loc) · 15.3 KB
/
Dropout.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/Dispatch.h>
#include <ATen/Utils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/macros/Macros.h>
#include <curand_kernel.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_masked_scale_native.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/native_dropout_backward_native.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>
#endif
namespace at::native {
namespace {
// philox generates 128 bits of randomness at a time. Kernel uses this explicitly by putting suitably transformed result into float4
// for all members of float4 to be consumed UNROLL has to be 4. Don't change!
// Note: VEC <= 4 (and in most real-world cases will be 4), so same logic applies.
const int UNROLL = 4;
template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
int ADims,
int VEC,
typename mask_t>
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void
fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType> a,
at::cuda::detail::TensorInfo<scalar_t, IndexType> b,
at::cuda::detail::TensorInfo<mask_t, IndexType> c,
IndexType totalElements, accscalar_t p,
PhiloxCudaState philox_args) {
// make sure we don't break assumption that we can't have > 4 elements / thread
static_assert(VEC <= 4, "Value of VEC must be in [2, 4]");
using LoadT = memory::aligned_vector<scalar_t, VEC>;
using MaskLoadT = memory::aligned_vector<mask_t, VEC>;
auto seeds = at::cuda::philox::unpack(philox_args);
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds),
idx,
std::get<1>(seeds),
&state);
// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
// in the vec=2 and vec=4 cases.
bool gridxvec_loop_state = 0;
accscalar_t scale = 1.0 / p;
float4 rand;
// Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
for (IndexType linearIndex = idx * VEC;
linearIndex < totalElements;
linearIndex += gridDim.x * blockDim.x * VEC) {
// local storage
scalar_t src[VEC];
// We'll use this to actually cause vectorized loads later
LoadT *value = reinterpret_cast<LoadT*>(&src);
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
// sets of rand.
if ((VEC == 4) || (gridxvec_loop_state == 0)) {
rand = curand_uniform4(&state);
} else {
// sets up the last two values we generated last iteration to be used this iteration.
rand.x = rand.z;
rand.y = rand.w;
gridxvec_loop_state ^= 1;
}
rand.x = rand.x < p;
rand.y = rand.y < p;
if (VEC == 4) {
rand.z = rand.z < p;
rand.w = rand.w < p;
}
// Note: We explicitly check for is_contiguous() before launching the vectorized kernel
// and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other)
// ordering.
// Single vectorized load
*value = *reinterpret_cast<const LoadT*>(&a.data[linearIndex]);
scalar_t r[VEC];
mask_t mask[VEC];
// Perform the actual computation
#pragma unroll
for (int ii = 0; ii < VEC; ii++) {
r[ii] = src[ii]*(&rand.x)[ii]*scale;
mask[ii] = (mask_t)(&rand.x)[ii];
}
// Vectorized writes for both mask & result
*(reinterpret_cast<LoadT*>(&b.data[linearIndex])) = *reinterpret_cast<LoadT*>(&r[0]);
*(reinterpret_cast<MaskLoadT*>(&c.data[linearIndex])) = *reinterpret_cast<MaskLoadT*>(&mask[0]);
__syncthreads();
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType,
int ADims,
int BDims = ADims,
typename mask_t>
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(256, 4)
#endif
__global__ void
fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t, IndexType> a,
cuda::detail::TensorInfo<scalar_t, IndexType> b,
cuda::detail::TensorInfo<mask_t, IndexType> c,
IndexType totalElements, accscalar_t p,
PhiloxCudaState philox_args) {
auto seeds = at::cuda::philox::unpack(philox_args);
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds),
idx,
std::get<1>(seeds),
&state);
accscalar_t scale = 1.0 / p;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) *
blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
rand.x = rand.x < p;
rand.y = rand.y < p;
rand.z = rand.z < p;
rand.w = rand.w < p;
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
// Convert `linearIndex` into an offset of `a`
const IndexType aOffset =
cuda::detail::IndexToOffset<const scalar_t, IndexType, ADims>::get(li, a);
src[ii] = a.data[aOffset];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
cuda::detail::IndexToOffset<scalar_t, IndexType, BDims>::get(li, b);
b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale;
c.data[bOffset] = (mask_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}
template<typename mask_t, typename scalar_t, typename accscalar_t>
void masked_scale_kernel(at::Tensor& ret, const at::Tensor& src, const at::Tensor& mask, accscalar_t scale){
auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(ret)
.add_const_input(src)
.add_const_input(mask)
.build();
at::native::gpu_kernel(
iter,
[=]GPU_LAMBDA(const scalar_t src_val, const mask_t mask_val) -> scalar_t {
return (float)mask_val * src_val * scale;
});
}
template <typename scalar_t>
int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) {
int vec_size = 4;
// get the vector size
if (!self.is_non_overlapping_and_dense() || !ret.is_non_overlapping_and_dense() || !mask.is_non_overlapping_and_dense()) {
vec_size = 1;
} else {
vec_size = memory::can_vectorize_up_to<scalar_t>((const char*)self.const_data_ptr());
}
// check that we'd have no remainders - prefer a smaller vector size with no remainders over a larger vector and remainder.
bool can_vectorize = true;
do {
can_vectorize = self.numel() % vec_size == 0 && ret.numel() % vec_size == 0 && mask.numel() % vec_size == 0;
if (!can_vectorize) vec_size /= 2;
} while (vec_size > 1 && !can_vectorize);
return can_vectorize ? vec_size : 1;
}
template <typename index_type, typename mask_t>
inline void launcher(
const Tensor& self,
Tensor& ret,
Tensor& mask,
double p,
const int64_t nelem,
const PhiloxCudaState rng_engine_inputs,
dim3 grid,
dim3 dim_block) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
self.scalar_type(),
"fused_dropout",
[&] {
using accscalar_t = acc_type<scalar_t, true>;
accscalar_t pa = (accscalar_t)(p);
auto self_info =
cuda::detail::getTensorInfo<const scalar_t, index_type>(self);
auto ret_info =
cuda::detail::getTensorInfo<scalar_t, index_type>(ret);
auto mask_info =
cuda::detail::getTensorInfo<mask_t, index_type>(mask);
self_info.collapseDims();
ret_info.collapseDims();
mask_info.collapseDims(); // ret and mask are collapsed to 1d
// contiguous tensor
int vec_size = get_vector_size<scalar_t>(self, ret, mask);
if (vec_size > 1) {
switch (vec_size) {
case 4:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
4>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
case 2:
fused_dropout_kernel_vec<
scalar_t,
accscalar_t,
index_type,
1,
2>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}
} else {
switch (self_info.dims) {
case 1:
fused_dropout_kernel<scalar_t, accscalar_t, index_type, 1>
<<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
default:
if (!self.is_contiguous() && ret.is_contiguous() &&
mask.is_contiguous()) {
fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1, 1>
<<<grid,
dim_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
fused_dropout_kernel<scalar_t, accscalar_t, index_type, -1>
<<<grid,
dim_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
self_info,
ret_info,
mask_info,
nelem,
pa,
rng_engine_inputs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
}
});
}
} //anonymous namespace
template <typename mask_t>
std::tuple<Tensor,Tensor>
dropout_cuda(CUDAGeneratorImpl* gen, const Tensor& self, double p){
Tensor mask = at::empty_like(self, self.options().dtype(c10::CppTypeToScalarType<mask_t>::value));
const int64_t nelem = self.numel();
// empty tensors should not get here, but just in case, avoid FPE
// non-training shot-cut
if (nelem==0) return std::tuple<Tensor,Tensor>(self.clone(), mask);
Tensor ret = at::empty_like(self);
const int64_t block_size = 256;
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
dim3 dim_block(block_size);
dim3 grid((nelem + block_size -1)/block_size);
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
PhiloxCudaState rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_cuda_state(counter_offset);
}
if (cuda::detail::canUse32BitIndexMath(self)){
launcher<unsigned int, mask_t>(
self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
} else {
launcher<uint64_t, mask_t>(
self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block);
}
return std::tuple<Tensor,Tensor>(ret, mask);
}
std::tuple<Tensor,Tensor>
native_dropout_cuda(const Tensor& self, double p, std::optional<bool> train){
// short-cut for train == false
if (train.has_value() && !train.value()) {
return std::make_tuple(self.clone(), at::ones_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value)));
}
// short-cut
if (p == 1) {
// native_dropout_cuda is in derivatives.yaml, so we don't need to add data
// dependency from output to input for autograd
auto ret = at::zeros_like(self);
auto mask = at::zeros_like(self, self.options().dtype(c10::CppTypeToScalarType<bool>::value));
return std::tuple<Tensor,Tensor>(ret, mask);
}
auto gen = get_generator_or_default<CUDAGeneratorImpl>(std::nullopt, cuda::detail::getDefaultCUDAGenerator());
double p1m = 1. - p;
return dropout_cuda<bool>(gen, self, p1m);
}
// TODO: _fused_dropout_cuda is to be removed, see PR #63937
std::tuple<Tensor,Tensor>
fused_dropout_cuda(const Tensor& self, double p, std::optional<Generator> gen_){
auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
return dropout_cuda<uint8_t>(gen, self, p);
}
template <typename mask_t>
Tensor dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
Tensor ret = at::empty_like(grad, grad.suggest_memory_format());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] {
using accscalar_t = acc_type<scalar_t, true>;
masked_scale_kernel<mask_t, scalar_t>(ret, grad, mask, (accscalar_t)scale);
});
return ret;
}
Tensor native_dropout_backward_cuda(const Tensor& grad, const Tensor& mask, double scale){
TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Mask should be Bool Scalar Type", mask.scalar_type());
return dropout_backward_cuda<bool>(grad, mask, scale);
}
// TODO: masked_scale_cuda is to be removed, see PR #63937
Tensor masked_scale_cuda(const Tensor& self, const Tensor& mask, double scale){
TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype");
return dropout_backward_cuda<uint8_t>(self, mask, scale);
}
} // namespace at::native