forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DLConvertor.cpp
295 lines (283 loc) · 8.62 KB
/
DLConvertor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
#include <ATen/DLConvertor.h>
#include <ATen/Functions.h>
#include <iostream>
#include <sstream>
using namespace std;
namespace at {
DLDataType getDLDataType(const Tensor& t) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = t.element_size() * 8;
switch (t.scalar_type()) {
case ScalarType::Byte:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::Char:
dtype.code = DLDataTypeCode::kDLInt;
break;
// NOLINTNEXTLINE(bugprone-branch-clone)
case ScalarType::Double:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Float:
dtype.code = DLDataTypeCode::kDLFloat;
break;
// NOLINTNEXTLINE(bugprone-branch-clone)
case ScalarType::Int:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Long:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Short:
dtype.code = DLDataTypeCode::kDLInt;
break;
case ScalarType::Half:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case ScalarType::Bool:
TORCH_CHECK(false, "Bool type is not supported by dlpack");
break;
case ScalarType::ComplexHalf:
dtype.code = DLDataTypeCode::kDLComplex;
break;
case ScalarType::ComplexFloat:
dtype.code = DLDataTypeCode::kDLComplex;
break;
case ScalarType::ComplexDouble:
dtype.code = DLDataTypeCode::kDLComplex;
break;
case ScalarType::BFloat16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
case ScalarType::QInt8:
case ScalarType::QUInt8:
case ScalarType::QInt32:
case ScalarType::QUInt4x2:
case ScalarType::QUInt2x4:
TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::Bits1x8:
case ScalarType::Bits2x4:
case ScalarType::Bits4x2:
case ScalarType::Bits8:
case ScalarType::Bits16:
TORCH_CHECK(false, "Bit types are not supported by dlpack");
break;
case ScalarType::Undefined:
TORCH_CHECK(false, "Undefined is not a valid ScalarType");
case ScalarType::NumOptions:
TORCH_CHECK(false, "NumOptions is not a valid ScalarType");
}
return dtype;
}
static DLDevice getDLDevice(const Tensor& tensor, const int64_t& device_id) {
DLDevice ctx;
ctx.device_id = device_id;
switch (tensor.device().type()) {
case DeviceType::CPU:
ctx.device_type = DLDeviceType::kDLCPU;
break;
case DeviceType::CUDA:
#ifdef USE_ROCM
// ROCM, if enabled will look like cuda to PyTorch
// while everyone else should see HIP
ctx.device_type = DLDeviceType::kDLROCM;
#else
ctx.device_type = DLDeviceType::kDLCUDA;
#endif
break;
case DeviceType::OPENCL:
ctx.device_type = DLDeviceType::kDLOpenCL;
break;
case DeviceType::HIP:
ctx.device_type = DLDeviceType::kDLROCM;
break;
case DeviceType::XPU:
ctx = at::detail::getXPUHooks().getDLPackDeviceFromATenDevice(
ctx, tensor.device(), tensor.data_ptr());
break;
default:
TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
}
return ctx;
}
static Device getATenDevice(const DLDevice& ctx, void* data) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return at::Device(DeviceType::CPU);
#ifndef USE_ROCM
// if we are compiled under HIP, we cannot do cuda
case DLDeviceType::kDLCUDA:
return at::Device(DeviceType::CUDA, ctx.device_id);
#endif
case DLDeviceType::kDLOpenCL:
return at::Device(DeviceType::OPENCL, ctx.device_id);
case DLDeviceType::kDLROCM:
#ifdef USE_ROCM
// this looks funny, we need to return CUDA here to masquerade
return at::Device(DeviceType::CUDA, ctx.device_id);
#else
return at::Device(DeviceType::HIP, ctx.device_id);
#endif
case DLDeviceType::kDLOneAPI:
return at::detail::getXPUHooks().getATenDeviceFromDLPackDevice(ctx, data);
default:
TORCH_CHECK(
false, "Unsupported device_type: " + c10::to_string(ctx.device_type));
}
}
ScalarType toScalarType(const DLDataType& dtype) {
ScalarType stype;
TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1");
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Byte;
break;
default:
TORCH_CHECK(
false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
stype = ScalarType::Char;
break;
case 16:
stype = ScalarType::Short;
break;
case 32:
stype = ScalarType::Int;
break;
case 64:
stype = ScalarType::Long;
break;
default:
TORCH_CHECK(
false, "Unsupported kInt bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
stype = ScalarType::Half;
break;
case 32:
stype = ScalarType::Float;
break;
case 64:
stype = ScalarType::Double;
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
stype = ScalarType::BFloat16;
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
case DLDataTypeCode::kDLComplex:
switch (dtype.bits) {
case 32:
stype = ScalarType::ComplexHalf;
break;
case 64:
stype = ScalarType::ComplexFloat;
break;
case 128:
stype = ScalarType::ComplexDouble;
break;
default:
TORCH_CHECK(
false, "Unsupported kFloat bits " + c10::to_string(dtype.bits));
}
break;
default:
TORCH_CHECK(
false, "Unsupported code " + c10::to_string(dtype.code));
}
return stype;
}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct ATenDLMTensor {
Tensor handle;
DLManagedTensor tensor;
};
static void deleter(DLManagedTensor* arg) {
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
}
// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
auto strides = src.strides().vec();
for (int i=0; i<src.dim(); i++) {
if (shape[i] < 2) {
strides[i] = 1;
}
}
auto view = src.as_strided(shape, strides, src.storage_offset());
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
atDLMTensor->handle = view;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
int64_t device_id = 0;
if (src.is_cuda()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(atDLMTensor->tensor);
}
Tensor fromDLPack(const DLManagedTensor* src) {
auto deleter = [src](void* self) {
if (src->deleter) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
src->deleter(const_cast<DLManagedTensor*>(src));
}
};
return fromDLPack(src, std::move(deleter));
}
Tensor fromDLPack(
const DLManagedTensor* src,
std::function<void(void*)> deleter) {
Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
if (!src->dl_tensor.strides) {
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
deleter,
at::device(device).dtype(stype),
{device});
}
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim),
deleter,
at::device(device).dtype(stype),
{ device });
}
} // namespace at