Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and load SOK model embeddings #951

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0141ad4
new sok class
Nov 4, 2022
d7ad9ba
new sok class
Nov 4, 2022
7825963
test sok dynamic variable
Nov 7, 2022
475145a
test sok dynamic variable
Nov 7, 2022
4ef7a9e
bug fix comma
Nov 7, 2022
720eacf
add some comments and test distributed var
Nov 16, 2022
7a3a177
format the comments
Nov 17, 2022
54f795b
assert condition in sok lookup sparse
Dec 6, 2022
f3136b7
Merge branch 'main' into fea-sok-integration-wj
edknv Dec 14, 2022
7555c6f
Move SOKEmbedding to a separate file
edknv Dec 14, 2022
d0111b1
Clean up
edknv Dec 14, 2022
debb6e2
Clean up
edknv Dec 14, 2022
a4e3ffc
fix some import and param bug
Dec 27, 2022
97d51f5
remove some unused variable
Dec 28, 2022
b978afa
remove intial vals
Dec 28, 2022
4f22c0f
fix import
Dec 28, 2022
16fb414
reorder the import
Dec 28, 2022
9b37d4a
Merge branch 'main' into fea-sok-integration-wj
marcromeyn Jan 9, 2023
557af98
Merge branch 'main' into fea-sok-integration-wj
edknv Jan 9, 2023
f41f52f
fix import error in test embedding
Jan 12, 2023
f03e4ef
Merge branch 'fea-sok-integration-wj' of https://github.com/NVIDIA-Me…
Jan 12, 2023
fe34f9d
format the code
Jan 13, 2023
98bb17b
change the way of import
Jan 13, 2023
b0de517
Merge branch 'main' into fea-sok-integration-wj
WonderingWJ Jan 15, 2023
3e3c3b7
Merge branch 'main' into fea-sok-integration-wj
Jan 15, 2023
f7ff20f
Merge branch 'fea-sok-integration-wj' of https://github.com/NVIDIA-Me…
Jan 15, 2023
0d5a23f
ckpt to load and dump model
Jan 31, 2023
1560d48
load/dump interface for sok
Feb 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions merlin/models/tf/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from typing import Union

import tensorflow as tf
from sparse_operation_kit import experiment as sok

from merlin.models.tf.inputs.embedding import EmbeddingTableBase
from merlin.models.utils.schema_utils import (
schema_to_tensorflow_metadata_json,
tensorflow_metadata_json_to_schema,
)
from merlin.schema import ColumnSchema


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SOKEmbedding(EmbeddingTableBase):
"""
Wrap GPU accelerated opererations dedicated for sparse training / inference case.
dim: int The last dimension of the variable
vocab_sizes: list, rows of the variable list
initializer: string, list = "uniform"
When it's string, it specifies the initializer used to generate initial values.
For sok.DynamicVariable, currently, only support "random" or string of a float
value(meaning const initializer).
For sok.Variable, it is compatible with tf.Variable.
Default value is "uniform".
When it's list, it specifies the values in the embedding table.
For sok.DynamicVariable, initializer[i] must be list of [index, value],
and will be used as the initial indices and value for i-th sok.DynamicVariable.
For sok.Variable, initializer[i] must be a numpy with shape
[vocab_size[i], embedding_vec_size],
and will be used as the initial value for i-th sok.Variable.
use_dynamic_variable: bool = "False" use sok.DynamicVariable or sok.Variable. DynamicVariable
can allocates memory dynamically. Variable is a model-parallel distributed variable
localized: When utilizing sok.Variable, we change choose two mode: distributed(Distributed Va
riable) and localized(Localized Variable). If set to None, use Distributed Variable,
otherwise Localized Variable. where the list indicates which GPU you want to put this
variable on.
Default is None.
Examples
--------
.. code-block:: python
Notes
-----
"""

