Skip to content

Commit

Permalink
Add automatically slicing when there is more workers than tables and …
Browse files Browse the repository at this point in the history
…no column_slice_threshold is set. These case now run without NotImplemented error.
  • Loading branch information
FDecaYed committed Feb 13, 2023
1 parent 70b7d20 commit a581613
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
23 changes: 19 additions & 4 deletions distributed_embeddings/python/layers/dist_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def create_sliced_configs(self, world_size, column_slice_threshold, input_table_
sliced_out_ranges (list): each element is list of 2 integers, representing output ranges need
to be concatenated to re-form output due to above slice.
"""
# TODO(Deyu): in auto slice and when there are equal sized tables, allow slice some of them
# less table than worker, we try our best to slice into worker count slices(may go over)
if column_slice_threshold is None:
table_sizes = [config['input_dim'] * config['output_dim'] for config in self.global_configs]
while world_size > len(table_sizes):
table_sizes.sort()
column_slice_threshold = table_sizes[-1] - 1
cur_max_size = table_sizes.pop(-1)
table_sizes += [cur_max_size // 2, cur_max_size // 2]

sliced_configs = []
for global_config in self.global_configs:
maybe_sliced_config = self.maybe_slice_table_column(global_config, column_slice_threshold,
Expand Down Expand Up @@ -300,8 +310,11 @@ class DistributedEmbedding(tf.keras.layers.Layer):
embeddings (list of keras Embedding layers): embedding tables to be distributed
strategy (str): A string indicates how embedding tables are distributed.
Choices are [“basic”, “memory_balanced”]. Default "basic"
column_slice_threshold (int or None): If not None, embedding tables with more elements than
column_slice_threshold will be divide into N even pieces alone embedded width dimension.
column_slice_threshold (int or None): If None, column slice only happen when there are more
workers than tables. In that case, column_slice_threshold will be choose automatically
so each worker receive at least one slice.
If not None, embedding tables with more elements than column_slice_threshold will be divide
into N even pieces alone embedded width dimension.
N is smallest power of 2 makes each slice smaller than column_slice_threshold. Default None.
row_slice (TBD): Describe how which embedding needs to be row sliced
dp_input (bool): If True, takes data parallel input, i.e. in shape
Expand Down Expand Up @@ -342,8 +355,10 @@ def __init__(self,
strategy,
input_table_map=input_table_map,
column_slice_threshold=column_slice_threshold)
if len(self.strategy.global_configs) < self.world_size:
raise NotImplementedError
# Handle explicit threshold or corner cases, in which worker may receive no configs
if not all(rank_configs for rank_configs in self.strategy.local_configs):
raise ValueError("Not enough table after slicing to run on all worker."
"Try decrease column_slice_threshold or decrease worker count")

# create local embeddings
self.local_embedding_layers = []
Expand Down
15 changes: 12 additions & 3 deletions distributed_embeddings/python/layers/dist_model_parallel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def __init__(self, *args, **kwargs):
def gen_table_sizes(self, num_tables=None):
random.seed(self.seed)
if num_tables is None:
num_tables = random.randint(self.hvd_size, 2 * self.hvd_size)
num_tables = random.randint(1, 2 * self.hvd_size)
table_sizes = []
for _ in range(num_tables):
table_height = random.randint(3, 20)
table_width = random.randint(3, 15)
table_width = random.randint(4, 15)
table_sizes.append([table_height, table_width])
return table_sizes

Expand Down Expand Up @@ -278,7 +278,7 @@ def test_column_slice_merge(self):
self.assertEqual(len(tables), len(set(tables)))

def test_column_slice_threshold(self):
table_sizes = self.gen_table_sizes()
table_sizes = self.gen_table_sizes(self.hvd_size + 1)
ref_model = EmbeddingListModel(table_sizes, distribute=False)
test_model = EmbeddingListModel(table_sizes,
distribute=True,
Expand Down Expand Up @@ -377,6 +377,15 @@ def test_indivisible_batch(self):
with self.assertRaisesRegex(ValueError, "not divisible"):
self.run_and_test(ref_model, dp_inputs, test_model, mp_inputs)

def test_fewer_tables_than_workers(self):
table_sizes = self.gen_table_sizes(1)

ref_model = EmbeddingListModel(table_sizes, distribute=False)
test_model = EmbeddingListModel(table_sizes, distribute=True, strategy='memory_balanced')

dp_inputs, _ = self.gen_inputs(table_sizes)
self.run_and_test(ref_model, dp_inputs, test_model, dp_inputs)


if __name__ == "__main__":
test.main()

0 comments on commit a581613

Please sign in to comment.