Skip to content

Commit

Permalink
Support custom keras layer in DistributedEmbedding wrapper.
Browse files Browse the repository at this point in the history
Switched submodule to thrust to avoid build issue.
  • Loading branch information
skyw authored and FDecaYed committed Feb 13, 2023
1 parent a581613 commit 5dafd70
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 16 deletions.
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "third_party/cub"]
path = third_party/cub
url = https://github.com/NVIDIA/cub.git
[submodule "third_party/thrust"]
path = third_party/thrust
url = https://github.com/NVIDIA/thrust.git
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TARGET_LIB = distributed_embeddings/python/ops/_embedding_lookup_ops.so
all: $(TARGET_LIB)

%_kernels.cu.o: distributed_embeddings/cc/kernels/%_kernels.cu distributed_embeddings/cc/kernels/%.h
$(NVCC) -c -o $@ $< -Ithird_party/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr
$(NVCC) -c -o $@ $< -Ithird_party/thrust/dependencies/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr

%_kernels.cc.o: distributed_embeddings/cc/kernels/%_kernels.cc distributed_embeddings/cc/kernels/%.h
$(CXX) -c -o $@ $< $(CFLAGS) -Wall -fPIC -I/usr/local/cuda/include
Expand Down
26 changes: 20 additions & 6 deletions distributed_embeddings/python/layers/dist_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(self,
# column_slice can be used to enable more table concat, so keep it in single process
self.column_slice_threshold = column_slice_threshold
self.global_configs = [e.get_config() for e in embeddings]
# Insert layer type information to config dicts
for config, embedding in zip(self.global_configs, embeddings):
config['layer_type'] = type(embedding)
if input_table_map is None:
input_table_map = list(range(len(embeddings)))

Expand Down Expand Up @@ -274,8 +277,10 @@ def _create_concat(self, table_configs, input_maps):
for concat_config in concat_configs:
input_dims = concat_config.pop('input_dims')
if len(input_dims) > 1:
orig_initializer = initializers.deserialize(concat_config['embeddings_initializer'])
concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims)
# TODO(deyuf): custom layer without initializer will be concat but init is not wrapped
if 'embeddings_initializer' in concat_config:
orig_initializer = initializers.deserialize(concat_config['embeddings_initializer'])
concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims)

