Skip to content

Commit 6e7b613

Browse files
committed
version 0.2 updates
1 parent 427f869 commit 6e7b613

18 files changed

+962
-368
lines changed

.clang-format

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
BasedOnStyle: Google
33
ColumnLimit: 100
44
DerivePointerAlignment: false
5+
StatementMacros:
6+
- _Pragma

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/cub"]
2+
path = third_party/cub
3+
url = https://github.com/NVIDIA/cub.git

.pylintrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ enable=indexing-exception,old-raise-syntax
3838
# --enable=similarities". If you want to run only the classes checker, but have
3939
# no Warning level messages displayed, use"--disable=all --enable=classes
4040
# --disable=W"
41-
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,arguments-differ,missing-function-docstring,unexpected-keyword-arg,no-value-for-parameter
41+
disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,invalid-sequence-index,arguments-differ,missing-function-docstring,unexpected-keyword-arg,no-value-for-parameter,missing-return-type-doc,missing-type-doc,missing-param-doc
4242

4343

4444
# Set the cache size for astng objects.

Makefile

+9-3
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@ PYTHON_BIN_PATH = python
2020

2121
TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
2222
TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
23-
24-
CFLAGS = ${TF_CFLAGS} -O3 -std=c++14
23+
TF_VERSION := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(int(tf.__version__.split(".")[1]))')
24+
ifeq ($(shell expr $(TF_VERSION) \>= 10), 1)
25+
CPP_STD := 17
26+
else
27+
CPP_STD := 14
28+
endif
29+
30+
CFLAGS = ${TF_CFLAGS} -O3 -std=c++${CPP_STD}
2531
LDFLAGS = -shared ${TF_LFLAGS}
2632

2733
SRC = embedding_lookup_kernels
@@ -34,7 +40,7 @@ TARGET_LIB = distributed_embeddings/python/ops/_embedding_lookup_ops.so
3440
all: $(TARGET_LIB)
3541

3642
%_kernels.cu.o: distributed_embeddings/cc/kernels/%_kernels.cu distributed_embeddings/cc/kernels/%.h
37-
$(NVCC) -c -o $@ $< $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr
43+
$(NVCC) -c -o $@ $< -Ithird_party/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr
3844

3945
%_kernels.cc.o: distributed_embeddings/cc/kernels/%_kernels.cc distributed_embeddings/cc/kernels/%.h
4046
$(CXX) -c -o $@ $< $(CFLAGS) -Wall -fPIC -I/usr/local/cuda/include

distributed_embeddings/cc/kernels/embedding_lookup.h

+10-14
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,32 @@
2020

2121
#include <string>
2222

23+
#include "tensorflow/core/framework/op_kernel.h"
24+
2325
namespace tensorflow {
2426
enum class Combiner { Mean = 0, Sum = 1 };
2527
inline Combiner StringToEnum(std::string combiner) {
2628
return combiner == "mean" ? Combiner::Mean : Combiner::Sum;
2729
}
2830

29-
template <typename Device, typename T, typename Tindices>
30-
struct EmbeddingLookupConstantHotnessFunctor {
31-
void operator()(const Device& d, T* output_ptr, const T* param_ptr, const Tindices* ids_ptr,
32-
Tindices nnz_per_row, Tindices num_rows, Tindices embedding_width,
33-
Combiner combiner) const;
34-
};
35-
36-
template <typename Device, typename T, typename Tindices>
37-
struct EmbeddingLookupConstantHotnessGradFunctor {
38-
void operator()(const Device& d, T* output_ptr, const T* grad_ptr, Tindices nnz_per_row,
39-
Tindices num_rows, Tindices embedding_width, Combiner combiner) const;
31+
template <typename Device, typename Tindices>
32+
struct RowToSplitFunctor {
33+
void operator()(const Device& d, Tindices* split_ptr, const Tindices* row_ptr, Tindices num_ids,
34+
Tindices num_rows) const;
4035
};
4136

4237
template <typename Device, typename T, typename Tindices>
4338
struct EmbeddingLookupVariableHotnessFunctor {
4439
void operator()(const Device& d, T* output_ptr, const T* param_ptr, const Tindices* ids_ptr,
4540
const Tindices* offsets_ptr, Tindices num_rows, Tindices embedding_width,
46-
Combiner combiner) const;
41+
Combiner combiner, Tindices ave_red_len) const;
4742
};
4843

4944
template <typename Device, typename T, typename Tindices>
5045
struct EmbeddingLookupVariableHotnessGradFunctor {
51-
void operator()(const Device& d, T* output_ptr, const T* grad_ptr, const Tindices* offsets_ptr,
52-
Tindices num_rows, Tindices embedding_width, Combiner combiner) const;
46+
void operator()(OpKernelContext* context, const Tindices* ids_ptr, const Tindices* row_ptr,
47+
const T* grad_ptr, int64_t num_ids, Tindices embedding_width, Tindices num_rows,
48+
int64_t dense_shape_dim0, int64_t max_red_len, Combiner combiner) const;
5349
};
5450

5551
} // namespace tensorflow

