forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
GcdLcmKernel.cu
58 lines (52 loc) · 1.94 KB
/
GcdLcmKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/JitLoops.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Math.cuh>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/cuda/jit_utils.h>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at::native {
// See note [Jiterator]
CONSTEXPR_EXCEPT_WIN_CUDA char gcd_name[] = "gcd";
void gcd_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
jitted_gpu_kernel</*name=*/gcd_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 2>(iter, gcd_string);
});
#else
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
return calc_gcd(a, b);
});
});
#endif // AT_USE_JITERATOR()
}
// See note [Jiterator]
CONSTEXPR_EXCEPT_WIN_CUDA char lcm_name[] = "lcm";
void lcm_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
jitted_gpu_kernel</*name=*/lcm_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 2>(iter, lcm_string);
});
#else
AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
scalar_t g = calc_gcd(a, b);
return (g == 0) ? 0 : ::abs(a / g * b);
});
});
#endif // AT_USE_JITERATOR()
}
REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
} // namespace at::native