Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce distributed embeddings #974

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8b111c9
rough draft
edknv Feb 4, 2023
39cc981
Introduce distributed embeddings
edknv Feb 7, 2023
b46bd20
Merge branch 'main' into distributed-embeddings
edknv Feb 7, 2023
dd3df0c
install distributed-embeddings package in tox
edknv Feb 8, 2023
847c05b
check if distributed-embeddings is installed before loading the class…
edknv Feb 8, 2023
7b19fb2
Merge branch 'main' into distributed-embeddings
edknv Feb 8, 2023
ded191e
lint
edknv Feb 8, 2023
99a9b2d
install distributed-embeddings from github repo
edknv Feb 8, 2023
77a4f5c
Merge branch 'main' into distributed-embeddings
edknv Feb 8, 2023
674d53f
graph mode support
edknv Feb 15, 2023
4a8bd97
add distributed-embeddings to ci
edknv Feb 15, 2023
1f75984
Add multi-gpu ci tests
edknv Mar 7, 2023
aa3c7b7
Merge branch 'main' into distributed-embeddings
edknv Mar 7, 2023
a0c58d7
remove graph mode error
edknv Mar 7, 2023
9892be8
lint and minor rearrangement
edknv Mar 7, 2023
c76b528
lint
edknv Mar 7, 2023
4fe3863
revert to using tensor.shape
edknv Mar 7, 2023
9ac60c7
whitelist sh in tox
edknv Mar 7, 2023
3c4b529
specify path in gha
edknv Mar 7, 2023
4a8df39
Merge branch 'main' into distributed-embeddings
edknv Mar 7, 2023
8c2e0e6
fix horovod cpu gha workflow
edknv Mar 7, 2023
5c7fd5a
move horovod installation to multi-gpu
edknv Mar 7, 2023
36bb606
use python -m in tox
edknv Mar 7, 2023
b69078f
Remove horovod installation
edknv Mar 7, 2023
2763cd8
clean up and add documentation
edknv Mar 8, 2023
c916eb0
Merge branch 'main' into distributed-embeddings
edknv Mar 8, 2023
2e2d449
Merge branch 'main' into distributed-embeddings
rnyak Mar 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/usecases/multi-gpu/install_distributed_embeddings.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

set -e

INSTALL_DIR=$1

WORK_DIR=$(pwd)

cd $INSTALL_DIR

if [ ! -d "distributed-embeddings" ]; then
git clone https://github.com/NVIDIA-Merlin/distributed-embeddings.git
fi

cd distributed-embeddings

