Skip to content

Commit

Permalink
Refactor NextItremPredictionTask to fix serialization and graph-mode (#…
Browse files Browse the repository at this point in the history
…309)

* refactor NextItremPredictionTask for serialization and graph-mode

* fix failing test_transformation test
  • Loading branch information
sararb authored Nov 3, 2021
1 parent 72a1953 commit f87e4b8
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 11 deletions.
37 changes: 37 additions & 0 deletions tests/tf/model/test_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,43 @@ def test_serialization_simple_heads(tf_tabular_features, tf_tabular_data, predic
test_utils.assert_loss_and_metrics_are_valid(copy_head, tf_tabular_data, targets)


def test_serialization_item_prediction_head(tf_yoochoose_like, yoochoose_schema):
input_module = tr.TabularSequenceFeatures.from_schema(
yoochoose_schema,
max_sequence_length=20,
continuous_projection=64,
d_output=64,
)
body = tr.SequentialBlock([input_module, tr.MLPBlock([64])])
task = tr.NextItemPredictionTask(weight_tying=True, metrics=[])

head = task.to_head(body, input_module)

copy_head = test_utils.assert_serialization(head)
targets = tf_yoochoose_like["item_id/list"]

loss = copy_head.compute_loss(tf_yoochoose_like, targets, call_body=True)
assert loss is not None


@test_utils.mark_run_eagerly_modes
def test_item_prediction_yoochoose_model(yoochoose_schema, tf_yoochoose_like, run_eagerly):
input_module = tr.TabularSequenceFeatures.from_schema(
yoochoose_schema,
max_sequence_length=20,
continuous_projection=64,
d_output=64,
)
body = tr.SequentialBlock([input_module, tr.MLPBlock([64])])
task = tr.NextItemPredictionTask(weight_tying=True, metrics=[])

model = task.to_model(body, input_module)
model.compile(optimizer="adam", run_eagerly=run_eagerly)

outputs = model(tf_yoochoose_like)
assert outputs.shape[-1] == 51997


@pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask])
@pytest.mark.parametrize("task_block", [None, tr.MLPBlock([32])])
@pytest.mark.parametrize("summary", ["last", "first", "mean", "cls_index"])
Expand Down
47 changes: 36 additions & 11 deletions transformers4rec/tf/model/prediction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..block.mlp import MLPBlock
from ..ranking_metric import AvgPrecisionAt, NDCGAt, RecallAt
from ..utils.tf_utils import maybe_deserialize_keras_objects, maybe_serialize_keras_objects
from .base import PredictionTask


Expand Down Expand Up @@ -107,7 +108,9 @@ class NextItemPredictionTask(PredictionTask):
Value 1.0 reduces to regular softmax.
"""

DEFAULT_LOSS = tf.keras.losses.SparseCategoricalCrossentropy()
DEFAULT_LOSS = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
)
DEFAULT_METRICS = (
# default metrics suppose labels are int encoded
NDCGAt(top_ks=[10, 20], labels_onehot=True),
Expand All @@ -122,9 +125,10 @@ def __init__(
target_name: Optional[str] = None,
task_name: Optional[str] = None,
task_block: Optional[Layer] = None,
weight_tying: bool = False,
weight_tying: bool = True,
target_dim: int = None,
softmax_temperature: float = 1,
padding_idx: int = 0,
**kwargs,
):
super().__init__(
Expand All @@ -138,6 +142,7 @@ def __init__(
self.weight_tying = weight_tying
self.target_dim = target_dim
self.softmax_temperature = softmax_temperature
self.padding_idx = padding_idx

def build(self, input_shape, body, inputs=None):
# Retrieve the embedding module to get the name of itemid col and its related table
Expand Down Expand Up @@ -196,7 +201,7 @@ def call(self, inputs, **kwargs):
labels = self.embeddings.item_seq

# remove vectors of padded items
trg_flat = tf.reshape(labels, -1)
trg_flat = tf.reshape(labels, (-1,))
non_pad_mask = trg_flat != self.padding_idx
x = self.remove_pad_3d(x, non_pad_mask)

Expand All @@ -208,7 +213,7 @@ def remove_pad_3d(self, inp_tensor, non_pad_mask):
# inp_tensor: (n_batch x seqlen x emb_dim)
inp_tensor = tf.reshape(inp_tensor, (-1, inp_tensor.shape[-1]))
inp_tensor_fl = tf.boolean_mask(
inp_tensor, tf.broadcast_to(tf.expand_dims(non_pad_mask, 1), inp_tensor.shape)
inp_tensor, tf.broadcast_to(tf.expand_dims(non_pad_mask, 1), tf.shape(inp_tensor))
)
out_tensor = tf.reshape(inp_tensor_fl, (-1, inp_tensor.shape[1]))
return out_tensor
Expand All @@ -230,10 +235,14 @@ def compute_loss( # type: ignore
# retrieve labels from masking
if self.masking:
targets = self.masking.masked_targets
# flatten labels and remove padding index
targets = tf.reshape(targets, -1)
non_pad_mask = targets != self.padding_idx
targets = tf.boolean_mask(targets, non_pad_mask)

else:
targets = self.embeddings.item_seq

# flatten labels and remove padding index
targets = tf.reshape(targets, (-1,))
non_pad_mask = targets != self.padding_idx
targets = tf.boolean_mask(targets, non_pad_mask)

loss = self.loss(y_true=targets, y_pred=predictions, sample_weight=sample_weight)

Expand Down Expand Up @@ -294,6 +303,7 @@ def metric_results(self, mode: str = None) -> Dict[str, tf.Tensor]:
return results


@tf.keras.utils.register_keras_serializable(package="transformers4rec")
class _NextItemPredictionTask(tf.keras.layers.Layer):
"""Predict the interacted item-id probabilities.
- During inference, the task consists of predicting the next item.
Expand All @@ -317,11 +327,12 @@ class _NextItemPredictionTask(tf.keras.layers.Layer):
def __init__(
self,
target_dim: int,
weight_tying: bool = False,
weight_tying: bool = True,
item_embedding_table: Optional[tf.Variable] = None,
softmax_temperature: float = 0,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.target_dim = target_dim
self.weight_tying = weight_tying
self.item_embedding_table = item_embedding_table
Expand All @@ -347,6 +358,20 @@ def __init__(
name="logits",
)

@classmethod
def from_config(cls, config):
config = maybe_deserialize_keras_objects(config, ["output_layer"])
return super().from_config(config)

def get_config(self):
config = super().get_config()
config = maybe_serialize_keras_objects(self, config, ["output_layer"])
config["target_dim"] = self.target_dim
config["weight_tying"] = self.weight_tying
config["softmax_temperature"] = self.softmax_temperature

return config

def call(self, inputs: tf.Tensor, **kwargs):
if self.weight_tying:
logits = tf.matmul(inputs, tf.transpose(self.item_embedding_table))
Expand All @@ -363,4 +388,4 @@ def call(self, inputs: tf.Tensor, **kwargs):
return predictions

def _get_name(self) -> str:
return "NextItemPredictionTask"
return "_NextItemPredictionTask"

0 comments on commit f87e4b8

Please sign in to comment.