Skip to content

Commit

Permalink
Support to pre-trained embeddings initializer (trainable or not) (#572)
Browse files Browse the repository at this point in the history
* Added PretrainedEmbeddingsInitializer

* Fixing lint issue

* Added docstrings

* Keeping weight matrix on CPU, as it will be copied to the embedding table to any device in forward()

* Adding missing List import

* Fixing lint issues
  • Loading branch information
gabrielspmoreira authored Dec 13, 2022
1 parent 650f932 commit be4ff73
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
28 changes: 28 additions & 0 deletions tests/torch/features/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,34 @@ def test_embedding_features_yoochoose_custom_initializers(yoochoose_schema, torc
)


@pytest.mark.parametrize("trainable", [True, False])
def test_pre_trained_embeddings_initializer(yoochoose_schema, torch_yoochoose_like, trainable):
item_id_cardinality = (
yoochoose_schema.select_by_name("item_id/list").feature[0].int_domain.max + 1
)
embedding_dim = 64
pre_trained_item_embeddings = np.random.rand(item_id_cardinality, embedding_dim)

schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL)
emb_module = tr.EmbeddingFeatures.from_schema(
schema,
embedding_dims={"item_id/list": embedding_dim},
embeddings_initializers={
"item_id/list": tr.PretrainedEmbeddingsInitializer(
pre_trained_item_embeddings, trainable=trainable
),
},
)

assert np.allclose(
emb_module.embedding_tables["item_id/list"].weight.detach().numpy(),
pre_trained_item_embeddings,
)

assert emb_module.embedding_tables["item_id/list"].weight.requires_grad == trainable
_ = emb_module(torch_yoochoose_like)


def test_soft_embedding_invalid_num_embeddings():
with pytest.raises(AssertionError) as excinfo:
tr.SoftEmbedding(num_embeddings=0, embeddings_dim=16)
Expand Down
2 changes: 2 additions & 0 deletions transformers4rec/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .features.embedding import (
EmbeddingFeatures,
FeatureConfig,
PretrainedEmbeddingsInitializer,
SoftEmbedding,
SoftEmbeddingFeatures,
TableConfig,
Expand Down Expand Up @@ -103,6 +104,7 @@
"ContinuousFeatures",
"EmbeddingFeatures",
"SoftEmbeddingFeatures",
"PretrainedEmbeddingsInitializer",
"TabularSequenceFeatures",
"SequenceEmbeddingFeatures",
"FeatureConfig",
Expand Down
39 changes: 38 additions & 1 deletion transformers4rec/torch/features/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#

from functools import partial
from typing import Any, Callable, Dict, Optional, Text, Union
from typing import Any, Callable, Dict, List, Optional, Text, Union

import torch

Expand Down Expand Up @@ -497,3 +497,40 @@ def forward(self, input_numeric):
soft_one_hot_embeddings = (weights.unsqueeze(-1) * self.embedding_table.weight).sum(-2)

return soft_one_hot_embeddings


class PretrainedEmbeddingsInitializer(torch.nn.Module):
"""
Initializer of embedding tables with pre-trained weights
Parameters
----------
weight_matrix : Union[torch.Tensor, List[List[float]]]
A 2D torch or numpy tensor or lists of lists with the pre-trained
weights for embeddings. The expect dims are
(embedding_cardinality, embedding_dim). The embedding_cardinality
can be inferred from the column schema, for example,
`schema.select_by_name("item_id").feature[0].int_domain.max + 1`.
The first position of the embedding table is reserved for padded
items (id=0).
trainable : bool
Whether the embedding table should be trainable or not
"""

def __init__(
self,
weight_matrix: Union[torch.Tensor, List[List[float]]],
trainable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
# The weight matrix is kept in CPU, but when forward() is called
# to initialize the embedding table weight will be copied to
# the embedding table device (e.g. cuda)
self.weight_matrix = torch.tensor(weight_matrix, device="cpu")
self.trainable = trainable

def forward(self, x):
with torch.no_grad():
x.copy_(self.weight_matrix)
x.requires_grad = self.trainable

0 comments on commit be4ff73

Please sign in to comment.