16
16
*/
17
17
18
18
#include " embedding_lookup.h"
19
- #include " tensorflow/core/framework/op_kernel.h"
20
19
#include " tensorflow/core/framework/resource_mgr.h"
21
20
#include " tensorflow/core/framework/resource_var.h"
22
21
@@ -45,62 +44,25 @@ class ReadVariableNoCopyOp : public OpKernel {
45
44
DataType dtype_;
46
45
};
47
46
48
- template <typename Device, typename T, typename Tindices>
49
- class EmbeddingLookupConstantHotnessOp : public OpKernel {
50
- public:
51
- explicit EmbeddingLookupConstantHotnessOp (OpKernelConstruction* context) : OpKernel(context) {
52
- OP_REQUIRES_OK (context, context->GetAttr (" combiner" , &_combiner));
53
- }
54
-
55
- void Compute (OpKernelContext* context) override {
56
- const Tensor& params = context->input (0 );
57
- const Tensor& ids = context->input (1 );
58
-
59
- auto num_rows = ids.dim_size (0 );
60
- auto nnz_per_row = ids.dim_size (1 );
61
- auto embedding_width = params.dim_size (1 );
62
-
63
- TensorShape output_shape ({num_rows, embedding_width});
64
- Tensor* output = nullptr ;
65
- OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
66
-
67
- EmbeddingLookupConstantHotnessFunctor<Device, T, Tindices>()(
68
- context->eigen_device <Device>(), output->flat <T>().data (), params.flat <T>().data (),
69
- ids.flat <Tindices>().data (), nnz_per_row, num_rows, embedding_width,
70
- StringToEnum (_combiner));
71
- }
72
-
73
- private:
74
- string _combiner;
75
- };
76
-
77
- template <typename Device, typename T, typename Tindices>
78
- class EmbeddingLookupConstantHotnessGradOp : public OpKernel {
47
+ template <typename Device, typename Tindices>
48
+ class RowToSplitOp : public OpKernel {
79
49
public:
80
- explicit EmbeddingLookupConstantHotnessGradOp (OpKernelConstruction* context) : OpKernel(context) {
81
- OP_REQUIRES_OK (context, context->GetAttr (" combiner" , &_combiner));
82
- }
50
+ explicit RowToSplitOp (OpKernelConstruction* context) : OpKernel(context) {}
83
51
84
52
void Compute (OpKernelContext* context) override {
85
- const Tensor& grad = context->input (0 );
86
- const Tensor& ids = context->input (1 );
53
+ // [n, 2]
54
+ const Tensor& row = context->input (0 );
55
+ auto num_ids = row.dim_size (0 );
56
+ auto num_rows = context->input (1 ).scalar <int32>()();
87
57
88
- auto num_rows = ids.dim_size (0 );
89
- auto nnz_per_row = ids.dim_size (1 );
90
- auto nnz = num_rows * nnz_per_row;
91
- auto embedding_width = grad.dim_size (1 );
92
-
93
- TensorShape output_shape ({nnz, embedding_width});
58
+ TensorShape output_shape ({num_rows + 1 });
94
59
Tensor* output = nullptr ;
95
60
OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
96
61
97
- EmbeddingLookupConstantHotnessGradFunctor <Device, T, Tindices>()(
98
- context-> eigen_device <Device>(), output->flat <T>(). data (), grad. flat <T >().data (),
99
- nnz_per_row, num_rows, embedding_width, StringToEnum (_combiner) );
62
+ RowToSplitFunctor <Device, Tindices>()(context-> eigen_device <Device>(),
63
+ output->flat <Tindices >().data (),
64
+ row. flat <Tindices>(). data (), num_ids, num_rows );
100
65
}
101
-
102
- private:
103
- string _combiner;
104
66
};
105
67
106
68
template <typename Device, typename T, typename Tindices>
@@ -118,14 +80,17 @@ class EmbeddingLookupVariableHotnessOp : public OpKernel {
118
80
auto num_rows = offsets.dim_size (0 ) - 1 ;
119
81
auto embedding_width = params.dim_size (1 );
120
82
83
+ auto num_ids = ids.dim_size (0 );
84
+ auto ave_red_len = num_ids / num_rows;
85
+
121
86
TensorShape output_shape ({num_rows, embedding_width});
122
87
Tensor* output = nullptr ;
123
88
OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
124
89
125
90
EmbeddingLookupVariableHotnessFunctor<Device, T, Tindices>()(
126
91
context->eigen_device <Device>(), output->flat <T>().data (), params.flat <T>().data (),
127
92
ids.flat <Tindices>().data (), offsets.flat <Tindices>().data (), num_rows, embedding_width,
128
- StringToEnum (_combiner));
93
+ StringToEnum (_combiner), ave_red_len );
129
94
}
130
95
131
96
private:
@@ -140,21 +105,20 @@ class EmbeddingLookupVariableHotnessGradOp : public OpKernel {
140
105
}
141
106
142
107
void Compute (OpKernelContext* context) override {
143
- const Tensor& grad = context->input (0 );
144
- const Tensor& ids = context->input (1 );
145
- const Tensor& offsets = context->input (2 );
146
-
147
- auto num_rows = offsets.dim_size (0 ) - 1 ;
108
+ const Tensor& ids = context->input (0 );
109
+ const Tensor& offset_in = context->input (1 );
110
+ const Tensor& grad = context->input (2 );
111
+ const Tensor& param = context->input (3 );
112
+ auto num_ids = ids.dim_size (0 );
113
+ auto num_rows = offset_in.dim_size (0 ) - 1 ;
148
114
auto embedding_width = grad.dim_size (1 );
149
- auto nnz = ids.dim_size (0 );
150
-
151
- TensorShape output_shape ({nnz, embedding_width});
152
- Tensor* output = nullptr ;
153
- OP_REQUIRES_OK (context, context->allocate_output (0 , output_shape, &output));
115
+ auto max_red_len = grad.dim_size (0 );
116
+ auto dense_shape_dim0 = param.dim_size (0 );
154
117
155
118
EmbeddingLookupVariableHotnessGradFunctor<Device, T, Tindices>()(
156
- context->eigen_device <Device>(), output->flat <T>().data (), grad.flat <T>().data (),
157
- offsets.flat <Tindices>().data (), num_rows, embedding_width, StringToEnum (_combiner));
119
+ context, ids.flat <Tindices>().data (), offset_in.flat <Tindices>().data (),
120
+ grad.flat <T>().data (), num_ids, embedding_width, num_rows, dense_shape_dim0, max_red_len,
121
+ StringToEnum (_combiner));
158
122
}
159
123
160
124
private:
@@ -167,26 +131,21 @@ REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_DEFAULT).HostMe
167
131
REGISTER_KERNEL_BUILDER (Name(" ReadVariableNoCopy" ).Device(DEVICE_GPU).HostMemory(" resource" ),
168
132
ReadVariableNoCopyOp);
169
133
170
- #define REGISTER_GPU (T, Tindices ) \
171
- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupConstantHotness" ) \
172
- .Device(DEVICE_GPU) \
173
- .TypeConstraint<T>(" T" ) \
174
- .TypeConstraint<Tindices>(" Tindices" ), \
175
- EmbeddingLookupConstantHotnessOp<Eigen::GpuDevice, T, Tindices>); \
176
- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupConstantHotnessGrad" ) \
177
- .Device(DEVICE_GPU) \
178
- .TypeConstraint<T>(" T" ) \
179
- .TypeConstraint<Tindices>(" Tindices" ), \
180
- EmbeddingLookupConstantHotnessGradOp<Eigen::GpuDevice, T, Tindices>); \
181
- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotness" ) \
182
- .Device(DEVICE_GPU) \
183
- .TypeConstraint<T>(" T" ) \
184
- .TypeConstraint<Tindices>(" Tindices" ), \
185
- EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
186
- REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotnessGrad" ) \
187
- .Device(DEVICE_GPU) \
188
- .TypeConstraint<T>(" T" ) \
189
- .TypeConstraint<Tindices>(" Tindices" ), \
134
+ #define REGISTER_GPU (T, Tindices ) \
135
+ REGISTER_KERNEL_BUILDER (Name(" RowToSplit" ) \
136
+ .Device(DEVICE_GPU) \
137
+ .TypeConstraint<Tindices>(" Tindices" ) \
138
+ .HostMemory(" shape" ), \
139
+ RowToSplitOp<Eigen::GpuDevice, Tindices>); \
140
+ REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotness" ) \
141
+ .Device(DEVICE_GPU) \
142
+ .TypeConstraint<T>(" T" ) \
143
+ .TypeConstraint<Tindices>(" Tindices" ), \
144
+ EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
145
+ REGISTER_KERNEL_BUILDER (Name(" EmbeddingLookupVariableHotnessGrad" ) \
146
+ .Device(DEVICE_GPU) \
147
+ .TypeConstraint<T>(" T" ) \
148
+ .TypeConstraint<Tindices>(" Tindices" ), \
190
149
EmbeddingLookupVariableHotnessGradOp<Eigen::GpuDevice, T, Tindices>);
191
150
192
151
REGISTER_GPU (float , int64_t )
0 commit comments