From 5dafd7047c0ab58b6a6f0eddcb6d64e135f32329 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Sun, 12 Feb 2023 20:53:22 -0800 Subject: [PATCH] Support custom keras layer in DistributedEmbedding wrapper. Switched submodule to thrust to avoid build issue. --- .gitmodules | 6 +-- Makefile | 2 +- .../python/layers/dist_model_parallel.py | 26 ++++++++++--- .../python/layers/dist_model_parallel_test.py | 39 ++++++++++++++++++- .../python/layers/embedding.py | 3 -- third_party/cub | 1 - third_party/thrust | 1 + 7 files changed, 62 insertions(+), 16 deletions(-) delete mode 160000 third_party/cub create mode 160000 third_party/thrust diff --git a/.gitmodules b/.gitmodules index 17ccc9e..f4cc8c1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "third_party/cub"] - path = third_party/cub - url = https://github.com/NVIDIA/cub.git +[submodule "third_party/thrust"] + path = third_party/thrust + url = https://github.com/NVIDIA/thrust.git diff --git a/Makefile b/Makefile index 4b8e1f9..df2d7a0 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ TARGET_LIB = distributed_embeddings/python/ops/_embedding_lookup_ops.so all: $(TARGET_LIB) %_kernels.cu.o: distributed_embeddings/cc/kernels/%_kernels.cu distributed_embeddings/cc/kernels/%.h - $(NVCC) -c -o $@ $< -Ithird_party/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr + $(NVCC) -c -o $@ $< -Ithird_party/thrust/dependencies/cub $(CFLAGS) -I. -DGOOGLE_CUDA=1 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -x cu -Xcompiler -fPIC --expt-relaxed-constexpr %_kernels.cc.o: distributed_embeddings/cc/kernels/%_kernels.cc distributed_embeddings/cc/kernels/%.h $(CXX) -c -o $@ $< $(CFLAGS) -Wall -fPIC -I/usr/local/cuda/include diff --git a/distributed_embeddings/python/layers/dist_model_parallel.py b/distributed_embeddings/python/layers/dist_model_parallel.py index 8a4af81..db1f464 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel.py +++ b/distributed_embeddings/python/layers/dist_model_parallel.py @@ -74,6 +74,9 @@ def __init__(self, # column_slice can be used to enable more table concat, so keep it in single process self.column_slice_threshold = column_slice_threshold self.global_configs = [e.get_config() for e in embeddings] + # Insert layer type information to config dicts + for config, embedding in zip(self.global_configs, embeddings): + config['layer_type'] = type(embedding) if input_table_map is None: input_table_map = list(range(len(embeddings))) @@ -274,8 +277,10 @@ def _create_concat(self, table_configs, input_maps): for concat_config in concat_configs: input_dims = concat_config.pop('input_dims') if len(input_dims) > 1: - orig_initializer = initializers.deserialize(concat_config['embeddings_initializer']) - concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims) + # TODO(deyuf): custom layer without initializer will be concat but init is not wrapped + if 'embeddings_initializer' in concat_config: + orig_initializer = initializers.deserialize(concat_config['embeddings_initializer']) + concat_config['embeddings_initializer'] = ConcatInitializer(orig_initializer, input_dims) # record weight offsets for get/set. weight_offsets = [concat_config.pop('offsets', None) for concat_config in concat_configs] @@ -363,8 +368,12 @@ def __init__(self, # create local embeddings self.local_embedding_layers = [] for config in self.strategy.local_configs[self.rank]: - config['synchronization'] = tf.VariableSynchronization.NONE - self.local_embedding_layers.append(Embedding.from_config(config)) + layer_type = config.pop('layer_type') + # For stock keras Embedding, we switch underlying layer for better performance + # If inputs are custom layers, original layer will be used + # TODO(deyuf): Check functionality coverage, add fallback or type picking api + layer_type = Embedding if layer_type == tf.keras.layers.Embedding else layer_type + self.local_embedding_layers.append(layer_type.from_config(config)) self.offsets = [ None if offset == 0 else tf.constant([offset], dtype=tf.int64) for offset in self.strategy.local_input_offsets[self.rank] @@ -651,6 +660,11 @@ def build(self, input_shape): F"Global batchsize {batch_sizes[0]} not divisible workers count {self.world_size}.") for layer in self.local_embedding_layers: layer.build(input_shape[0] if input_shape else None) + for var in layer.trainable_weights: + # Mark local(model parallel) variable. use prefix de(distributed embeddings) to avoid conflicts. + var.de_local = True + # set built flag to prevent above build trigger again and above flag fall off + layer.built = True self.built = True def call(self, inputs): # pylint: disable=missing-function-docstring @@ -671,7 +685,7 @@ def broadcast_variables(model_vars, root_rank=0): # pylint: disable=missing-any dp_vars = [] mp_vars = [] for var in model_vars: - if var.synchronization == tf.VariableSynchronization.NONE: + if hasattr(var, 'de_local'): mp_vars.append(var) else: dp_vars.append(var) @@ -693,7 +707,7 @@ def gradient(self, target, sources, output_gradients=None): mp_grads = [] split_infos = [] for grad, var in zip(gradients, sources): - if var.synchronization == tf.VariableSynchronization.NONE: + if hasattr(var, 'de_local'): if isinstance(grad, tf.IndexedSlices): mp_grads.append(tf.IndexedSlices(grad.values / hvd.size(), grad.indices, grad.dense_shape)) diff --git a/distributed_embeddings/python/layers/dist_model_parallel_test.py b/distributed_embeddings/python/layers/dist_model_parallel_test.py index 89dea50..6a8be15 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel_test.py +++ b/distributed_embeddings/python/layers/dist_model_parallel_test.py @@ -21,9 +21,28 @@ import horovod.tensorflow as hvd from distributed_embeddings.python.layers import dist_model_parallel as dmp + # There are some functions in TF that pylint can't inspect correctly which leads to incorrect # report of unexpected-keyword-arg, no-value-for-parameter. Disable them globally here # pylint: disable=no-self-use,unexpected-keyword-arg,no-value-for-parameter,missing-docstring +class CustomEmbedding(tf.keras.layers.Layer): + + def __init__(self, input_dim, output_dim, **kwargs): + super().__init__(**kwargs) + self.input_dim = input_dim + self.output_dim = output_dim + + def build(self, _): + self.params = self.add_weight("params", + shape=[self.input_dim, self.output_dim], + dtype=tf.float32) + + def call(self, inputs): + return tf.gather(params=self.params, indices=inputs, axis=None) + + def get_config(self): + config = {'input_dim': self.input_dim, 'output_dim': self.output_dim} + return config class EmbeddingListModel(tf.keras.Model): @@ -35,11 +54,15 @@ def __init__(self, strategy='basic', dp_input=True, input_table_map=None, - column_slice_threshold=None): + column_slice_threshold=None, + test_custom_layer=False): super().__init__() self.embeddings = [] for size in table_sizes: - self.embeddings.append(tf.keras.layers.Embedding(*size)) + if test_custom_layer: + self.embeddings.append(CustomEmbedding(*size)) + else: + self.embeddings.append(tf.keras.layers.Embedding(*size)) if distribute: self.dist_embeddings = dmp.DistributedEmbedding(self.embeddings, strategy=strategy, @@ -386,6 +409,18 @@ def test_fewer_tables_than_workers(self): dp_inputs, _ = self.gen_inputs(table_sizes) self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs) + def test_custom_embedding_layer(self): + table_sizes = self.gen_table_sizes() + + ref_model = EmbeddingListModel(table_sizes, distribute=False, test_custom_layer=True) + test_model = EmbeddingListModel(table_sizes, + distribute=True, + strategy='basic', + test_custom_layer=True) + + dp_inputs, _ = self.gen_inputs(table_sizes) + self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs) + if __name__ == "__main__": test.main() diff --git a/distributed_embeddings/python/layers/embedding.py b/distributed_embeddings/python/layers/embedding.py index 96bdfe1..9b159e7 100644 --- a/distributed_embeddings/python/layers/embedding.py +++ b/distributed_embeddings/python/layers/embedding.py @@ -67,7 +67,6 @@ def __init__(self, activity_regularizer=None, embeddings_constraint=None, combiner=None, - synchronization=tf.VariableSynchronization.AUTO, **kwargs): if 'input_shape' not in kwargs: kwargs['input_shape'] = (None,) @@ -89,7 +88,6 @@ def __init__(self, self.activity_regularizer = regularizers.get(activity_regularizer) self.embeddings_constraint = constraints.get(embeddings_constraint) self.combiner = combiner - self.synchronization = synchronization @tf_utils.shape_type_conversion def build(self, input_shape): # pylint: disable=unused-argument @@ -98,7 +96,6 @@ def build(self, input_shape): # pylint: disable=unused-argument name='embeddings', regularizer=self.embeddings_regularizer, constraint=self.embeddings_constraint, - synchronization=self.synchronization, experimental_autocast=False) self.built = True diff --git a/third_party/cub b/third_party/cub deleted file mode 160000 index cdaa955..0000000 --- a/third_party/cub +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cdaa9558a85e45d849016e5fe7b6e4ee79113f95 diff --git a/third_party/thrust b/third_party/thrust new file mode 160000 index 0000000..65fbe23 --- /dev/null +++ b/third_party/thrust @@ -0,0 +1 @@ +Subproject commit 65fbe23ab95d58966a2bc44245c084576f093b71