diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index e25abd1cd6..0cf5a75fb5 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -111,6 +111,7 @@ from merlin.models.tf.models.retrieval import ( MatrixFactorizationModel, TwoTowerModel, + TwoTowerModelV2, YoutubeDNNRetrievalModel, ) from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask diff --git a/merlin/models/tf/core/prediction.py b/merlin/models/tf/core/prediction.py index 865e3b2df7..27b6d14c02 100644 --- a/merlin/models/tf/core/prediction.py +++ b/merlin/models/tf/core/prediction.py @@ -23,16 +23,18 @@ class PredictionContext(NamedTuple): features: Dict[str, TensorLike] targets: Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] = None + top_ids: Optional[tf.Tensor] = None mask: tf.Tensor = (None,) training: bool = False testing: bool = False def with_updates( - self, targets=None, features=None, mask=None, training=None, testing=None + self, targets=None, features=None, top_ids=None, mask=None, training=None, testing=None ) -> "PredictionContext": return PredictionContext( features if features is not None else self.features, targets if targets is not None else self.targets, + top_ids if top_ids is not None else self.top_ids, mask if mask is not None else self.mask, training or self.training, testing or self.testing, @@ -57,6 +59,7 @@ class Prediction(NamedTuple): targets: Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] = None sample_weight: Optional[tf.Tensor] = None features: Optional[Dict[str, TensorLike]] = None + top_ids: Optional[tf.Tensor] = None @property def predictions(self): diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 8a27e0fd3a..5d27d4175d 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -9,6 +9,7 @@ import six import tensorflow as tf +from keras.engine import data_adapter from keras.utils.losses_utils import cast_losses_to_common_dtype from packaging import version from tensorflow.keras.utils import unpack_x_y_sample_weight @@ -26,6 +27,7 @@ from merlin.models.tf.models.utils import parse_prediction_tasks from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask from merlin.models.tf.predictions.base import ContrastivePredictionBlock, PredictionBlock +from merlin.models.tf.predictions.topk import TopKLayer, TopKPrediction from merlin.models.tf.typing import TabularData from merlin.models.tf.utils.search_utils import find_all_instances_in_layers from merlin.models.tf.utils.tf_utils import ( @@ -318,6 +320,9 @@ def compile( ) self.prediction_blocks[0].compile(negative_sampling=negative_sampling) + k = kwargs.pop("k", None) + if k: + self.prediction_blocks[0].compile(k=k) # This flag will make Keras change the metric-names which is not needed in v2 from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) @@ -950,8 +955,9 @@ def _call_child( if isinstance(outputs, Prediction): targets = outputs.targets if outputs.targets is not None else context.targets features = outputs.features if outputs.features is not None else context.features + top_ids = outputs.top_ids outputs = outputs[0] - context = context.with_updates(targets=targets, features=features) + context = context.with_updates(targets=targets, features=features, top_ids=top_ids) return outputs, context @@ -1406,3 +1412,102 @@ def _maybe_convert_merlin_dataset(data, batch_size, shuffle=True, **kwargs): kwargs.pop("shuffle", None) return data + + +class ItemRecommenderModel(Model): + """ + top-k-based recommender model + """ + + def __init__( + self, + *block: tf.keras.layers.Layer, + ): + if not isinstance(block[-1], TopKPrediction): + raise ValueError("The last layer of the model must be a TopKPrediction.") + self.has_item_corpus = False + super().__init__(*block) + + @classmethod + def from_item_encoder( + cls, + item_dataset: merlin.io.Dataset, + item_encoder: tf.keras.layers.Layer, + query_encoder: tf.keras.layers.Layer, + topk_index: "TopKLayer" = None, + id_column: str = None, + ) -> "ItemRecommenderModel": + """Define top-k recommender from an item-encoder block + + Parameters + ---------- + item_dataset : merlin.io.Dataset + Dataset to export item embeddings from + item_encoder : tf.keras.layers.Layer + The `encoder` block that generates item embeddings + query_encoder : tf.keras.layers.Layer + The `encoder` block that generates user/query embeddings + topk_index : TopKLayer + The index layer for retrieving top-candidates, by default None + id_column : str, optional + column name of item-ids, by default None + + Returns + ------- + ItemRecommenderModel + Top-k recommender model + """ + import numpy as np + + from merlin.models.tf.utils.batch_utils import TFModelEncode + + # Convert item_encoder to TopKPredictionBlock + if not id_column and getattr(item_dataset, "schema", None): + tagged = item_dataset.schema.select_by_tag(Tags.ITEM_ID) + if tagged.column_schemas: + id_column = tagged.first.name + model_encode = TFModelEncode(model=item_encoder, output_concat_func=np.concatenate) + + item_dataset = item_dataset.to_ddf() + embedding_ddf = item_dataset.map_partitions(model_encode, filter_input_columns=[id_column]) + item_embeddings = embedding_ddf.compute(scheduler="synchronous") + item_embeddings.set_index(id_column, inplace=True) + prediction = TopKPrediction(item_dataset=item_embeddings, prediction=topk_index) + + return cls(query_encoder, prediction) + + def predict_step(self, data, output_context=True, k=None): + """The logic for one inference step. + This method can be overridden to support custom inference logic. + This method is called by `Model.make_predict_function`. + This method should contain the mathematical logic for one step of inference. + This typically includes the forward pass. + Configuration details for *how* this logic is run (e.g. `tf.function` and + `tf.distribute.Strategy` settings), should be left to + `Model.make_predict_function`, which can also be overridden. + Args: + data: A nested structure of `Tensor`s. + Returns: + The result of one inference step, typically the output of calling the + `Model` on data. + """ + x, _, _ = data_adapter.unpack_x_y_sample_weight(data) + outputs, context = self(x, training=False, output_context=output_context) + return outputs, context.top_ids + + def batch_predict( + self, dataset: merlin.io.Dataset, batch_size: int, top_k=1, **kwargs + ) -> merlin.io.Dataset: + pass + + def query_embeddings( + self, + dataset: merlin.io.Dataset, + batch_size: int, + query_tag: Union[str, Tags] = Tags.USER, + query_id_tag: Union[str, Tags] = Tags.USER_ID, + ) -> merlin.io.Dataset: + pass + + def item_embeddings(self) -> merlin.io.Dataset: + pass diff --git a/merlin/models/tf/models/retrieval.py b/merlin/models/tf/models/retrieval.py index bee1625ef0..b77b63ae50 100644 --- a/merlin/models/tf/models/retrieval.py +++ b/merlin/models/tf/models/retrieval.py @@ -1,17 +1,22 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import merlin.io from merlin.models.tf.blocks.mlp import MLPBlock +from merlin.models.tf.blocks.retrieval.base import TowerBlock from merlin.models.tf.blocks.retrieval.matrix_factorization import QueryItemIdsEmbeddingsBlock from merlin.models.tf.blocks.retrieval.two_tower import TwoTowerBlock from merlin.models.tf.blocks.sampling.base import ItemSampler from merlin.models.tf.core.base import Block, BlockType from merlin.models.tf.inputs.base import InputBlock from merlin.models.tf.inputs.embedding import EmbeddingOptions -from merlin.models.tf.models.base import Model, RetrievalModel +from merlin.models.tf.models.base import ItemRecommenderModel, Model, RetrievalModel from merlin.models.tf.models.utils import parse_prediction_tasks from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask from merlin.models.tf.prediction_tasks.next_item import NextItemPredictionTask from merlin.models.tf.prediction_tasks.retrieval import ItemRetrievalTask +from merlin.models.tf.predictions.dot_product import DotProductCategoricalPrediction +from merlin.models.tf.predictions.topk import TopKLayer, TopKPrediction +from merlin.models.utils.dataset import unique_rows_by_features from merlin.schema import Schema, Tags @@ -241,7 +246,6 @@ def YoutubeDNNRetrievalModel( .. [4] Jean, Sébastien, et al. "On using very large target vocabulary for neural machine translation." arXiv preprint arXiv:1412.2007 (2014). - Parameters ---------- schema: Schema @@ -300,3 +304,223 @@ def YoutubeDNNRetrievalModel( # TODO: Figure out how to make this fit as # a RetrievalModel (which must have a RetrievalBlock) return Model(inputs, top_block, task) + + +class TwoTowerModelV2(RetrievalModel): + """Builds the Two-tower architecture, as proposed in [1]. + + Example Usage:: + two_tower = TwoTowerModel(schema, MLPBlock([256, 64])) + two_tower.compile(optimizer="adam") + two_tower.fit(train_data, epochs=10) + + References + ---------- + [1] Yi, Xinyang, et al. + "Sampling-bias-corrected neural modeling for large corpus item recommendations." + Proceedings of the 13th ACM Conference on Recommender Systems. 2019. + + Parameters + ---------- + schema: Schema + The `Schema` with the input features + query_tower: Block + The `Block` that combines user features + item_tower: Optional[Block], optional + The optional `Block` that combines items features. + If not provided, a copy of the query_tower is used. + query_tower_tag: Tag + The tag to select query features, by default `Tags.USER` + item_tower_tag: Tag + The tag to select item features, by default `Tags.ITEM` + embedding_options : EmbeddingOptions + Options for the input embeddings. + - embedding_dims: Optional[Dict[str, int]] - The dimension of the + embedding table for each feature (key), by default {} + - embedding_dim_default: int - Default dimension of the embedding + table, when the feature is not found in ``embedding_dims``, by default 64 + - infer_embedding_sizes : bool, Automatically defines the embedding + dimension from the feature cardinality in the schema, by default False + - infer_embedding_sizes_multiplier: int. Multiplier used by the heuristic + to infer the embedding dimension from its cardinality. Generally + reasonable values range between 2.0 and 10.0. By default 2.0. + post: Optional[Block], optional + The optional `Block` to apply on both outputs of Two-tower model + prediction_tasks: optional + The optional `PredictionTask` or list of `PredictionTask` to apply on the model. + logits_temperature: float + Parameter used to reduce model overconfidence, so that logits / T. + Defaults to 1. + loss: Optional[LossType] + Loss function. + Defaults to `categorical_crossentropy`. + samplers: List[ItemSampler] + List of samplers for negative sampling, by default `[InBatchSampler()]` + """ + + def __init__( + self, + schema: Schema, + query_tower: Block, + item_tower: Optional[Block] = None, + query_tower_tag=Tags.USER, + item_tower_tag=Tags.ITEM, + embedding_options: EmbeddingOptions = EmbeddingOptions( + embedding_dims=None, + embedding_dim_default=64, + infer_embedding_sizes=False, + infer_embedding_sizes_multiplier=2.0, + ), + post: Optional[BlockType] = None, + prediction_tasks: Optional[ + Union[PredictionTask, List[PredictionTask], ParallelPredictionBlock] + ] = None, + logits_temperature: float = 1.0, + samplers: Sequence[ItemSampler] = ["in-batch"], + **kwargs, + ): + if not prediction_tasks: + prediction_tasks = DotProductCategoricalPrediction( + schema, + **kwargs, + ) + + prediction_tasks = parse_prediction_tasks(schema, prediction_tasks) + two_tower = TwoTowerBlock( + schema=schema, + query_tower=query_tower, + item_tower=item_tower, + query_tower_tag=query_tower_tag, + item_tower_tag=item_tower_tag, + embedding_options=embedding_options, + post=post, + **kwargs, + ) + + super().__init__(two_tower, prediction_tasks, **kwargs) + + def query_block(self) -> TowerBlock: + return self.first._query_block + + def item_block(self) -> TowerBlock: + return self.first._item_block + + def query_embeddings( + self, + dataset: merlin.io.Dataset, + batch_size: int, + query_tag: Union[str, Tags] = Tags.USER, + query_id_tag: Union[str, Tags] = Tags.USER_ID, + ) -> merlin.io.Dataset: + """Export query embeddings from the model. + Parameters + ---------- + dataset : merlin.io.Dataset + Dataset to export embeddings from. + batch_size : int + Batch size to use for embedding extraction. + query_tag: Union[str, Tags], optional + Tag to use for the query. + query_id_tag: Union[str, Tags], optional + Tag to use for the query id. + Returns + ------- + merlin.io.Dataset + Dataset with the user/query features and the embeddings + (one dim per column in the data frame) + """ + from merlin.models.tf.utils.batch_utils import QueryEmbeddings + + get_user_emb = QueryEmbeddings(self, batch_size=batch_size) + + dataset = unique_rows_by_features(dataset, query_tag, query_id_tag).to_ddf() + embeddings = dataset.map_partitions(get_user_emb) + + return merlin.io.Dataset(embeddings) + + def item_embeddings( + self, + dataset: merlin.io.Dataset, + batch_size: int, + item_tag: Union[str, Tags] = Tags.ITEM, + item_id_tag: Union[str, Tags] = Tags.ITEM_ID, + filter_input_columns: bool = False, + ) -> merlin.io.Dataset: + """Export item embeddings from the model. + Parameters + ---------- + dataset : merlin.io.Dataset + Dataset to export embeddings from. + batch_size : int + Batch size to use for embedding extraction. + item_tag : Union[str, Tags], optional + Tag to use for the item. + item_id_tag : Union[str, Tags], optional + Tag to use for the item id, by default Tags.ITEM_ID + filter_input_columns: bool + Returns + ------- + merlin.io.Dataset + Dataset with the item features and the embeddings + (one dim per column in the data frame) + """ + from merlin.models.tf.utils.batch_utils import ItemEmbeddings + + get_item_emb = ItemEmbeddings(self, batch_size=batch_size) + + dataset = unique_rows_by_features(dataset, item_tag, item_id_tag).to_ddf() + if filter_input_columns: + id_column = self.schema.select_by_tag(item_id_tag).first.name + embeddings = dataset.map_partitions(get_item_emb, filter_input_columns=[id_column]) + else: + embeddings = dataset.map_partitions(get_item_emb) + + return merlin.io.Dataset(embeddings) + + def to_item_recommender( + self, + dataset: merlin.io.Dataset, + batch_size: int, + k: int = 10, + prediction: TopKLayer = None, + item_tag: Union[str, Tags] = Tags.ITEM, + item_id_tag: Union[str, Tags] = Tags.ITEM_ID, + ) -> ItemRecommenderModel: + """Convert the retrieval model to a top-k recommender model + for evaluation and inference + + Parameters + ---------- + dataset : merlin.io.Dataset + Dataset to export item embeddings from + batch_size : int + Batch size to use for embedding extraction + k : int, optional + Number of top candidates to retrieve, by default 10 + prediction : TopKLayer, optional + The index layer for retrieving top-candidates, by default None + item_tag : Union[str, Tags], optional + Tag to use for the item, by default Tags.ITEM + item_id_tag : Union[str, Tags], optional + Tag to use for the item-id column, by default Tags.ITEM_ID + + Returns + ------- + ItemRecommenderModel + Top-k recommender model + """ + item_embeddings = ( + self.item_embeddings( + dataset, + batch_size=batch_size, + item_tag=item_tag, + item_id_tag=item_id_tag, + filter_input_columns=True, + ) + .to_ddf() + .compute() + ) + id_column = self.schema.select_by_tag(item_id_tag).first.name + item_embeddings.set_index(id_column, inplace=True) + prediction = TopKPrediction(item_dataset=item_embeddings, prediction=prediction, k=k) + return ItemRecommenderModel(self.query_block(), prediction) diff --git a/merlin/models/tf/predictions/topk.py b/merlin/models/tf/predictions/topk.py new file mode 100644 index 0000000000..057af34a2b --- /dev/null +++ b/merlin/models/tf/predictions/topk.py @@ -0,0 +1,253 @@ +from typing import Optional, Union + +import numpy as np +import tensorflow as tf +from tensorflow.keras.layers import Layer + +import merlin.io +from merlin.core.dispatch import DataFrameType +from merlin.models.tf.core.base import Block +from merlin.models.tf.predictions.base import MetricsFn, Prediction, PredictionBlock +from merlin.models.tf.predictions.classification import default_categorical_prediction_metrics +from merlin.models.tf.utils import tf_utils +from merlin.schema import Tags + + +class TopKPrediction(PredictionBlock): + """Prediction block for top-k evaluation + + Parameters + ---------- + item_dataset: merlin.io.Dataset, + Dataset of the pretrained item embeddings to use for the top-k index. + prediction: TopKLayer, + The layer for indexing the pre-trained candidates and retrieving top-k candidates. + By default None + target: Union[str, Schema], optional + The name of the target. If a Schema is provided, the target is inferred from the schema. + pre: Optional[Block], optional + Optional block to transform predictions before computing the binary logits, + by default None + post: Optional[Block], optional + Optional block to transform the binary logits, + by default None + name: str, optional + The name of the task. + task_block: Block, optional + The block to use for the task. + logits_temperature: float, optional + Parameter used to reduce model overconfidence, so that logits / T. + by default 1. + default_loss: Union[str, tf.keras.losses.Loss], optional + Default loss to use for binary-classification + by 'binary_crossentropy' + default_metrics_fn: Callable + A function returning the list of default metrics + to use for binary-classification + """ + + def __init__( + self, + item_dataset: merlin.io.Dataset, + prediction: "TopKLayer" = None, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature: float = 1.0, + name: Optional[str] = None, + k: int = 10, + default_loss: Union[str, tf.keras.losses.Loss] = "categorical_crossentropy", + default_metrics_fn: MetricsFn = default_categorical_prediction_metrics, + **kwargs, + ): + if prediction is None: + prediction = BruteForce(k=k) + + prediction = prediction.index_from_dataset(item_dataset) + super().__init__( + prediction=prediction, + default_loss=default_loss, + default_metrics_fn=default_metrics_fn, + name=name, + target=target, + pre=pre, + post=post, + logits_temperature=logits_temperature, + **kwargs, + ) + + def compile(self, k=None): + self.prediction._k = k + + +@tf.keras.utils.register_keras_serializable(package="merlin_models") +class TopKLayer(Layer): + def __init__(self, k: int, **kwargs) -> None: + """Initializes the base class.""" + super().__init__(**kwargs) + self._k = k + + def index(self, candidates: tf.Tensor, identifiers: Optional[tf.Tensor] = None) -> "TopKLayer": + """Builds the retrieval index. + When called multiple times the existing index will be dropped and a new one + created. + + Parameters: + ----------- + candidates: tensor of candidate embeddings. + identifiers: Optional tensor of candidate identifiers. If + given, these will be used as identifiers of top candidates returned + when performing searches. If not given, indices into the candidates + tensor will be returned instead. + Returns: + Self + """ + raise NotImplementedError() + + def index_from_dataset( + self, data: merlin.io.Dataset, check_unique_ids: bool = True, **kwargs + ) -> "TopKLayer": + """Builds the retrieval index from a merlin dataset. + + Parameters + ---------- + data : merlin.io.Dataset + The dataset with the pre-trained item embeddings + check_unique_ids : bool, optional + Whether to check if `data` has unique indices, by default True + + Returns + ------- + TopKLayer + return the class with retrieval index + """ + ids, values = self.extract_ids_embeddings(data, check_unique_ids) + return self.index(candidates=values, identifiers=ids, **kwargs) + + @staticmethod + def _check_unique_ids(data: DataFrameType): + if data.index.to_series().nunique() != data.shape[0]: + raise ValueError("Please make sure that `data` contains unique indices") + + def extract_ids_embeddings(self, data: merlin.io.Dataset, check_unique_ids: bool = True): + """Extract tensors of candidates and indices from the merlin dataset + + Parameters + ---------- + data : merlin.io.Dataset + The dataset with the pre-trained item embeddings + check_unique_ids : bool, optional + Whether to check if `data` has unique indices, by default True + """ + if hasattr(data, "to_ddf"): + data = data.to_ddf() + if check_unique_ids: + self._check_unique_ids(data=data) + values = tf_utils.df_to_tensor(data) + ids = tf_utils.df_to_tensor(data.index) + + if len(ids.shape) == 2: + ids = tf.squeeze(ids) + return ids, values + + def get_candidates_dataset( + self, block: Block, data: merlin.io.Dataset, id_column: Optional[str] = None + ): + from merlin.models.tf.utils.batch_utils import TFModelEncode + + if not id_column and getattr(block, "schema", None): + tagged = block.schema.select_by_tag(Tags.ITEM_ID) + if tagged.column_schemas: + id_column = tagged.first.name + + model_encode = TFModelEncode(model=block, output_concat_func=np.concatenate) + + data = data.to_ddf() + embedding_ddf = data.map_partitions(model_encode, filter_input_columns=[id_column]) + embedding_df = embedding_ddf.compute(scheduler="synchronous") + + embedding_df.set_index(id_column, inplace=True) + return embedding_df + + def from_block( + self, block: Block, data: merlin.io.Dataset, id_column: Optional[str] = None, **kwargs + ): + """Build candidates embeddings from applying `block` to a dataset of features `data`. + + Parameters: + ----------- + block: Block + The Block that returns embeddings from raw item features. + data: merlin.io.Dataset + Dataset containing raw item features. + id_column: Optional[str] + The candidates ids column name. + Note, this will be inferred automatically if the block contains + a schema with an item-id Tag. + """ + candidates_dataset = self.get_candidates_dataset(block, data, id_column) + return self.index_from_dataset(candidates_dataset, **kwargs) + + def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: + raise NotImplementedError() + + def _score(self, queries: tf.Tensor, candidates: tf.Tensor) -> tf.Tensor: + """Computes the standard dot product score from queries and candidates.""" + return tf.matmul(queries, candidates, transpose_b=True) + + +@tf.keras.utils.register_keras_serializable(package="merlin_models") +class BruteForce(TopKLayer): + """Brute force retrieval top-k layer.""" + + def __init__(self, k: int = 10, name: Optional[str] = None): + """Initializes the layer. + Args: + query_model: Optional Keras model for representing queries. If provided, + will be used to transform raw features into query embeddings when + querying the layer. If not provided, the layer will expect to be given + query embeddings as inputs. + k: Default k. + name: Name of the layer. + """ + + super().__init__(k=k, name=name) + + self._candidates = None + + def index(self, candidates: tf.Tensor, identifiers: Optional[tf.Tensor]) -> "BruteForce": + + self._ids = self.add_weight( + name="ids", + dtype=tf.int32, + shape=identifiers.shape, + initializer=tf.keras.initializers.Constant(value=tf.cast(identifiers, tf.int32)), + trainable=False, + ) + + self._candidates = self.add_weight( + name="candidates", + dtype=tf.float32, + shape=candidates.shape, + initializer=tf.keras.initializers.Zeros(), + trainable=False, + ) + + self._ids.assign(tf.cast(identifiers, tf.int32)) + self._candidates.assign(tf.cast(candidates, tf.float32)) + return self + + def call(self, inputs, targets=None, k=None, *args, **kwargs) -> "Prediction": + if not k: + k = self._k + scores = self._score(inputs, self._candidates) + top_scores, top_ids = tf.math.top_k(scores, k=k) + if targets is not None: + targets = tf.cast(tf.squeeze(targets), tf.int32) + targets = tf.cast(tf.expand_dims(targets, -1) == top_ids, tf.float32) + targets = tf.reshape(targets, tf.shape(top_scores)) + return Prediction(top_scores, targets, top_ids=top_ids) + + def compute_output_shape(self, input_shape): + batch_size = input_shape[0] + return tf.TensorShape((batch_size, self._k)), tf.TensorShape((batch_size, self._k)) diff --git a/tests/unit/tf/models/test_retrieval.py b/tests/unit/tf/models/test_retrieval.py index c57ce1e8eb..5d027e35d2 100644 --- a/tests/unit/tf/models/test_retrieval.py +++ b/tests/unit/tf/models/test_retrieval.py @@ -368,3 +368,37 @@ def last_interaction_as_target(inputs, targets): losses = model.fit(dataloader, epochs=1) assert losses is not None + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_to_recommender_v2(ecommerce_data: Dataset, run_eagerly): + def _item_id_as_target(inputs, targets): + items = inputs["item_id"] + return inputs, items + + dataloader = BatchedDataset(ecommerce_data, batch_size=50) + dataloader = dataloader.map(_item_id_as_target) + + model = mm.TwoTowerModelV2(ecommerce_data.schema, query_tower=mm.MLPBlock([64, 128])) + model.compile(run_eagerly=run_eagerly, optimizer="adam") + _ = model.fit(ecommerce_data, batch_size=50, epochs=5, steps_per_epoch=1) + + recommender = model.to_item_recommender(dataset=ecommerce_data, batch_size=50, k=10) + recommender.compile(run_eagerly=run_eagerly, optimizer="adam") + history = recommender.evaluate(dataloader, return_dict=True) + assert set(history.keys()) == { + "loss", + "recall_at_10", + "mrr_at_10", + "ndcg_at_10", + "map_at_10", + "precision_at_10", + "regularization_loss", + } + + scores, top_ids = recommender.predict(dataloader, output_context=True) + assert scores.shape[-1] == 10 + + recommender.compile(run_eagerly=run_eagerly, optimizer="adam", k=20) + scores, top_ids = recommender.predict(dataloader, output_context=True) + assert scores.shape[-1] == 20