def __init__(
self,
dim: int,
*col_schemas: ColumnSchema,
vocab_sizes: list,
initializer: Union[str, tf.Tensor, list] = "uniform",
use_dynamic_variable=False,
localized=None,
trainable=True,
name=None,
dtype=None,
**kwargs,
):
super(SOKEmbedding, self).__init__(
dim, *col_schemas, trainable=trainable, name=name, dtype=dtype, **kwargs
)
self._embedding_vec_size = dim
self._vocab_sizes = vocab_sizes
self._use_dynamic_variable = use_dynamic_variable
self._localized = localized
self._vars = []
if self._localized is None and self._use_dynamic_variable is False:
for i in range(len(vocab_sizes)):
if isinstance(initializer, str):
v = sok.Variable(
shape=[self._vocab_sizes[i], self._embedding_vec_size],
initializer=tf.keras.initializers.get(initializer),
dtype=tf.float32,
)
else:
v = sok.Variable(initializer[i])
else:
for i in range(len(vocab_sizes)):
if self._use_dynamic_variable:
if isinstance(initializer, str):
v = sok.DynamicVariable(
dimension=self._embedding_vec_size, initializer=initializer
)
else:
indices = tf.convert_to_tensor(initializer[i][0])
values = tf.convert_to_tensor(initializer[i][1])
sok.assign(v, indices, values)
elif self._localized is not None:
if isinstance(initializer, str):
v = sok.Variable(
shape=[self._vocab_sizes[i], self._embedding_vec_size],
initializer=tf.keras.initializers.get(initializer),
dtype=tf.float32,
mode="localized:%d" % self._localized[i],
)
else:
v = sok.Variable(
initializer[i],
mode="localized:%d" % self._localized[i],
)
else:
raise ValueError("Wrong Configuration!!!")
self._trainable_weights.append(v)
self._vars.append(v)

def call(self, inputs, combiners, training=True):
"""
inputs: list, tuple
a list or tuple of tf.SparseTensor or tf.RaggedTensor.
combiners: list, tuple
a list or tuple of string to specify the combiner of each lookup.
"""
is_list = isinstance(inputs, list) or isinstance(inputs, tuple)
if is_list:
for cur_input in inputs:
if not isinstance(cur_input, tf.SparseTensor):
if not isinstance(cur_input, tf.RaggedTensor):
raise ValueError(
"The input must be a list of tf.SparseTensor or tf.RaggedTensor"
)
else:
if not len(cur_input.shape) == 2:
raise ValueError("The rank of input RaggedTensor must be 2")
else:
if not isinstance(cur_input, tf.SparseTensor):
if not isinstance(cur_input, tf.RaggedTensor):
raise ValueError(
"The input must be a list of tf.SparseTensor or tf.RaggedTensor"
)
else:
if not len(cur_input.shape) == 2:
raise ValueError("The rank of input RaggedTensor must be 2")
emb_vectors = sok.lookup_sparse(
self._vars,
inputs,
combiners,
)
return emb_vectors
def load(path):
sok.load(path+"/sok/"+name+".weights",sok_vars)

def dump(path):
sok.dump(path+"/sok/"+name+".weights",sok_vars)

@classmethod
def from_pretrained(
cls,
dim: int,
data: list,
trainable=True,
name=None,
col_schema=None,
use_dynamic_variable=False,
localized=None,
**kwargs,
) -> "SOKEmbedding":
"""Create From pre-trained embeddings from a Dataset or DataFrame.
Parameters
----------
data :
A list of numpy.array or A list of dict {"indice": numpy.array, "values": numpy.array}
trainable : bool
Whether the layer should be trained or not.
name : str
The name of the layer.
"""
weights = []
for i, item in enumerate(data):
if use_dynamic_variable:
if isinstance(item, dict) and "indice" in item and "values" in item:
weights.append([item["indice"], item["values"]])
else:
raise ValueError("DynamicVariable should be initialized with indice and values")
else:
weights.append(item)

return cls(
dim,
col_schema,
name=name,
initializer=weights,
use_dynamic_variable=use_dynamic_variable,
localized=localized,
trainable=trainable,
**kwargs,
)

def get_config(self):
config = super().get_config()
config["dim"] = self.dim

schema = schema_to_tensorflow_metadata_json(self.schema)
config["schema"] = schema
config["vocab_sizes"] = self._vocab_sizes
config["initializer"] = self._initializer
config["use_dynamic_variable"] = self._use_dynamic_variable
config["localized"] = self._localized

return config

@classmethod
def from_config(cls, config):
dim = config.pop("dim")
schema = tensorflow_metadata_json_to_schema(config.pop("schema"))
vocab_size = config["vocab_sizes"]
initializer = config["initializer"]
use_dynamic_variable = config["use_dynamic_variable"]
localized = config["localized"]

return cls(dim, *schema, vocab_size, initializer, use_dynamic_variable, localized, **config)

def model_dump(model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, path: str):
sok_variables, other_variables = sok.filter_variables(model.weights)
ckpt = tf.train.Checkpoint(variables=other_variables, optimizer = optimizer)
sok.dump(path+"/sok_var/sok_weights", sok_variables)
ckpt.save(path)