distributed_embeddings/cc/kernels/embedding_lookup_kernels.cc

+41-82
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
*/
1717

1818
#include "embedding_lookup.h"
19-
#include "tensorflow/core/framework/op_kernel.h"
2019
#include "tensorflow/core/framework/resource_mgr.h"
2120
#include "tensorflow/core/framework/resource_var.h"
2221

@@ -45,62 +44,25 @@ class ReadVariableNoCopyOp : public OpKernel {
4544
DataType dtype_;
4645
};
4746

48-
template <typename Device, typename T, typename Tindices>
49-
class EmbeddingLookupConstantHotnessOp : public OpKernel {
50-
public:
51-
explicit EmbeddingLookupConstantHotnessOp(OpKernelConstruction* context) : OpKernel(context) {
52-
OP_REQUIRES_OK(context, context->GetAttr("combiner", &_combiner));
53-
}
54-
55-
void Compute(OpKernelContext* context) override {
56-
const Tensor& params = context->input(0);
57-
const Tensor& ids = context->input(1);
58-
59-
auto num_rows = ids.dim_size(0);
60-
auto nnz_per_row = ids.dim_size(1);
61-
auto embedding_width = params.dim_size(1);
62-
63-
TensorShape output_shape({num_rows, embedding_width});
64-
Tensor* output = nullptr;
65-
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
66-
67-
EmbeddingLookupConstantHotnessFunctor<Device, T, Tindices>()(
68-
context->eigen_device<Device>(), output->flat<T>().data(), params.flat<T>().data(),
69-
ids.flat<Tindices>().data(), nnz_per_row, num_rows, embedding_width,
70-
StringToEnum(_combiner));
71-
}
72-
73-
private:
74-
string _combiner;
75-
};
76-
77-
template <typename Device, typename T, typename Tindices>
78-
class EmbeddingLookupConstantHotnessGradOp : public OpKernel {
47+
template <typename Device, typename Tindices>
48+
class RowToSplitOp : public OpKernel {
7949
public:
80-
explicit EmbeddingLookupConstantHotnessGradOp(OpKernelConstruction* context) : OpKernel(context) {
81-
OP_REQUIRES_OK(context, context->GetAttr("combiner", &_combiner));
82-
}
50+
explicit RowToSplitOp(OpKernelConstruction* context) : OpKernel(context) {}
8351

8452
void Compute(OpKernelContext* context) override {
85-
const Tensor& grad = context->input(0);
86-
const Tensor& ids = context->input(1);
53+
// [n, 2]
54+
const Tensor& row = context->input(0);
55+
auto num_ids = row.dim_size(0);
56+
auto num_rows = context->input(1).scalar<int32>()();
8757

88-
auto num_rows = ids.dim_size(0);
89-
auto nnz_per_row = ids.dim_size(1);
90-
auto nnz = num_rows * nnz_per_row;
91-
auto embedding_width = grad.dim_size(1);
92-
93-
TensorShape output_shape({nnz, embedding_width});
58+
TensorShape output_shape({num_rows + 1});
9459
Tensor* output = nullptr;
9560
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
9661

97-
EmbeddingLookupConstantHotnessGradFunctor<Device, T, Tindices>()(
98-
context->eigen_device<Device>(), output->flat<T>().data(), grad.flat<T>().data(),
99-
nnz_per_row, num_rows, embedding_width, StringToEnum(_combiner));
62+
RowToSplitFunctor<Device, Tindices>()(context->eigen_device<Device>(),
63+
output->flat<Tindices>().data(),
64+
row.flat<Tindices>().data(), num_ids, num_rows);
10065
}
101-
102-
private:
103-
string _combiner;
10466
};
10567

10668
template <typename Device, typename T, typename Tindices>
@@ -118,14 +80,17 @@ class EmbeddingLookupVariableHotnessOp : public OpKernel {
11880
auto num_rows = offsets.dim_size(0) - 1;
11981
auto embedding_width = params.dim_size(1);
12082

83+
auto num_ids = ids.dim_size(0);
84+
auto ave_red_len = num_ids / num_rows;
85+
12186
TensorShape output_shape({num_rows, embedding_width});
12287
Tensor* output = nullptr;
12388
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
12489

12590
EmbeddingLookupVariableHotnessFunctor<Device, T, Tindices>()(
12691
context->eigen_device<Device>(), output->flat<T>().data(), params.flat<T>().data(),
12792
ids.flat<Tindices>().data(), offsets.flat<Tindices>().data(), num_rows, embedding_width,
128-
StringToEnum(_combiner));
93+
StringToEnum(_combiner), ave_red_len);
12994
}
13095