git submodule update --init --recursive
make pip_pkg
python -m pip install --force-reinstall artifacts/*.whl
python setup.py install

cd $WORK_DIR

python -c "import distributed_embeddings"
8 changes: 7 additions & 1 deletion merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@
from merlin.models.tf.prediction_tasks.multi import PredictionTasks
from merlin.models.tf.prediction_tasks.regression import RegressionTask
from merlin.models.tf.prediction_tasks.retrieval import ItemRetrievalTask
from merlin.models.utils.dependencies import is_transformers_available
from merlin.models.utils.dependencies import (
is_distributed_embeddings_available,
is_transformers_available,
)

if is_transformers_available():
from merlin.models.tf.transformers.block import (
Expand All @@ -145,6 +148,9 @@
LastHiddenStateAndAttention,
)

if is_distributed_embeddings_available():
from merlin.models.tf.distributed.embedding import DistributedEmbeddings

from merlin.models.tf.transforms.features import (
BroadcastToSequence,
CategoryEncoding,
Expand Down
9 changes: 9 additions & 0 deletions merlin/models/tf/distributed/backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
hvd = None
hvd_installed = False

dmp = None
dmp_installed = False

try:
import horovod.tensorflow.keras as hvd # noqa: F401

hvd_installed = True
except ImportError:
pass

try:
from distributed_embeddings.python.layers import dist_model_parallel as dmp # noqa: F401

dmp_installed = True
except ImportError:
pass

if hvd_installed:
hvd.init()
155 changes: 155 additions & 0 deletions merlin/models/tf/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import Dict, List, Optional, Union

import tensorflow as tf

from merlin.models.tf.core.tabular import TabularBlock
from merlin.models.tf.distributed.backend import dmp, dmp_installed, hvd_installed
from merlin.models.utils.schema_utils import infer_embedding_dim
from merlin.schema import Schema


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class DistributedEmbeddings(TabularBlock):
"""Large embedding table that automatically distributes embedding tables
to multiple GPUs.

Parameters
----------
schema: Schema
Schema containing the columns used in embedding tables.
dim: Optional[Union[Dict[str, int], int]], optional
If int, the embedding size to use for all features, or a
dictionary-like {"feature_name": embedding size, ...}.
By default, None.
strategy: str
Indicates how embedding tables are distributed.
One of ["basic", "memory_balanced"]. Default: "basic".
column_slice_threshold: Optional[int]
Desired upper bound of element count in each slice.
dp_input: bool
If True, takes data-parallel input in shape [local_batch_size x global_num_embeddings].
Otherwise takes model-parallel input in shape [global_batch_size x local_num_embeddings].
Default: true.
input_table_map: Optional[List[int]]
A list with same length as inputs. Maps `input[i]` to `table[input_table_map[i]]`.
If None, `input[i]` maps to `table[i]`. Default: None.
"""

def __init__(
self,
schema: Schema,
dim: Optional[Union[Dict[str, int], int]] = None,
strategy: str = "basic",
column_slice_threshold: Optional[int] = None,
dp_input: bool = True,
input_table_map: Optional[List[int]] = None,
**kwargs,
):
if not hvd_installed or not dmp_installed:
raise ImportError(
"'horovod' and 'distributed-embeddings' are required to use "
f"{self.__class__.__name__}."
)

super(DistributedEmbeddings, self).__init__(schema=schema, **kwargs)

self.dim = dim
self.table_names = []
self.embedding_layers = []

for col in self.schema:
table_name = col.int_domain.name or col.name
self.table_names.append(table_name)
self.embedding_layers.append(
tf.keras.layers.Embedding(
input_dim=self._infer_input_dim(col),
output_dim=self._infer_output_dim(col, dim),
name=table_name,
)
)

self.embedding_layers = dmp.DistributedEmbedding(
self.embedding_layers,
strategy=strategy,
column_slice_threshold=column_slice_threshold,
dp_input=dp_input,
input_table_map=input_table_map,
)

def _infer_input_dim(self, col_schema):
return col_schema.int_domain.max + 1

def _infer_output_dim(self, col_schema, embedding_dims):
if isinstance(embedding_dims, dict):
dim = embedding_dims.get(col_schema.name)
elif isinstance(embedding_dims, int):
dim = embedding_dims
else:
dim = None

if dim is None:
dim = infer_embedding_dim(col_schema)

return dim

def build(self, input_shapes):
super().build(input_shapes)

if self.embedding_layers.built is True:
return

if isinstance(input_shapes, dict):
ordered_input_shapes = []
for feature_name in self.table_names:
ordered_input_shapes.append(input_shapes[feature_name])
elif isinstance(input_shapes, list):
ordered_input_shapes = input_shapes
else:
raise ValueError(f"Unexpected input type encountered: {input_shapes}")
self.embedding_layers.build(ordered_input_shapes)

@tf.function
def call(
self, inputs: Union[Dict[str, tf.Tensor], List[tf.Tensor]]
) -> Union[Dict[str, tf.Tensor], List[tf.Tensor]]:
"""
Parameters
----------
inputs : Union[Dict[str, tf.Tensor], List[tf.Tensor]]
Tensors or dictionary of tensors representing the input batch.

Returns
-------
A tensor or dict of tensors corresponding to the embeddings for inputs
"""

if isinstance(inputs, dict):
ordered_inputs = []
outputs = {}
for feature_name in self.table_names:
ordered_inputs.append(inputs[feature_name])
ordered_outputs = self.embedding_layers(ordered_inputs)
for feature_name, output in zip(self.schema.column_names, ordered_outputs):
outputs[feature_name] = output
elif isinstance(inputs, list):
outputs = self.embedding_layers(inputs)
else:
raise ValueError(f"Unexpected input type encountered: {inputs}")

return outputs

@tf.function
def compute_call_output_shape(self, input_shapes):
def _get_output_shape(input_shape):
batch_size = input_shape[0]
output_shape = tf.TensorShape([batch_size, self.dim])
return output_shape

if isinstance(input_shapes, dict):
output_shapes = {k: _get_output_shape(v) for k, v in input_shapes.items()}
elif isinstance(input_shapes, list):
output_shapes = [_get_output_shape(x) for x in input_shapes]
else:
raise ValueError(f"Unexpected input type encountered: {input_shapes}")

return output_shapes
10 changes: 10 additions & 0 deletions merlin/models/utils/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ def is_transformers_available() -> bool:
except ImportError:
transformers = None
return transformers is not None


def is_distributed_embeddings_available() -> bool:
try:
import horovod # isort: skip
import distributed_embeddings
except ImportError:
horovod = None
distributed_embeddings = None
return horovod is not None and distributed_embeddings is not None
81 changes: 81 additions & 0 deletions tests/unit/tf/horovod/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np
import pytest
import tensorflow as tf

import merlin.models.tf as mm
from merlin.schema import ColumnSchema, Schema, Tags

hvd = pytest.importorskip("horovod.tensorflow.keras")
dmp = pytest.importorskip("distributed_embeddings.python.layers.dist_model_parallel")


def generate_inputs(input_dims, global_batch_size):
global_inputs = [
tf.random.uniform(shape=[global_batch_size], minval=0, maxval=dim, dtype=tf.int64)
for dim in input_dims
]
for t in global_inputs:
hvd.broadcast(t, root_rank=0)
local_batch_size = global_batch_size // hvd.size()
rank = hvd.rank()
inputs = [t[rank * local_batch_size : (rank + 1) * local_batch_size] for t in global_inputs]
return inputs


def test_distributed_embeddings_basic(embedding_dim=4, global_batch_size=8):
column_schema_0 = ColumnSchema(
"col0",
dtype=np.int32,
properties={"domain": {"min": 0, "max": 10, "name": "col0"}},
tags=[Tags.CATEGORICAL],
)
column_schema_1 = ColumnSchema(
"col1",
dtype=np.int32,
properties={"domain": {"min": 0, "max": 20, "name": "col1"}},
tags=[Tags.CATEGORICAL],
)
schema = Schema([column_schema_0, column_schema_1])

inputs = generate_inputs([10, 20], global_batch_size)
table = mm.DistributedEmbeddings(schema, embedding_dim)
outputs = table(inputs)

assert len(outputs) == 2
assert outputs[0].shape == (global_batch_size // hvd.size(), embedding_dim)
assert outputs[1].shape == (global_batch_size // hvd.size(), embedding_dim)


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_dlrm_model_with_embeddings(
music_streaming_data, run_eagerly, batch_size=8, embedding_dim=16, learning_rate=0.03
):
music_streaming_data.schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "click"]
)
train = music_streaming_data.repartition(npartitions=hvd.size())
train_loader = mm.Loader(
train,
schema=train.schema,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)

target_column = train.schema.select_by_tag(Tags.TARGET).column_names[0]

model = mm.DLRMModel(
train.schema,
embeddings=mm.DistributedEmbeddings(
train.schema.select_by_tag(Tags.CATEGORICAL), dim=embedding_dim
),
bottom_block=mm.MLPBlock([32, embedding_dim]),
top_block=mm.MLPBlock([32, embedding_dim]),
prediction_tasks=mm.BinaryClassificationTask(target_column),
)

opt = tf.keras.optimizers.Adagrad(learning_rate=learning_rate)
model.compile(optimizer=opt, run_eagerly=run_eagerly, metrics=[tf.keras.metrics.AUC()])

losses = model.fit(train_loader, epochs=2)
assert all(measure >= 0 for metric in losses.history for measure in losses.history[metric])
15 changes: 11 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ commands =
; Runs in: Github Actions
; Runs GPU-based tests.
allowlist_externals =
bash
horovodrun
deps =
-rrequirements/test.txt
passenv =
OPAL_PREFIX
setenv =
TF_GPU_ALLOCATOR=cuda_malloc_async
sitepackages=true
commands =
# Install Merlin packages
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git@{posargs:main}
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/nvtabular.git@{posargs:main}
horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh pytest -m horovod -rxs tests/unit
# Install distributed embeddings and check build
# TODO: Move distributed-embeddings installation to CI runner.
bash examples/usecases/multi-gpu/install_distributed_embeddings.sh {envtmpdir}
# Run multi-gpu tests marked with `horovod` marker
horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh python -m pytest -m horovod -rxs tests/unit

[testenv:py38-horovod-cpu]
setenv =
Expand All @@ -51,12 +55,15 @@ setenv =
commands =
conda update --yes --name base --channel defaults conda
conda env create --prefix {envdir}/env --file requirements/horovod-cpu-environment.yml --force
# Install horovod and check build
{envdir}/env/bin/python -m pip install horovod --no-cache-dir
{envdir}/env/bin/horovodrun --check-build
# Install Merlin packages
{envdir}/env/bin/python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git
{envdir}/env/bin/python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git
{envdir}/env/bin/python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/nvtabular.git
{envdir}/env/bin/horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh pytest -m horovod -rxs tests/unit
# Run multi-gpu tests marked with `horovod` marker
{envdir}/env/bin/horovodrun -np 2 sh examples/usecases/multi-gpu/hvd_wrapper.sh {envdir}/env/bin/python -m pytest -m horovod -rxs tests/unit

[testenv:py38-nvtabular-cpu]
passenv=GIT_COMMIT
Expand Down