def model_load(model: tf.keras.Model, optimizer: tf.keras.optimizers.Optimizer, path:str):
sok_variables, other_variables = sok.filter_variables(model.weights)
ckpt = tf.train.Checkpoint(variables=other_variables, optimizer = optimizer)
ckpt.restore(tf.train.latest_checkpoint(path))
sok.load(path+"/sok_var/sok_weights", sok_variables)

1 change: 1 addition & 0 deletions requirements/horovod.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
horovod
merlin-sok
120 changes: 120 additions & 0 deletions tests/unit/tf/horovod/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import numpy as np
import pytest
import tensorflow as tf

import merlin.models.tf as mm
from merlin.schema import ColumnSchema, Tags

hvd = pytest.importorskip("horovod.tensorflow")
sok = pytest.importorskip("sparse_operation_kit.experiment")

hvd.init()
sok.init()


class TestSOKEmbedding:
sample_column_schema = ColumnSchema(
"item_id",
dtype=np.int32,
properties={"domain": {"min": 0, "max": 10, "name": "item_id"}},
tags=[Tags.CATEGORICAL],
)

def test_raises_with_invalid_schema(self):
column_schema = ColumnSchema("item_id")
with pytest.raises(ValueError) as exc_info:
mm.EmbeddingTable(16, column_schema)
assert "needs to have an int-domain" in str(exc_info.value)

@pytest.mark.parametrize("dim", [16, 32])
def test_sok_dynamic_variables(self, dim):

rows = [65536 * 10, 65536]
cols = [128, 4]
initial_vals = [13, 17]

# sok variables
sok_vars = [
sok.DynamicVariable(dimension=cols[i], initializer=str(initial_vals[i]))
for i in range(len(cols))
]
local_indices = []
for row in rows:
local_size = row // hvd.size()
if hvd.rank() < row % hvd.size():
local_size += 1
indices = np.arange(local_size) * hvd.size() + hvd.rank()
indices = tf.convert_to_tensor(indices, dtype=tf.int64)
local_indices.append(indices)
out1 = []
for i in range(len(sok_vars)):
out1.append(tf.nn.embedding_lookup(sok_vars[i], local_indices[i]))

tf_vars = [
tf.Variable(tf.constant(initial_vals[i], shape=[rows[i], cols[i]], dtype=tf.float32))
for i in range(len(rows))
]
out2 = []
for i, v in enumerate(tf_vars):
out2.append(tf.nn.embedding_lookup(v, local_indices[i]))

# Check results
diff = 0
for i in range(len(out1)):
length = out1[i] ** 2 + out2[i] ** 2 + 1e-8
diff = diff + tf.reduce_sum((out1[i] - out2[i]) ** 2 / length)
print("[SOK INFO] diff:", diff)
assert diff < 1e-6

@pytest.mark.parametrize("dim", [16, 32])
def test_distributed_variables(self, dim):
rows = [65536 * 10, 65536]
cols = [128, 4]

# initial value of embedding table
weights = []
for i in range(len(rows)):
weight = np.random.rand(rows[i], cols[i]).astype(np.float32)
weight = tf.convert_to_tensor(weight, dtype=tf.float32)
# make sure the weight is same on each rank
weight = hvd.allreduce(weight)
weights.append(weight)

# sok variables
sok_vars = [sok.Variable(w) for w in weights]
local_indices = []
for row in rows:
local_size = row // hvd.size()
if hvd.rank() < row % hvd.size():
local_size += 1
indices = np.arange(local_size) * hvd.size() + hvd.rank()
indices = tf.convert_to_tensor(indices, dtype=tf.int64)
local_indices.append(indices)

out1 = sok_vars
tf_vars = [tf.Variable(w) for w in weights]
out2 = []
for i, v in enumerate(tf_vars):
out2.append(tf.nn.embedding_lookup(v, local_indices[i]))

# Check results
diff = 0
for i in range(len(out1)):
length = out1[i] ** 2 + out2[i] ** 2 + 1e-8
diff = diff + tf.reduce_sum((out1[i] - out2[i]) ** 2 / length)
print("[SOK INFO] diff:", diff)
assert diff < 1e-6
def test_sok_embedding_in_model(self, )
cat_schema = music_streaming_data.schema.select_by_tag(Tags.CATEGORICAL)
input_block = InputBlockV2(
cat_schema,
categorical=Embeddings(cat_schema, dim=16, table_cls=SOKEmbedding)
)
model = mm.DCNModel(
music_streaming_data.schema,
depth=2,
input_block=input_block,
deep_block=mm.MLPBlock([64, 32]),
prediction_tasks=mm.BinaryClassificationTask("click"),
)
model_test(model, music_streaming_data)