diff --git a/tests/torch/features/test_embedding.py b/tests/torch/features/test_embedding.py index b2ec93f63a..f7f87d9bae 100644 --- a/tests/torch/features/test_embedding.py +++ b/tests/torch/features/test_embedding.py @@ -151,6 +151,34 @@ def test_embedding_features_yoochoose_custom_initializers(yoochoose_schema, torc ) +@pytest.mark.parametrize("trainable", [True, False]) +def test_pre_trained_embeddings_initializer(yoochoose_schema, torch_yoochoose_like, trainable): + item_id_cardinality = ( + yoochoose_schema.select_by_name("item_id/list").feature[0].int_domain.max + 1 + ) + embedding_dim = 64 + pre_trained_item_embeddings = np.random.rand(item_id_cardinality, embedding_dim) + + schema = yoochoose_schema.select_by_tag(Tag.CATEGORICAL) + emb_module = tr.EmbeddingFeatures.from_schema( + schema, + embedding_dims={"item_id/list": embedding_dim}, + embeddings_initializers={ + "item_id/list": tr.PretrainedEmbeddingsInitializer( + pre_trained_item_embeddings, trainable=trainable + ), + }, + ) + + assert np.allclose( + emb_module.embedding_tables["item_id/list"].weight.detach().numpy(), + pre_trained_item_embeddings, + ) + + assert emb_module.embedding_tables["item_id/list"].weight.requires_grad == trainable + _ = emb_module(torch_yoochoose_like) + + def test_soft_embedding_invalid_num_embeddings(): with pytest.raises(AssertionError) as excinfo: tr.SoftEmbedding(num_embeddings=0, embeddings_dim=16) diff --git a/transformers4rec/torch/__init__.py b/transformers4rec/torch/__init__.py index 0b066203b8..de222e3a3a 100644 --- a/transformers4rec/torch/__init__.py +++ b/transformers4rec/torch/__init__.py @@ -40,6 +40,7 @@ from .features.embedding import ( EmbeddingFeatures, FeatureConfig, + PretrainedEmbeddingsInitializer, SoftEmbedding, SoftEmbeddingFeatures, TableConfig, @@ -103,6 +104,7 @@ "ContinuousFeatures", "EmbeddingFeatures", "SoftEmbeddingFeatures", + "PretrainedEmbeddingsInitializer", "TabularSequenceFeatures", "SequenceEmbeddingFeatures", "FeatureConfig", diff --git a/transformers4rec/torch/features/embedding.py b/transformers4rec/torch/features/embedding.py index 4bd1df55d7..f9bc83ff56 100644 --- a/transformers4rec/torch/features/embedding.py +++ b/transformers4rec/torch/features/embedding.py @@ -15,7 +15,7 @@ # from functools import partial -from typing import Any, Callable, Dict, Optional, Text, Union +from typing import Any, Callable, Dict, List, Optional, Text, Union import torch @@ -497,3 +497,40 @@ def forward(self, input_numeric): soft_one_hot_embeddings = (weights.unsqueeze(-1) * self.embedding_table.weight).sum(-2) return soft_one_hot_embeddings + + +class PretrainedEmbeddingsInitializer(torch.nn.Module): + """ + Initializer of embedding tables with pre-trained weights + + Parameters + ---------- + weight_matrix : Union[torch.Tensor, List[List[float]]] + A 2D torch or numpy tensor or lists of lists with the pre-trained + weights for embeddings. The expect dims are + (embedding_cardinality, embedding_dim). The embedding_cardinality + can be inferred from the column schema, for example, + `schema.select_by_name("item_id").feature[0].int_domain.max + 1`. + The first position of the embedding table is reserved for padded + items (id=0). + trainable : bool + Whether the embedding table should be trainable or not + """ + + def __init__( + self, + weight_matrix: Union[torch.Tensor, List[List[float]]], + trainable: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + # The weight matrix is kept in CPU, but when forward() is called + # to initialize the embedding table weight will be copied to + # the embedding table device (e.g. cuda) + self.weight_matrix = torch.tensor(weight_matrix, device="cpu") + self.trainable = trainable + + def forward(self, x): + with torch.no_grad(): + x.copy_(self.weight_matrix) + x.requires_grad = self.trainable