diff --git a/README.md b/README.md index fedf94f..1c7c15c 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ See more details at [User Guide](https://nvidia-merlin.github.io/distributed-emb ## Installation ### Requirements -Python 3, CUDA 11 or newer, TensorFlow 2.6.0 or newer +Python 3, CUDA 11 or newer, TensorFlow 2 ### Containers ### You can build inside 22.03 or later NGC TF2 [image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow): ```bash diff --git a/build_pip_pkg.sh b/build_pip_pkg.sh index 28d2173..8516d6d 100644 --- a/build_pip_pkg.sh +++ b/build_pip_pkg.sh @@ -11,6 +11,7 @@ echo "=== Copy TensorFlow Custom op files" cp setup.py "${TMPDIR}" cp MANIFEST.in "${TMPDIR}" cp requirements.txt "${TMPDIR}" +cp version.txt "${TMPDIR}" rsync -avm -L --exclude='*_test.py' distributed_embeddings "${TMPDIR}" pushd ${TMPDIR} diff --git a/distributed_embeddings/__init__.py b/distributed_embeddings/__init__.py index 3449f46..5ff36c2 100644 --- a/distributed_embeddings/__init__.py +++ b/distributed_embeddings/__init__.py @@ -15,3 +15,4 @@ """Distributed embedding API.""" from distributed_embeddings.python.ops.embedding_lookup_ops import embedding_lookup +from .version import __version__ diff --git a/distributed_embeddings/python/layers/dist_model_parallel.py b/distributed_embeddings/python/layers/dist_model_parallel.py index 84862d3..7fe1135 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel.py +++ b/distributed_embeddings/python/layers/dist_model_parallel.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Distributed Embedding layers and utils""" +import math +import numpy as np import tensorflow as tf from tensorflow.python.keras.utils import tf_utils import horovod.tensorflow as hvd +from distributed_embeddings.python.ops.embedding_lookup_ops import read_var_no_copy from .embedding import Embedding @@ -48,19 +51,21 @@ def __init__(self, self.input_ids_list = [list(range(len(input_table_map)))] self.table_ids_list = [list(range(len(embeddings)))] return + # Create (maybe) sliced configs sliced_configs, self.sliced_out_ranges = self.create_sliced_configs( world_size, column_slice_threshold, input_table_map) # Apply strategy and save nested list containing table indices by rank self.table_ids_list = self.apply_stragety(strategy, world_size, sliced_configs) - # Nested list to split embedding output from each rank into tables - self.widths_list = [] + # Nested list containing input indices by rank self.input_ids_list = [] # Nested list containing local input to local table map by rank self.local_map_list = [] # Nested list containing local configs by rank self.local_configs_list = [] + # All of local widths ordered by rank flat into single list + self.widths_list_flat = [] # Each worker loop over all rank to get global view of strategy for rank_table_ids in self.table_ids_list: # calculate stats needed for each rank @@ -73,20 +78,22 @@ def __init__(self, rank_input_ids.append(k) rank_input_map.append(m) self.local_configs_list.append(rank_configs) - self.widths_list.append(rank_widths) + self.widths_list_flat += rank_widths self.input_ids_list.append(rank_input_ids) self.local_map_list.append(rank_input_map) - # List of total embedding widths to split embedding output by rank after alltoall - self.total_local_widths = [sum(widths) for widths in self.widths_list] + # List that maps local inputs to local table self.local_input_table_map = self.local_map_list[rank] + # flatten self.input_ids_list worker_order_input_ids = [item for sublist in self.input_ids_list for item in sublist] + # List of indices to shuffle worker ordered embedding outputs back to original order self.rev_global_input_ids = [ index for _, index in sorted(zip(worker_order_input_ids, range(len(worker_order_input_ids)))) ] + # List of configs to create local embedding layers self.local_configs = self.local_configs_list[rank] @@ -286,18 +293,17 @@ def _call_base(self, inputs): # pylint: disable=missing-param-doc,missing-type- for m, inp in zip(self.strategy.local_input_table_map, inputs) ] - # concat last axis to make all2all slice correct, and reshape to make later split easier # TODO(Deyu): current assume 2D with same batch for all output, ideally should support general case - local_bs = inputs[0].shape[0] // self.world_size - mp_outs = tf.reshape(tf.concat(mp_outs, axis=-1), [-1, local_bs]) + mp_outs = [tf.reshape(mp_out, [self.world_size, -1]) for mp_out in mp_outs] + mp_outs = tf.reshape(tf.concat(mp_outs, axis=1), [-1]) + # cast before alltoall according to dtype policy + mp_outs = tf.cast(mp_outs, self.compute_dtype) dp_outs = hvd.alltoall(mp_outs, name='out_mp_to_dp') - dp_outs = [ - tf.reshape(t, [local_bs, -1]) for t in tf.split(dp_outs, self.strategy.total_local_widths) - ] - # split each worker result and re-order using id - worker_order_res = [] - for dp_out, widths in zip(dp_outs, self.strategy.widths_list): - worker_order_res += tf.split(dp_out, widths, 1) + local_bs = inputs[0].shape[0] // self.world_size + num_elements = [local_bs * item for item in self.strategy.widths_list_flat] + split_outs = tf.split(dp_outs, num_elements) + worker_order_res = [tf.reshape(split_out, [local_bs, -1]) for split_out in split_outs] + # reorder outputs to be same as inputs order result = [worker_order_res[index] for index in self.strategy.rev_global_input_ids] return result @@ -309,70 +315,149 @@ def _concat_column_slice_outputs(self, outs): outs[start:end] = [tf.concat(outs[start:end], axis=-1)] return outs - def set_weights(self, weights): # pylint: disable=missing-param-doc,missing-type-doc + def set_weights(self, weights, chunk=134217728, use_lock=False): """Sets the weights of the layer, from NumPy arrays. - This override expects global weights for all tables as input. + Args: + weights (list): list containing global weights for all table. + item in the list can be either numpy array or file path to load from. + chunk (int): max number of elements per chunk when set weight on GPU by chunks. + this will be round to number of rows base on weight shape. + use_lock (bool): If true, set weights rank by rank in lock step to avoid OOM. Default False. """ - if self.world_size == 1: - sliced_local_weights = weights - else: + if use_lock: + for _ in range(self.rank): + hvd.broadcast_object(0) + + if self.world_size > 1: slice_info = [[rank_tids.count(tid) for rank_tids in self.strategy.table_ids_list] for tid in range(len(weights))] - local_weights = [weights[index] for index in self.strategy.table_ids_list[self.rank]] + weights = [weights[index] for index in self.strategy.table_ids_list[self.rank]] + if isinstance(weights[0], str): + weights = [np.load(file=path, mmap_mode='r') for path in weights] local_info = [slice_info[index] for index in self.strategy.table_ids_list[self.rank]] + # array to handle multiple slice into same table case + # TODO(Deyu): avoid this by merge those table again after find strategy + rank_ids = self.strategy.table_ids_list[self.rank] + index_offset = [rank_ids[:i].count(rank_id) for i, rank_id in enumerate(rank_ids)] - def _slice_weight_for_rank(weight, info, global_rank): + def _slice_weight_for_rank(weight, info, global_rank, offset): num_columns = weight.shape[1] num_slices = sum(info) column_per_slice = num_columns // num_slices remainder = num_columns % num_slices - rank = info[:global_rank].count(1) + rank = sum(info[:global_rank]) + offset start = column_per_slice * rank + min(rank, remainder) rank += 1 end = column_per_slice * rank + min(rank, remainder) return weight[:, start:end] - sliced_local_weights = [ - _slice_weight_for_rank(weight, info, self.rank) - for weight, info in zip(local_weights, local_info) + weights = [ + _slice_weight_for_rank(weight, info, self.rank, offset) + for weight, info, offset in zip(weights, local_info, index_offset) ] - super().set_weights(sliced_local_weights) + # variable.assign and copy-on-write creates extra copy of weight that causes OOM + # so here we scatter update by ~128M elements chunks instead of just do + # super().set_weights(weights) + for weight, arr in zip(self.weights, weights): + if arr.size <= chunk: + weight.assign(arr) + else: + chunk_size_dim0 = chunk // weight.shape[1] + num_chunks = math.ceil(weight.shape[0] / chunk_size_dim0) + last_size = weight.shape[0] - chunk_size_dim0 * (num_chunks - 1) + chunk_sizes = [chunk_size_dim0] * (num_chunks - 1) + [last_size] + for i in range(num_chunks): + start = i * chunk_size_dim0 + end = start + chunk_sizes[i] + indices = tf.range(start=start, limit=end, dtype=tf.int64) + update = tf.IndexedSlices(values=arr[start:end], + indices=indices, + dense_shape=weight.shape) + weight.scatter_update(sparse_delta=update) + del weights + + if use_lock: + for _ in range(self.world_size - self.rank): + hvd.broadcast_object(0) + + # 1d split that works beyond 32bit indexing limit TF support + def _split_1d(self, tensor, lengths): + # choose a number close to int32 limit as maximum chunk size + # This will handle tensor with size up to square of int32_max + chunking_threshold = 2147483646 + if tensor.shape[0] <= chunking_threshold: + return tf.split(tensor, lengths) + num_chunks = math.ceil(tensor.shape[0] / chunking_threshold) + padding_len = math.ceil(tensor.shape[0] / num_chunks) * num_chunks - tensor.shape[0] + padded_tensor = tf.concat([tensor, tf.zeros(padding_len, tensor.dtype)], axis=0) + tensor_list = tf.unstack(tf.reshape(padded_tensor, [num_chunks, -1])) + result = [] + for length in lengths: + this_slice = [] + while length > 0: + if length > tensor_list[0].shape[0]: + this_slice.append(tensor_list.pop(0)) + else: + this_slice.append(tensor_list[0][:length]) + tensor_list[0] = tensor_list[0][length:] + length -= this_slice[-1].shape[0] + result.append(tf.concat(this_slice, axis=0)) + return result - def get_weights(self): + def get_weights(self, all_ranks=False): """Returns the current weights of the layer, as NumPy arrays. This override outputs global weights for all tables. + Args: + all_ranks (bool): If true, return weights in all ranks, otherwise only in rank 0. + Default False. """ + # avoid copy-on-read on dense access + local_weights = [read_var_no_copy(w) for w in self.weights] if self.world_size == 1: - return [weight.numpy() for weight in self.weights] + return [w.numpy() for w in local_weights] + + # mpi segfault on over 32bit range index, so we gather weights chunk by chunk here + # choose a number not very close to int32 limit as maximum chunk size just to be safe + chunking_threshold = 2000000000 + num_chunks = 1 + for local_configs in self.strategy.local_configs_list: + total_elements = sum([c['input_dim'] * c['output_dim'] for c in local_configs]) + num_chunks = max(num_chunks, math.ceil(self.world_size * total_elements / chunking_threshold)) - # mpi segfault on large sizes so we gather weights chunk by chunk here - num_chunks = 8 with tf.device('CPU:0'): - local_weights = tf.concat([tf.reshape(w, [-1]) for w in self.weights], axis=0) + local_weights = tf.concat([tf.reshape(w, [-1]) for w in local_weights], axis=0) chunk_size = local_weights.shape[0] // num_chunks last_size = local_weights.shape[0] - chunk_size * (num_chunks - 1) chunk_sizes = [chunk_size] * (num_chunks - 1) + [last_size] - local_weights = tf.split(local_weights, chunk_sizes) + local_weights = self._split_1d(local_weights, chunk_sizes) + # communicate chunk sizes all_sizes = hvd.allgather(chunk_sizes) # collect all chunks and split to reverse allgather concat chunks = [] for i, w in enumerate(local_weights): - chunks += tf.split(hvd.allgather(w), all_sizes[i::num_chunks]) + w = hvd.allgather(w) + if all_ranks or self.rank == 0: + chunks += self._split_1d(w, all_sizes[i::num_chunks]) + if not chunks: + return [] + # re-construct all local weights from chunks local_weights = [] for i in range(self.world_size): local_weights.append(tf.concat(chunks[i::self.world_size], axis=0)) + del chunks + # split flat local weights into correct sizes weights = [] for local_weight, local_configs in zip(local_weights, self.strategy.local_configs_list): local_shapes = [[c['input_dim'], c['output_dim']] for c in local_configs] local_sizes = [shape[0] * shape[1] for shape in local_shapes] - flat_weights = tf.split(local_weight, local_sizes) + flat_weights = self._split_1d(local_weight, local_sizes) weights += [tf.reshape(weight, shape) for weight, shape in zip(flat_weights, local_shapes)] # restore original table order # flatten self.strategy.table_ids_list @@ -408,6 +493,7 @@ def call(self, inputs): # pylint: disable=missing-function-docstring self.local_embedding_layers[m](inp) for m, inp in zip(self.strategy.local_input_table_map, inputs) ] + outputs = [tf.cast(output, self.compute_dtype) for output in outputs] return outputs # TODO(skyw): Revisit logics of selecting call functions for different strategy @@ -460,7 +546,10 @@ def gradient(self, target, sources, output_gradients=None): dp_vars.append(var) dp_grads.append(grad) split_infos.append((False, len(dp_grads) - 1)) - dp_grads = self._allreduce_grads(dp_grads, dp_vars) # pylint: disable=protected-access + # TODO(Deyu): make sure not reusing _allreduce_grads doesn't lead to any issue + dp_grads = [ + hvd.allreduce(g, name=f'dp_gradient_{i}', op=hvd.Average) for i, g in enumerate(dp_grads) + ] # put gradients back in original order grads = [] for info in split_infos: diff --git a/distributed_embeddings/python/layers/dist_model_parallel_test.py b/distributed_embeddings/python/layers/dist_model_parallel_test.py index d19abe3..9660484 100644 --- a/distributed_embeddings/python/layers/dist_model_parallel_test.py +++ b/distributed_embeddings/python/layers/dist_model_parallel_test.py @@ -158,7 +158,7 @@ def run_and_test(self, ref_model, ref_inputs, test_model, test_inputs): optimizer.apply_gradients(zip(ref_grads, ref_model.variables)) optimizer.apply_gradients(zip(test_grads, test_model.variables)) ref_weights = ref_model.get_weights() - test_weights = test_model.dist_embeddings.get_weights() + test_model.dense.get_weights() + test_weights = test_model.dist_embeddings.get_weights(True) + test_model.dense.get_weights() for ref_w, test_w in zip(ref_weights, test_weights): # assert close here since order of accumulations(inputs and batch dim) might have changed @@ -269,6 +269,18 @@ def test_column_slice_threshold(self): dp_inputs, _ = self.gen_inputs(table_sizes) self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs) + def test_column_slice_dup_worker(self): + table_sizes = [[10, 4], [11, 2], [4, 2], [4, 2]] + ref_model = EmbeddingListModel(table_sizes, distribute=False) + test_model = EmbeddingListModel(table_sizes, + distribute=True, + strategy='memory_balanced', + dp_input=False, + column_slice_threshold=10) + mp_input_ids = test_model.dist_embeddings.strategy.input_ids_list[self.hvd_rank] + dp_inputs, mp_inputs = self.gen_inputs(table_sizes, mp_input_ids=mp_input_ids) + self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs) + if __name__ == "__main__": test.main() diff --git a/setup.py b/setup.py index e3752b7..ad3795e 100644 --- a/setup.py +++ b/setup.py @@ -14,16 +14,44 @@ # limitations under the License. """Simple setup script""" +import os from setuptools import setup, find_packages +abspath = os.path.dirname(os.path.realpath(__file__)) + with open("requirements.txt", encoding='utf-8') as f: requirements = f.read().splitlines() # pylint: disable=invalid-name print(find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"])) +license_header = """# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" + +# Generate version file +with open(os.path.join(abspath, "version.txt"), encoding="utf-8") as f: + version = f.read().strip() +with open(os.path.join(abspath, "distributed_embeddings/version.py"), "w", encoding="utf-8") as f: + f.write(license_header) + f.write(F"__version__ = \"{version}\"") + setup( name="distributed-embeddings", - version="1.0.0", + version=version, description="Distributed Embedding", packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), install_requires=requirements, diff --git a/version.txt b/version.txt new file mode 100644 index 0000000..6e8bf73 --- /dev/null +++ b/version.txt @@ -0,0 +1 @@ +0.1.0