forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cc_bmm_bg_op.h
143 lines (126 loc) · 3.8 KB
/
cc_bmm_bg_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#ifndef CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#define CAFFE2_FB_OPERATORS_CC_BMM_BG_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
using T = float;
using TInd = int;
using Engine = DefaultEngine;
template <class Context>
class ConcatBatchMatMulBatchGatherOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ConcatBatchMatMulBatchGatherOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
bool RunOnDevice() override;
protected:
int axis_ = 1;
int add_axis_ = 1;
bool trans_a_ = 0;
bool trans_b_ = 1;
bool broadcast_ = 0;
};
template <class Context>
bool ConcatBatchMatMulBatchGatherOp<Context>::RunOnDevice() {
auto& indices = Input(0);
auto& input_zero = Input(1);
int adj_size = input_zero.dim() + 1;
int canonical_axis = 1;
CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
for (int i = 2; i < InputSize(); ++i) {
CAFFE_ENFORCE(
Input(i).dtype() == input_zero.dtype(),
"All inputs must have the same type, expected: ",
input_zero.dtype().name(),
" but got: ",
Input(i).dtype().name(),
" for input: ",
i);
}
int before = 1, after = 1;
for (int i = 0; i < input_zero.dim(); ++i) {
int dim = input_zero.dim32(i);
if (i < canonical_axis) {
before *= dim;
} else { // i > canonical_axis || i == canonical_axis && add_axis_
after *= dim;
}
// check the input dims are compatible.
for (int j = 2; j < InputSize(); ++j) {
int dim_j = Input(j).dim32(i);
CAFFE_ENFORCE(
dim == dim_j,
"Expect dimension = ",
dim,
" got ",
dim_j,
" at axis = ",
i,
" for input: ",
j,
". The input tensors can only have different dimensions "
"when arg 'add_axis' = 0 and along the axis = ",
canonical_axis,
" <",
input_zero.sizes(),
"> vs <",
Input(j).sizes(),
">.");
}
}
auto ndata = InputSize() - 1;
auto batch_size = before;
auto embed_size = after;
auto gather_size = indices.sizes()[0];
vector<int64_t> output_dims;
output_dims.push_back(batch_size);
output_dims.insert(
output_dims.begin() + 1, indices.sizes().begin(), indices.sizes().end());
auto* output = Output(0, output_dims, at::dtype<T>());
// std::stringstream ss;
// ss << "[";
// for(int i = 0; i < output_dims.size(); i++) ss << output_dims[i];
// ss << "]";
// LOG(INFO) << "output size: " << ss.str();
auto* output_data = output->template mutable_data<T>();
auto* indices_data = indices.template data<TInd>();
#pragma omp parallel
{
std::vector<T> scratch_input(ndata * embed_size);
std::vector<T> scratch_output(ndata * ndata);
#pragma omp for
for (int b = 0; b < batch_size; ++b) {
// concat input to scratch
for (int i = 1; i < InputSize(); ++i) {
auto* input_data = Input(i).template data<T>();
memcpy(
&scratch_input[(i - 1) * embed_size],
input_data + b * embed_size,
embed_size * Input(i).itemsize());
}
// call mkl gemm
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasTrans,
ndata,
ndata,
embed_size,
1,
&scratch_input[0],
&scratch_input[0],
0,
&scratch_output[0],
&context_);
// do gather
int64_t output_offset = b * gather_size;
for (int i = 0; i < gather_size; i++) {
output_data[output_offset + i] = scratch_output[indices_data[i]];
}
}
}
return true;
}
} // namespace caffe2
#endif // CAFFE2_FB_OPERATORS_CC_BMM_BG_H_