forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_permutation_op.cc
171 lines (150 loc) · 4.4 KB
/
batch_permutation_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include "caffe2/operators/batch_permutation_op.h"
#include <cstring>
#include <vector>
#ifdef CAFFE2_USE_MKLDNN
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
#include <caffe2/ideep/utils/ideep_operator.h>
#endif
namespace caffe2 {
template <bool forwards>
void batch_permutation_loop(
const int N,
const int K,
const float* src,
const int* indices,
float* dst) {
long numBytes = K * sizeof(float);
if (forwards) {
#ifdef _OPENMP
#if (_OPENMP >= 201307)
#pragma omp parallel for simd
#else
#pragma omp parallel for
#endif
#endif
for (int n = 0; n < N; n++) {
int origIdx = n * K;
int permuteIdx = indices[n] * K;
std::memcpy(dst + origIdx, src + permuteIdx, numBytes);
}
} else {
std::vector<int> backward_indices(N);
for (int i = 0; i < N; ++i) {
backward_indices[indices[i]] = i;
}
for (int n = 0; n < N; n++) {
int permuteIdx = n * K;
int origIdx = backward_indices[n] * K;
std::memcpy(dst + permuteIdx, src + origIdx, numBytes);
}
}
}
template <>
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto& indices = Input(1);
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
CAFFE_ENFORCE(
X.dim32(0) == indices.dim32(0),
"X.dim32(0) must be equal to indices.dim32(0)",
"(",
X.dim32(0),
" vs. ",
indices.dim32(0),
")");
auto* Y = Output(0, X.sizes(), at::dtype<float>());
if (X.dim32(0) > 0) {
batch_permutation_loop<true>(
X.dim32(0),
X.numel() / X.dim32(0),
X.data<float>(),
indices.data<int>(),
Y->mutable_data<float>());
}
return true;
}
template <>
bool BatchPermutationGradientOp<float, CPUContext>::RunOnDevice() {
auto& indices = Input(0);
auto& dY = Input(1);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
if (dY.dim32(0) > 0) {
batch_permutation_loop<false>(
dY.dim32(0),
dY.numel() / dY.dim32(0),
dY.data<float>(),
indices.data<int>(),
dX->mutable_data<float>());
}
return true;
}
#ifdef CAFFE2_USE_MKLDNN
REGISTER_IDEEP_OPERATOR(
BatchPermutation,
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
#endif
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BatchPermutationGradient,
BatchPermutationGradientOp<float, CPUContext>);
// Input: X, indices; Output: Y
OPERATOR_SCHEMA(BatchPermutation)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Batch permutation of an input tensor X given input indices. First dimension of
X equals batch size N. The indices stores a be permutation of N.
The output Y is a tensor of same shape as X, with data re-ordered according to
the indices within the batch size.
Example of batch permutation on a 2-D tensor with batch size 4:
X = [
[1, 5, 2, 3, 4, 6, 0],
[4, 3, 3, 5, 2, 3, 1],
[2, 2, 3, 6, 0, 0, 1],
[0, 0, 1, 1, 2, 2, 3]
]
indices = [2, 0, 1, 3]
Y = [
[2, 2, 3, 6, 0, 0, 1],
[1, 5, 2, 3, 4, 6, 0],
[4, 3, 3, 5, 2, 3, 1],
[0, 0, 1, 1, 2, 2, 3]
]
Example of batch permutation on a 3-D tensor with batch size 4:
X = [
[[1, 5, 2], [3, 4, 6, 0]],
[[4, 3, 3], [5, 2, 3, 1]],
[[2, 2, 3], [6, 0, 0, 1]],
[[0, 0, 1], [1, 2, 2, 3]]
]
indices = [2, 0, 1, 3]
Y = [
[[2, 2, 3], [6, 0, 0, 1]],
[[1, 5, 2], [3, 4, 6, 0]],
[[4, 3, 3], [5, 2, 3, 1]],
[[0, 0, 1], [1, 2, 2, 3]]
]
)DOC")
.Input(0, "X", "Input tensor, where 1st dimension equals batch size")
.Input(1, "indices", "Input indices of batch to permute")
.Output(0, "Y", "Output permuted tensor");
// Input: indices, dY (aka "gradOutput"); Output: dX (aka "gradInput")
OPERATOR_SCHEMA(BatchPermutationGradient).NumInputs(2).NumOutputs(1);
class GetBatchPermutationGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"BatchPermutationGradient",
"",
vector<string>{I(1), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
} // namespace caffe2
using BatchPermutationOpFloatCPU =
caffe2::BatchPermutationOp<float, caffe2::CPUContext>;
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
BatchPermutation,
"_caffe2::BatchPermutation(Tensor X, Tensor indices) -> Tensor",
BatchPermutationOpFloatCPU);