Skip to content

Commit

Permalink
Add support of transformer-based retrieval models (#1128)
Browse files Browse the repository at this point in the history
* add ragged support in topk block

* extend candidate embeddings extraction to CategoricalOutput

* make bias term optional in the weight-tying class

* add comment about top-k only works for the last item in the session

---------

Co-authored-by: rnyak <[email protected]>
Co-authored-by: edknv <[email protected]>
  • Loading branch information
3 people authored Jun 12, 2023
1 parent ed45657 commit 9980689
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 10 deletions.
3 changes: 2 additions & 1 deletion merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from merlin.models.tf.metrics.topk import TopKMetricsAggregator, filter_topk_metrics, split_metrics
from merlin.models.tf.models.utils import parse_prediction_blocks
from merlin.models.tf.outputs.base import ModelOutput, ModelOutputType
from merlin.models.tf.outputs.classification import CategoricalOutput
from merlin.models.tf.outputs.contrastive import ContrastiveOutput
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
from merlin.models.tf.transforms.features import PrepareFeatures, expected_input_cols_from_schema
Expand Down Expand Up @@ -2374,7 +2375,7 @@ def candidate_embeddings(

return candidate.encode(dataset, index=index, **kwargs)

if isinstance(self.last, ContrastiveOutput):
if isinstance(self.last, (ContrastiveOutput, CategoricalOutput)):
return self.last.to_dataset()

raise Exception(...)
Expand Down
19 changes: 12 additions & 7 deletions merlin/models/tf/outputs/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ class EmbeddingTablePrediction(Layer):
The embedding table to use as the weight matrix
bias_initializer : str, optional
Initializer for the bias vector, by default "zeros"
use_bias: bool, optional
Whether to add a bias term to weight-tying, by default False
References:
----------
Expand All @@ -312,18 +314,20 @@ class EmbeddingTablePrediction(Layer):
arXiv:1611.01462 (2016).
"""

def __init__(self, table: EmbeddingTable, bias_initializer="zeros", **kwargs):
def __init__(self, table: EmbeddingTable, bias_initializer="zeros", use_bias=False, **kwargs):
self.table = table
self.num_classes = table.input_dim
self.bias_initializer = bias_initializer
self.use_bias = use_bias
super().__init__(**kwargs)

def build(self, input_shape):
self.bias = self.add_weight(
name="output_layer_bias",
shape=(self.num_classes,),
initializer=self.bias_initializer,
)
if self.use_bias:
self.bias = self.add_weight(
name="output_layer_bias",
shape=(self.num_classes,),
initializer=self.bias_initializer,
)
self.table.build(input_shape)
return super().build(input_shape)

Expand All @@ -333,7 +337,8 @@ def call(self, inputs, training=False, **kwargs) -> tf.Tensor:
original_inputs = inputs
inputs = inputs.flat_values
logits = tf.matmul(inputs, self.table.table.embeddings, transpose_b=True)
logits = tf.nn.bias_add(logits, self.bias)
if self.use_bias:
logits = tf.nn.bias_add(logits, self.bias)
if is_ragged:
logits = original_inputs.with_flat_values(logits)
return logits
Expand Down
10 changes: 10 additions & 0 deletions merlin/models/tf/outputs/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def call(
"You should call the `index` method first to " "set the _candidates index."
)

if isinstance(inputs, tf.RaggedTensor):
# Evaluates on last session's item only
# (which is the default mode during inference too).
# TODO extend top-k generation to other items in the input session.
inputs = tf.squeeze(inputs.to_tensor(), axis=1)
tf.assert_equal(
tf.shape(inputs)[1],
tf.shape(self._candidates)[1],
Expand All @@ -220,6 +225,11 @@ def call(
assert targets is not None, ValueError(
"Targets should be provided during the evaluation mode"
)
if isinstance(targets, tf.RaggedTensor):
targets = tf.ragged.boolean_mask(
targets, targets._keras_mask.with_row_splits_dtype(targets.row_splits.dtype)
)
targets = targets.to_tensor()
targets = tf.cast(tf.squeeze(targets), tf.int32)
targets = tf.cast(tf.expand_dims(targets, -1) == top_ids, tf.float32)
targets = tf.reshape(targets, tf.shape(top_scores))
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/tf/outputs/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def test_categorical_output(sequence_testing_data: Dataset, run_eagerly):


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_last_item_prediction(sequence_testing_data: Dataset, run_eagerly):
@pytest.mark.parametrize("use_bias", [True, False])
def test_last_item_prediction(sequence_testing_data: Dataset, run_eagerly, use_bias):
dataloader, schema = testing_utils.loader_for_last_item_prediction(sequence_testing_data)
embeddings = mm.Embeddings(
schema,
Expand All @@ -110,7 +111,7 @@ def test_last_item_prediction(sequence_testing_data: Dataset, run_eagerly):
schema["item_id_seq"],
CategoricalTarget(schema["item_id_seq"]),
embeddings["item_id_seq"],
EmbeddingTablePrediction(embeddings["item_id_seq"]),
EmbeddingTablePrediction(embeddings["item_id_seq"], use_bias=use_bias),
]

for target in predictions:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/tf/outputs/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numpy as np
import pytest
import tensorflow as tf

Expand All @@ -28,6 +29,10 @@ def test_brute_force_layer():
candidates = tf.random.uniform(shape=(num_candidates, 4), dtype=tf.float32)
query = tf.random.uniform(shape=(num_queries, 4), dtype=tf.float32)

# Create a ragged query
elements = np.random.rand(num_queries, 1, 4)
ragged_query = tf.ragged.constant(elements)

wrong_candidates_rank = tf.random.uniform(shape=(num_candidates,), dtype=tf.float32)
wrong_query_dim = tf.random.uniform(shape=(num_queries, 8), dtype=tf.float32)
wrong_identifiers_shape = tf.range(num_candidates + 1, dtype=tf.int32)
Expand Down Expand Up @@ -60,6 +65,9 @@ def test_brute_force_layer():
assert list(topk_output.scores.shape) == [num_queries, top_k]
assert list(topk_output.identifiers.shape) == [num_queries, top_k]
assert isinstance(topk_output, TopKPrediction)
assert list(topk_output.scores.shape) == [num_queries, top_k]
ragged_topk_output = brute_force(ragged_query)
assert list(ragged_topk_output.scores.shape) == [num_queries, top_k]

with pytest.raises(Exception) as excinfo:
brute_force(query, targets=None, testing=True)
Expand Down

0 comments on commit 9980689

Please sign in to comment.