Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-k recommender model #663

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
from merlin.models.tf.models.retrieval import (
MatrixFactorizationModel,
TwoTowerModel,
TwoTowerModelV2,
YoutubeDNNRetrievalModel,
)
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
Expand Down
5 changes: 4 additions & 1 deletion merlin/models/tf/core/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@
class PredictionContext(NamedTuple):
features: Dict[str, TensorLike]
targets: Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] = None
top_ids: Optional[tf.Tensor] = None
mask: tf.Tensor = (None,)
training: bool = False
testing: bool = False

def with_updates(
self, targets=None, features=None, mask=None, training=None, testing=None
self, targets=None, features=None, top_ids=None, mask=None, training=None, testing=None
) -> "PredictionContext":
return PredictionContext(
features if features is not None else self.features,
targets if targets is not None else self.targets,
top_ids if top_ids is not None else self.top_ids,
mask if mask is not None else self.mask,
training or self.training,
testing or self.testing,
Expand All @@ -57,6 +59,7 @@ class Prediction(NamedTuple):
targets: Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] = None
sample_weight: Optional[tf.Tensor] = None
features: Optional[Dict[str, TensorLike]] = None
top_ids: Optional[tf.Tensor] = None

@property
def predictions(self):
Expand Down
107 changes: 106 additions & 1 deletion merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import six
import tensorflow as tf
from keras.engine import data_adapter
from keras.utils.losses_utils import cast_losses_to_common_dtype
from packaging import version
from tensorflow.keras.utils import unpack_x_y_sample_weight
Expand All @@ -26,6 +27,7 @@
from merlin.models.tf.models.utils import parse_prediction_tasks
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
from merlin.models.tf.predictions.base import ContrastivePredictionBlock, PredictionBlock
from merlin.models.tf.predictions.topk import TopKLayer, TopKPrediction
from merlin.models.tf.typing import TabularData
from merlin.models.tf.utils.search_utils import find_all_instances_in_layers
from merlin.models.tf.utils.tf_utils import (
Expand Down Expand Up @@ -318,6 +320,9 @@ def compile(
)
self.prediction_blocks[0].compile(negative_sampling=negative_sampling)

k = kwargs.pop("k", None)
if k:
self.prediction_blocks[0].compile(k=k)
# This flag will make Keras change the metric-names which is not needed in v2
from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0)

Expand Down Expand Up @@ -950,8 +955,9 @@ def _call_child(
if isinstance(outputs, Prediction):
targets = outputs.targets if outputs.targets is not None else context.targets
features = outputs.features if outputs.features is not None else context.features
top_ids = outputs.top_ids
outputs = outputs[0]
context = context.with_updates(targets=targets, features=features)
context = context.with_updates(targets=targets, features=features, top_ids=top_ids)

return outputs, context

Expand Down Expand Up @@ -1406,3 +1412,102 @@ def _maybe_convert_merlin_dataset(data, batch_size, shuffle=True, **kwargs):
kwargs.pop("shuffle", None)

return data


class ItemRecommenderModel(Model):
"""
top-k-based recommender model
"""

def __init__(
self,
*block: tf.keras.layers.Layer,
):
if not isinstance(block[-1], TopKPrediction):
raise ValueError("The last layer of the model must be a TopKPrediction.")
self.has_item_corpus = False
super().__init__(*block)

@classmethod
def from_item_encoder(
cls,
item_dataset: merlin.io.Dataset,
item_encoder: tf.keras.layers.Layer,
query_encoder: tf.keras.layers.Layer,
topk_index: "TopKLayer" = None,
id_column: str = None,
) -> "ItemRecommenderModel":
"""Define top-k recommender from an item-encoder block

Parameters
----------
item_dataset : merlin.io.Dataset
Dataset to export item embeddings from
item_encoder : tf.keras.layers.Layer
The `encoder` block that generates item embeddings
query_encoder : tf.keras.layers.Layer
The `encoder` block that generates user/query embeddings
topk_index : TopKLayer
The index layer for retrieving top-candidates, by default None
id_column : str, optional
column name of item-ids, by default None

Returns
-------
ItemRecommenderModel
Top-k recommender model
"""
import numpy as np

from merlin.models.tf.utils.batch_utils import TFModelEncode

# Convert item_encoder to TopKPredictionBlock
if not id_column and getattr(item_dataset, "schema", None):
tagged = item_dataset.schema.select_by_tag(Tags.ITEM_ID)
if tagged.column_schemas:
id_column = tagged.first.name
model_encode = TFModelEncode(model=item_encoder, output_concat_func=np.concatenate)

item_dataset = item_dataset.to_ddf()
embedding_ddf = item_dataset.map_partitions(model_encode, filter_input_columns=[id_column])
item_embeddings = embedding_ddf.compute(scheduler="synchronous")
item_embeddings.set_index(id_column, inplace=True)
prediction = TopKPrediction(item_dataset=item_embeddings, prediction=topk_index)

return cls(query_encoder, prediction)

def predict_step(self, data, output_context=True, k=None):
"""The logic for one inference step.
This method can be overridden to support custom inference logic.
This method is called by `Model.make_predict_function`.
This method should contain the mathematical logic for one step of inference.
This typically includes the forward pass.
Configuration details for *how* this logic is run (e.g. `tf.function` and
`tf.distribute.Strategy` settings), should be left to
`Model.make_predict_function`, which can also be overridden.
Args:
data: A nested structure of `Tensor`s.
Returns:
The result of one inference step, typically the output of calling the
`Model` on data.
"""
x, _, _ = data_adapter.unpack_x_y_sample_weight(data)
outputs, context = self(x, training=False, output_context=output_context)
return outputs, context.top_ids

def batch_predict(
self, dataset: merlin.io.Dataset, batch_size: int, top_k=1, **kwargs
) -> merlin.io.Dataset:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should raise NotImplementedError()


def query_embeddings(
self,
dataset: merlin.io.Dataset,
batch_size: int,
query_tag: Union[str, Tags] = Tags.USER,
query_id_tag: Union[str, Tags] = Tags.USER_ID,
) -> merlin.io.Dataset:
pass

def item_embeddings(self) -> merlin.io.Dataset:
pass
Loading