forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDeviceType.h
51 lines (41 loc) · 1.39 KB
/
DeviceType.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#pragma once
// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
// If you modify me, keep me synchronized with that file.
#include <c10/macros/Macros.h>
#include <ostream>
#include <functional>
namespace c10 {
enum class DeviceType : int16_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
COMPILE_TIME_MAX_DEVICE_TYPES = 8,
ONLY_FOR_TEST = 20901, // This device type is only for test.
};
// define explicit int constant
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
C10_API std::string DeviceTypeName(
DeviceType d,
bool lower_case = false);
C10_API bool isValidDeviceType(DeviceType d);
C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
} // namespace at
namespace std {
template <> struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std