forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMKLDNNConversions.cpp
246 lines (220 loc) · 8.9 KB
/
MKLDNNConversions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_to_dense_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/mkldnn_reorder_conv2d_weight_native.h>
#include <ATen/ops/mkldnn_reorder_conv3d_weight_native.h>
#include <ATen/ops/to_mkldnn_native.h>
#endif
namespace at { namespace native {
#if AT_MKLDNN_ENABLED()
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {
TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
mkldnn_tensor.scalar_type() == ScalarType::BFloat16,
"mkldnn_to_dense expects float or bfloat16 tensor input");
ideep::tensor& stensor = itensor_from_mkldnn(mkldnn_tensor);
auto dims = stensor.get_dims();
auto data_type = dtype.has_value() ? dtype.value() : mkldnn_tensor.scalar_type();
TORCH_CHECK(data_type == ScalarType::Float || data_type == ScalarType::BFloat16,
"mkldnn tensor only can be converted to be a float or bfloat16 cpu tensor")
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
Tensor cpu_tensor = at::empty(
std::vector<int64_t>(dims.begin(), dims.end()),
mkldnn_tensor.options().layout(c10::kStrided).dtype(data_type));
if (stensor.is_empty()) return cpu_tensor;
auto pub_tensor =
data_type == ScalarType::Float
? stensor.to_public(cpu_tensor.template data_ptr<float>(),
ideep::tensor::data_type::f32)
: stensor.to_public(cpu_tensor.template data_ptr<BFloat16>(),
ideep::tensor::data_type::bf16);
cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
return cpu_tensor.contiguous();
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional<ScalarType> dtype) {
TORCH_CHECK(cpu_tensor.device().is_cpu(),
"dense_to_mkldnn expects CPU tensor input");
TORCH_CHECK(cpu_tensor.layout() == Layout::Strided,
"dense_to_mkldnn expects strided tensor input");
TORCH_CHECK(cpu_tensor.scalar_type() == ScalarType::Float ||
cpu_tensor.scalar_type() == ScalarType::BFloat16,
"dense_to_mkldnn expects float or bfloat16 tensor input");
TORCH_CHECK(cpu_tensor.dim() <= 5,
"Can't convert cpu tensor with the number of dimensions > 5");
// NOTE: forbid direct convert from non-contiguous (or channels last) to `ideep::tensor`.
auto cpu_tensor_cont = cpu_tensor.contiguous();
auto data_type = dtype.has_value() ? dtype.value() : cpu_tensor.scalar_type();
TORCH_CHECK(data_type == ScalarType::Float || data_type == ScalarType::BFloat16,
"cpu tensor only can be converted to be a float or bfloat16 mkldnn tensor")
Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), data_type,
cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(),
cpu_tensor_cont.options().pinned_memory_opt());
ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor);
if (cpu_tensor.scalar_type() == ScalarType::Float) {
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::f32,
(cpu_tensor_cont.template data_ptr<float>()));
} else {
dtensor.feed_from(dtensor.get_dims(),
ideep::tensor::data_type::bf16,
cpu_tensor_cont.template data_ptr<BFloat16>());
}
return mkldnn_tensor;
}
// Mkldnn tensor has special non-public format for conv2d weights
// (dense_to_mkldnn only converts dense tensor to mkldnn tensor with
// public format). Ideep conv kernel will do implicit reorder if the
// weight is not already in this optimized format. By the time I'm
// writing this note, we are seeing ~20% perf cost of doing the
// on-the-fly reorder.
Tensor mkldnn_reorder_conv2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_reorder_conv2d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
auto w = itensor_from_mkldnn(self);
// Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
// dimension when groups > 1, having dimension [g, o/g, i, h, w] instead of
// [o, i, h, w]. Ideally we should reorder the weight back in serialization.
// For backward compatibility, we squash the first two dims (g * o/g) back to
// its original form.
if (w.ndims() == 5) {
auto wdims = w.get_dims();
w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]});
}
ideep::dims src_dims = ideep::dims();
bool is_channels_last = false;
if (input_size.has_value()) {
src_dims = input_size.value().vec();
// if has input size, we always use channels last.
is_channels_last = true;
}
auto desc = ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
{stride.begin(), stride.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
{dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct,
ideep::prop_kind::forward,
w.get_data_type(),
src_dims,
ideep::attr_t(),
is_channels_last);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().device_opt());
}
Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
if (self.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_reorder_conv3d_weight: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
auto w = itensor_from_mkldnn(self);
auto desc =
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
{stride.begin(), stride.end()},
{padding.begin(), padding.end()},
{padding.begin(), padding.end()},
{dilation.begin(), dilation.end()},
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
}
#else
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor dense_to_mkldnn(const Tensor& cpu_tensor, c10::optional<ScalarType> dtype) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv2d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
TORCH_CHECK(false, "mkldnn_reorder_conv2d_weight: MKL-DNN build is disabled");
}
Tensor mkldnn_reorder_conv3d_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups) {
TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
}
#endif // AT_MKLDNN_ENABLED()
#if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
#include <mkl.h>
Tensor mkl_reorder_linear_weight(
const Tensor& weight,
const int64_t batch_size) {
TORCH_CHECK(
weight.scalar_type() == ScalarType::Float,
"reorder_linear_weight: weight's dtype should be float");
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
auto M = batch_size;
auto N = weight.size(0);
auto K = weight.size(1);
int64_t pack_size =
(int64_t)(cblas_sgemm_pack_get_size(CblasBMatrix, M, N, K) / sizeof(float) + 1);
auto packed_weight = empty_mkldnn(
{pack_size, 1},
weight.scalar_type(),
weight.options().layout_opt(),
weight.options().device_opt(),
weight.options().pinned_memory_opt());
ideep::tensor& mkl_weight = itensor_from_mkldnn(packed_weight);
ideep::tensor& orig_w = itensor_from_mkldnn(weight);
cblas_sgemm_pack(
CblasRowMajor,
CblasBMatrix,
CblasTrans,
M,
N,
K,
1.0f,
(float*)(orig_w.get_data_handle()),
K,
(float*)(mkl_weight.get_data_handle()));
return packed_weight;
}
TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("mkl::_mkl_reorder_linear_weight"),
TORCH_FN(mkl_reorder_linear_weight));
}
#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
}}