# record weight offsets for get/set.
weight_offsets = [concat_config.pop('offsets', None) for concat_config in concat_configs]
Expand Down Expand Up @@ -363,8 +368,12 @@ def __init__(self,
# create local embeddings
self.local_embedding_layers = []
for config in self.strategy.local_configs[self.rank]:
config['synchronization'] = tf.VariableSynchronization.NONE
self.local_embedding_layers.append(Embedding.from_config(config))
layer_type = config.pop('layer_type')
# For stock keras Embedding, we switch underlying layer for better performance
# If inputs are custom layers, original layer will be used
# TODO(deyuf): Check functionality coverage, add fallback or type picking api
layer_type = Embedding if layer_type == tf.keras.layers.Embedding else layer_type
self.local_embedding_layers.append(layer_type.from_config(config))
self.offsets = [
None if offset == 0 else tf.constant([offset], dtype=tf.int64)
for offset in self.strategy.local_input_offsets[self.rank]
Expand Down Expand Up @@ -651,6 +660,11 @@ def build(self, input_shape):
F"Global batchsize {batch_sizes[0]} not divisible workers count {self.world_size}.")
for layer in self.local_embedding_layers:
layer.build(input_shape[0] if input_shape else None)
for var in layer.trainable_weights:
# Mark local(model parallel) variable. use prefix de(distributed embeddings) to avoid conflicts.
var.de_local = True
# set built flag to prevent above build trigger again and above flag fall off
layer.built = True
self.built = True

def call(self, inputs): # pylint: disable=missing-function-docstring
Expand All @@ -671,7 +685,7 @@ def broadcast_variables(model_vars, root_rank=0): # pylint: disable=missing-any
dp_vars = []
mp_vars = []
for var in model_vars:
if var.synchronization == tf.VariableSynchronization.NONE:
if hasattr(var, 'de_local'):
mp_vars.append(var)
else:
dp_vars.append(var)
Expand All @@ -693,7 +707,7 @@ def gradient(self, target, sources, output_gradients=None):
mp_grads = []
split_infos = []
for grad, var in zip(gradients, sources):
if var.synchronization == tf.VariableSynchronization.NONE:
if hasattr(var, 'de_local'):
if isinstance(grad, tf.IndexedSlices):
mp_grads.append(tf.IndexedSlices(grad.values / hvd.size(), grad.indices,
grad.dense_shape))
Expand Down
39 changes: 37 additions & 2 deletions distributed_embeddings/python/layers/dist_model_parallel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,28 @@
import horovod.tensorflow as hvd
from distributed_embeddings.python.layers import dist_model_parallel as dmp


# There are some functions in TF that pylint can't inspect correctly which leads to incorrect
# report of unexpected-keyword-arg, no-value-for-parameter. Disable them globally here
# pylint: disable=no-self-use,unexpected-keyword-arg,no-value-for-parameter,missing-docstring
class CustomEmbedding(tf.keras.layers.Layer):

def __init__(self, input_dim, output_dim, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim

def build(self, _):
self.params = self.add_weight("params",
shape=[self.input_dim, self.output_dim],
dtype=tf.float32)

def call(self, inputs):
return tf.gather(params=self.params, indices=inputs, axis=None)

def get_config(self):
config = {'input_dim': self.input_dim, 'output_dim': self.output_dim}
return config


class EmbeddingListModel(tf.keras.Model):
Expand All @@ -35,11 +54,15 @@ def __init__(self,
strategy='basic',
dp_input=True,
input_table_map=None,
column_slice_threshold=None):
column_slice_threshold=None,
test_custom_layer=False):
super().__init__()
self.embeddings = []
for size in table_sizes:
self.embeddings.append(tf.keras.layers.Embedding(*size))
if test_custom_layer:
self.embeddings.append(CustomEmbedding(*size))
else:
self.embeddings.append(tf.keras.layers.Embedding(*size))
if distribute:
self.dist_embeddings = dmp.DistributedEmbedding(self.embeddings,
strategy=strategy,
Expand Down Expand Up @@ -386,6 +409,18 @@ def test_fewer_tables_than_workers(self):
dp_inputs, _ = self.gen_inputs(table_sizes)
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)

def test_custom_embedding_layer(self):
table_sizes = self.gen_table_sizes()

ref_model = EmbeddingListModel(table_sizes, distribute=False, test_custom_layer=True)
test_model = EmbeddingListModel(table_sizes,
distribute=True,
strategy='basic',
test_custom_layer=True)

dp_inputs, _ = self.gen_inputs(table_sizes)
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)


if __name__ == "__main__":
test.main()
3 changes: 0 additions & 3 deletions distributed_embeddings/python/layers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self,
activity_regularizer=None,
embeddings_constraint=None,
combiner=None,
synchronization=tf.VariableSynchronization.AUTO,
**kwargs):
if 'input_shape' not in kwargs:
kwargs['input_shape'] = (None,)
Expand All @@ -89,7 +88,6 @@ def __init__(self,
self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.combiner = combiner
self.synchronization = synchronization

@tf_utils.shape_type_conversion
def build(self, input_shape): # pylint: disable=unused-argument
Expand All @@ -98,7 +96,6 @@ def build(self, input_shape): # pylint: disable=unused-argument
name='embeddings',
regularizer=self.embeddings_regularizer,
constraint=self.embeddings_constraint,
synchronization=self.synchronization,
experimental_autocast=False)
self.built = True

Expand Down
1 change: 0 additions & 1 deletion third_party/cub
Submodule cub deleted from cdaa95
1 change: 1 addition & 0 deletions third_party/thrust
Submodule thrust added at 65fbe2

0 comments on commit 5dafd70

Please sign in to comment.