13196
private:
@@ -140,21 +105,20 @@ class EmbeddingLookupVariableHotnessGradOp : public OpKernel {
140105
}
141106

142107
void Compute(OpKernelContext* context) override {
143-
const Tensor& grad = context->input(0);
144-
const Tensor& ids = context->input(1);
145-
const Tensor& offsets = context->input(2);
146-
147-
auto num_rows = offsets.dim_size(0) - 1;
108+
const Tensor& ids = context->input(0);
109+
const Tensor& offset_in = context->input(1);
110+
const Tensor& grad = context->input(2);
111+
const Tensor& param = context->input(3);
112+
auto num_ids = ids.dim_size(0);
113+
auto num_rows = offset_in.dim_size(0) - 1;
148114
auto embedding_width = grad.dim_size(1);
149-
auto nnz = ids.dim_size(0);
150-
151-
TensorShape output_shape({nnz, embedding_width});
152-
Tensor* output = nullptr;
153-
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
115+
auto max_red_len = grad.dim_size(0);
116+
auto dense_shape_dim0 = param.dim_size(0);
154117

155118
EmbeddingLookupVariableHotnessGradFunctor<Device, T, Tindices>()(
156-
context->eigen_device<Device>(), output->flat<T>().data(), grad.flat<T>().data(),
157-
offsets.flat<Tindices>().data(), num_rows, embedding_width, StringToEnum(_combiner));
119+
context, ids.flat<Tindices>().data(), offset_in.flat<Tindices>().data(),
120+
grad.flat<T>().data(), num_ids, embedding_width, num_rows, dense_shape_dim0, max_red_len,
121+
StringToEnum(_combiner));
158122
}
159123

160124
private:
@@ -167,26 +131,21 @@ REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_DEFAULT).HostMe
167131
REGISTER_KERNEL_BUILDER(Name("ReadVariableNoCopy").Device(DEVICE_GPU).HostMemory("resource"),
168132
ReadVariableNoCopyOp);
169133

170-
#define REGISTER_GPU(T, Tindices) \
171-
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupConstantHotness") \
172-
.Device(DEVICE_GPU) \
173-
.TypeConstraint<T>("T") \
174-
.TypeConstraint<Tindices>("Tindices"), \
175-
EmbeddingLookupConstantHotnessOp<Eigen::GpuDevice, T, Tindices>); \
176-
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupConstantHotnessGrad") \
177-
.Device(DEVICE_GPU) \
178-
.TypeConstraint<T>("T") \
179-
.TypeConstraint<Tindices>("Tindices"), \
180-
EmbeddingLookupConstantHotnessGradOp<Eigen::GpuDevice, T, Tindices>); \
181-
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotness") \
182-
.Device(DEVICE_GPU) \
183-
.TypeConstraint<T>("T") \
184-
.TypeConstraint<Tindices>("Tindices"), \
185-
EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
186-
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotnessGrad") \
187-
.Device(DEVICE_GPU) \
188-
.TypeConstraint<T>("T") \
189-
.TypeConstraint<Tindices>("Tindices"), \
134+
#define REGISTER_GPU(T, Tindices) \
135+
REGISTER_KERNEL_BUILDER(Name("RowToSplit") \
136+
.Device(DEVICE_GPU) \
137+
.TypeConstraint<Tindices>("Tindices") \
138+
.HostMemory("shape"), \
139+
RowToSplitOp<Eigen::GpuDevice, Tindices>); \
140+
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotness") \
141+
.Device(DEVICE_GPU) \
142+
.TypeConstraint<T>("T") \
143+
.TypeConstraint<Tindices>("Tindices"), \
144+
EmbeddingLookupVariableHotnessOp<Eigen::GpuDevice, T, Tindices>); \
145+
REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupVariableHotnessGrad") \
146+
.Device(DEVICE_GPU) \
147+
.TypeConstraint<T>("T") \
148+
.TypeConstraint<Tindices>("Tindices"), \
190149
EmbeddingLookupVariableHotnessGradOp<Eigen::GpuDevice, T, Tindices>);
191150

192151
REGISTER_GPU(float, int64_t)

0 commit comments

Comments
 (0)