+#
+# Copyright (c) 2021, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import copy
+import inspect
+import os
+import pathlib
+from collections import defaultdict
+from types import SimpleNamespace
+from typing import Callable, Dict, Iterable, List, Optional, Type, Union, cast
+
+import numpy as np
+import torch
+import torchmetrics as tm
+from merlin.models.utils.registry import camelcase_to_snakecase
+from merlin.schema import ColumnSchema
+from merlin.schema import Schema as Core_Schema
+from merlin.schema import Tags
+from tqdm import tqdm
+from transformers.modeling_utils import SequenceSummary
+
+from merlin_standard_lib import Schema
+
+from ..block.base import BlockBase, BlockOrModule, BlockType
+from ..features.base import InputBlock
+from ..features.sequence import TabularFeaturesType
+from ..typing import TabularData
+from ..utils.padding import pad_inputs
+from ..utils.torch_utils import LossMixin, MetricsMixin
+
+
+def name_fn(name, inp):
+ return "/".join([name, inp]) if name else None
+
+
+[docs]class PredictionTask(torch.nn.Module, LossMixin, MetricsMixin):
+
"""Individual prediction-task of a model.
+
Parameters
+
----------
+
loss: torch.nn.Module
+
The loss to use during training of this task.
+
metrics: torch.nn.Module
+
The metrics to calculate during training & evaluation.
+
target_name: str, optional
+
Name of the target, this is needed when there are multiple targets.
+
task_name: str, optional
+
Name of the prediction task, if not provided a name will be automatically constructed based
+
on the target-name & class-name.
+
forward_to_prediction_fn: Callable[[torch.Tensor], torch.Tensor]
+
Function to apply before the prediction
+
task_block: BlockType
+
Module to transform input tensor before computing predictions.
+
pre: BlockType
+
Module to compute the predictions probabilities.
+
summary_type: str
+
This is used to summarize a sequence into a single tensor. Accepted values are:
+
- `"last"` -- Take the last token hidden state (like XLNet)
+
- `"first"` -- Take the first token hidden state (like Bert)
+
- `"mean"` -- Take the mean of all tokens hidden states
+
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+
- `"attn"` -- Not implemented now, use multi-head attention
+
"""
+
+
def __init__(
+
self,
+
loss: torch.nn.Module,
+
metrics: Iterable[tm.Metric] = None,
+
target_name: Optional[str] = None,
+
task_name: Optional[str] = None,
+
forward_to_prediction_fn: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
+
task_block: Optional[BlockType] = None,
+
pre: Optional[BlockType] = None,
+
summary_type: str = "last",
+
):
+
super().__init__()
+
self.summary_type = summary_type
+
self.sequence_summary = SequenceSummary(
+
SimpleNamespace(summary_type=self.summary_type) # type: ignore
+
) # noqa
+
self.target_name = target_name
+
self.forward_to_prediction_fn = forward_to_prediction_fn
+
self.set_metrics(metrics)
+
self.loss = loss
+
self.pre = pre
+
self.task_block = task_block
+
self._task_name = task_name
+
+
[docs] def build(
+
self,
+
body: BlockType,
+
input_size,
+
inputs: Optional[InputBlock] = None,
+
device=None,
+
task_block: Optional[BlockType] = None,
+
pre=None,
+
):
+
"""
+
The method will be called when block is converted to a model,
+
i.e when linked to prediction head.
+
Parameters
+
----------
+
block:
+
the model block to link with head
+
device:
+
set the device for the metrics and layers of the task
+
"""
+
+
if task_block:
+
# TODO: What to do when `self.task_block is not None`?
+
self.task_block = task_block
+
if pre:
+
# TODO: What to do when `self.pre is not None`?
+
self.pre = pre
+
+
# Build task block
+
pre_input_size = input_size
+
if self.task_block:
+
if isinstance(self.task_block, torch.nn.Module):
+
self.task_block = copy.deepcopy(self.task_block)
+
else:
+
self.task_block = self.task_block.build(input_size)
+
pre_input_size = self.task_block.output_size() # type: ignore
+
+
if self.pre:
+
if isinstance(self.pre, torch.nn.Module):
+
self.pre = copy.deepcopy(self.pre)
+
else:
+
self.pre = self.pre.build(pre_input_size)
+
+
if device:
+
self.to(device)
+
for metric in self.metrics:
+
metric.to(device)
+
self.built = True
+
+
[docs] def forward(
+
self,
+
inputs: torch.Tensor,
+
targets: torch.Tensor = None,
+
training: bool = False,
+
testing: bool = False,
+
):
+
x = inputs
+
+
if len(x.size()) == 3 and self.summary_type:
+
x = self.sequence_summary(x)
+
+
if self.task_block:
+
x = self.task_block(x) # type: ignore
+
+
if self.pre:
+
x = self.pre(x) # type: ignore
+
+
if training or testing:
+
# add support of computing the loss inside the forward
+
# and return a dictionary as standard output
+
if self.summary_type is None:
+
if targets.dim() != 2:
+
raise ValueError(
+
"If `summary_type==None`, targets are expected to be a 2D tensor, "
+
f"but got a tensor with shape {targets.shape}"
+
)
+
+
loss = self.loss(x, target=targets)
+
return {"loss": loss, "labels": targets, "predictions": x}
+
+
return x
+
+
@property
+
def task_name(self):
+
if self._task_name:
+
return self._task_name
+
+
base_name = camelcase_to_snakecase(self.__class__.__name__)
+
+
return name_fn(self.target_name, base_name) if self.target_name else base_name
+
+
[docs] def child_name(self, name):
+
return name_fn(self.task_name, name)
+
+
[docs] def set_metrics(self, metrics):
+
self.metrics = torch.nn.ModuleList(metrics)
+
+
[docs] def calculate_metrics( # type: ignore
+
self,
+
predictions: torch.Tensor,
+
targets: torch.Tensor,
+
) -> Dict[str, torch.Tensor]:
+
outputs = {}
+
+
predictions = self.forward_to_prediction_fn(cast(torch.Tensor, predictions))
+
+
from .prediction_task import BinaryClassificationTask
+
+
for metric in self.metrics:
+
if isinstance(metric, tuple(type(x) for x in BinaryClassificationTask.DEFAULT_METRICS)):
+
targets = cast(torch.Tensor, targets).int()
+
outputs[self.metric_name(metric)] = metric(predictions, targets)
+
+
return outputs
+
+
[docs] def compute_metrics(self, **kwargs):
+
return {self.metric_name(metric): metric.compute() for metric in self.metrics}
+
+
[docs] def metric_name(self, metric: tm.Metric) -> str:
+
return self.child_name(camelcase_to_snakecase(metric.__class__.__name__))
+
+
[docs] def reset_metrics(self):
+
for metric in self.metrics:
+
metric.reset()
+
+
[docs] def to_head(self, body, inputs=None, **kwargs) -> "Head":
+
return Head(body, self, inputs=inputs, **kwargs)
+
+
[docs] def to_model(self, body, inputs=None, **kwargs) -> "Model":
+
return Model(Head(body, self, inputs=inputs, **kwargs), **kwargs)
+
+
+[docs]class Head(torch.nn.Module, LossMixin, MetricsMixin):
+
"""Head of a Model, a head has a single body but could have multiple prediction-tasks.
+
Parameters
+
----------
+
body: Block
+
TODO
+
prediction_tasks: Union[List[PredictionTask], PredictionTask], optional
+
TODO
+
task_blocks
+
TODO
+
task_weights: List[float], optional
+
TODO
+
loss_reduction: str, default="mean"
+
TODO
+
inputs: TabularFeaturesType, optional
+
TODO
+
"""
+
+
def __init__(
+
self,
+
body: BlockBase,
+
prediction_tasks: Union[List[PredictionTask], PredictionTask],
+
task_blocks: Optional[Union[BlockType, Dict[str, BlockType]]] = None,
+
task_weights: Optional[List[float]] = None,
+
loss_reduction: str = "mean",
+
inputs: Optional[TabularFeaturesType] = None,
+
):
+
super().__init__()
+
self.body = body
+
self.loss_reduction = loss_reduction
+
self.prediction_task_dict = torch.nn.ModuleDict()
+
if prediction_tasks:
+
if not isinstance(prediction_tasks, list):
+
prediction_tasks = [prediction_tasks]
+
for i, task in enumerate(prediction_tasks):
+
self.prediction_task_dict[task.task_name] = task
+
+
self._task_weights = defaultdict(lambda: 1.0)
+
if task_weights:
+
for task, val in zip(cast(List[PredictionTask], prediction_tasks), task_weights):
+
self._task_weights[task.task_name] = val
+
+
self.build(inputs=inputs, task_blocks=task_blocks)
+
+
[docs] def build(self, inputs=None, device=None, task_blocks=None):
+
"""Build each prediction task that's part of the head.
+
Parameters
+
----------
+
body
+
inputs
+
device
+
task_blocks
+
"""
+
if not getattr(self.body, "output_size", lambda: None)():
+
raise ValueError(
+
"Can't infer output-size of the body, please provide "
+
"a `Block` with a output-size. You can wrap any torch.Module in a Block."
+
)
+
+
input_size = self.body.output_size()
+
+
if device:
+
self.to(device)
+
+
for name, task in self.prediction_task_dict.items():
+
task_block = task_blocks
+
if task_blocks and isinstance(task_blocks, dict) and name in task_blocks:
+
task_block = task_blocks[name]
+
task.build(self.body, input_size, inputs=inputs, device=device, task_block=task_block)
+
self.input_size = input_size
+
+
[docs] @classmethod
+
def from_schema(
+
cls,
+
schema: Schema,
+
body: BlockBase,
+
task_blocks: Optional[Union[BlockType, Dict[str, BlockType]]] = None,
+
task_weight_dict: Optional[Dict[str, float]] = None,
+
loss_reduction: str = "mean",
+
inputs: Optional[TabularFeaturesType] = None,
+
) -> "Head":
+
"""Instantiate a Head from a Schema through tagged targets.
+
Parameters
+
----------
+
schema: DatasetSchema
+
Schema to use for inferring all targets based on the tags.
+
body
+
task_blocks
+
task_weight_dict
+
loss_reduction
+
inputs
+
Returns
+
-------
+
Head
+
"""
+
task_weight_dict = task_weight_dict or {}
+
tasks: List[PredictionTask] = []
+
task_weights = []
+
+
from .prediction_task import BinaryClassificationTask, RegressionTask
+
+
for binary_target in schema.select_by_tag([Tags.BINARY, Tags.CLASSIFICATION]).column_names:
+
tasks.append(BinaryClassificationTask(binary_target))
+
task_weights.append(task_weight_dict.get(binary_target, 1.0))
+
+
for regression_target in schema.select_by_tag(Tags.REGRESSION).column_names:
+
tasks.append(RegressionTask(regression_target))
+
task_weights.append(task_weight_dict.get(regression_target, 1.0))
+
+
# TODO: Add multi-class classification here. Figure out how to get number of classes
+
+
return cls(
+
body,
+
tasks,
+
task_blocks=task_blocks,
+
task_weights=task_weights,
+
loss_reduction=loss_reduction,
+
inputs=inputs,
+
)
+
+
[docs] def pop_labels(self, inputs: TabularData) -> TabularData:
+
"""Pop the labels from the different prediction_tasks from the inputs.
+
Parameters
+
----------
+
inputs: TabularData
+
Input dictionary containing all targets.
+
Returns
+
-------
+
TabularData
+
"""
+
outputs = {}
+
for name in self.prediction_task_dict.keys():
+
outputs[name] = inputs.pop(name)
+
+
return outputs
+
+
[docs] def forward(
+
self,
+
body_outputs: Union[torch.Tensor, TabularData],
+
training: bool = False,
+
testing: bool = False,
+
targets: Union[torch.Tensor, TabularData] = None,
+
call_body: bool = False,
+
top_k: Optional[int] = None,
+
**kwargs,
+
) -> Union[torch.Tensor, TabularData]:
+
outputs = {}
+
+
from transformers4rec.torch.model.prediction_task import NextItemPredictionTask
+
+
if call_body:
+
body_outputs = self.body(body_outputs, training=training, testing=testing, **kwargs)
+
+
if training or testing:
+
losses = []
+
labels = {}
+
predictions = {}
+
for name, task in self.prediction_task_dict.items():
+
if isinstance(targets, dict):
+
label = targets.get(task.target_name, None)
+
else:
+
label = targets
+
if label is not None:
+
label = label.float()
+
task_output = task(
+
body_outputs, targets=label, training=training, testing=testing, **kwargs
+
)
+
labels[name] = task_output["labels"]
+
predictions[name] = task_output["predictions"]
+
losses.append(task_output["loss"] * self._task_weights[name])
+
loss_tensor = torch.stack(losses)
+
loss = getattr(loss_tensor, self.loss_reduction)()
+
outputs = {"loss": loss, "labels": labels, "predictions": predictions}
+
else:
+
for name, task in self.prediction_task_dict.items():
+
if isinstance(task, NextItemPredictionTask):
+
outputs[name] = task(
+
body_outputs,
+
targets=targets,
+
training=training,
+
testing=testing,
+
top_k=top_k,
+
**kwargs,
+
)
+
+
else:
+
outputs[name] = task(
+
body_outputs, targets=targets, training=training, testing=testing, **kwargs
+
)
+
+
return outputs
+
+
[docs] def calculate_metrics( # type: ignore
+
self,
+
predictions: Union[torch.Tensor, TabularData],
+
targets: Union[torch.Tensor, TabularData],
+
) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]]:
+
"""Calculate metrics of the task(s) set in the Head instance.
+
Parameters
+
----------
+
predictions: Union[torch.Tensor, TabularData]
+
The predictions tensors to use for calculate metrics.
+
They can be either a torch.Tensor if a single task is used or
+
a dictionary of torch.Tensor if multiple tasks are used. In the
+
second case, the dictionary is indexed by the tasks names.
+
targets:
+
The tensor or dictionary of targets to use for computing the metrics of
+
one or multiple tasks.
+
"""
+
metrics = {}
+
+
for name, task in self.prediction_task_dict.items():
+
label = targets
+
output = predictions
+
if isinstance(targets, dict):
+
# The labels are retrieved from the task's output
+
# and indexed by the task name.
+
label = targets[name]
+
if isinstance(predictions, dict):
+
output = predictions[name]
+
+
metrics.update(
+
task.calculate_metrics(
+
predictions=output,
+
targets=label,
+
)
+
)
+
+
return _output_metrics(metrics)
+
+
[docs] def compute_metrics(self, mode: str = None) -> Dict[str, Union[float, torch.Tensor]]:
+
def name_fn(x):
+
return "_".join([mode, x]) if mode else x
+
+
metrics = {
+
name_fn(name): task.compute_metrics()
+
for name, task in self.prediction_task_dict.items()
+
}
+
+
return _output_metrics(metrics)
+
+
[docs] def reset_metrics(self):
+
""""""
+
for task in self.prediction_task_dict.values():
+
task.reset_metrics()
+
+
@property
+
def task_blocks(self) -> Dict[str, Optional[BlockOrModule]]:
+
return {name: task.task_block for name, task in self.prediction_task_dict.items()}
+
+
[docs] def to_model(self, **kwargs) -> "Model":
+
"""Convert the head to a Model.
+
Returns
+
-------
+
Model
+
"""
+
+
return Model(self, **kwargs)
+
+
+[docs]class Model(torch.nn.Module, LossMixin, MetricsMixin):
+
def __init__(
+
self,
+
*head: Head,
+
head_weights: Optional[List[float]] = None,
+
head_reduction: str = "mean",
+
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
+
name: str = None,
+
max_sequence_length: Optional[int] = None,
+
top_k: Optional[int] = None,
+
):
+
"""Model class that can aggregate one or multiple heads.
+
Parameters
+
----------
+
head: Head
+
One or more heads of the model.
+
head_weights: List[float], optional
+
Weight-value to use for each head.
+
head_reduction: str, optional
+
How to reduce the losses into a single tensor when multiple heads are used.
+
optimizer: Type[torch.optim.Optimizer]
+
Optimizer-class to use during fitting
+
name: str, optional
+
Name of the model.
+
max_sequence_length: int, optional
+
The maximum sequence length supported by the model.
+
Used to truncate sequence inputs longer than this value.
+
top_k: int, optional
+
The number of items to return at the inference step once the model is deployed.
+
Default is None, which will return all items.
+
"""
+
if head_weights:
+
if not isinstance(head_weights, list):
+
raise ValueError("`head_weights` must be a list")
+
if not len(head_weights) == len(head):
+
raise ValueError(
+
"`head_weights` needs to have the same length " "as the number of heads"
+
)
+
+
super().__init__()
+
+
self.name = name
+
self.heads = torch.nn.ModuleList(head)
+
self.head_weights = head_weights or [1.0] * len(head)
+
self.head_reduction = head_reduction
+
self.optimizer = optimizer
+
self.max_sequence_length = max_sequence_length
+
self.top_k = top_k
+
+
[docs] def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs):
+
# Convert inputs to float32 which is the default type, expected by PyTorch
+
for name, val in inputs.items():
+
if torch.is_floating_point(val):
+
inputs[name] = val.to(torch.float32)
+
+
# pad ragged inputs
+
inputs = pad_inputs(inputs, self.max_sequence_length)
+
+
if isinstance(targets, dict) and len(targets) == 0:
+
# `pyarrow`` dataloader is returning {} instead of None
+
# TODO remove this code when `PyarraowDataLoader` is dropped
+
targets = None
+
+
# TODO: Optimize this
+
if training or testing:
+
losses = []
+
labels = {}
+
predictions = {}
+
for i, head in enumerate(self.heads):
+
head_output = head(
+
inputs,
+
call_body=True,
+
targets=targets,
+
training=training,
+
testing=testing,
+
**kwargs,
+
)
+
labels.update(head_output["labels"])
+
predictions.update(head_output["predictions"])
+
losses.append(head_output["loss"] * self.head_weights[i])
+
loss_tensor = torch.stack(losses)
+
loss = getattr(loss_tensor, self.head_reduction)()
+
if len(labels) == 1:
+
labels = list(labels.values())[0]
+
predictions = list(predictions.values())[0]
+
return {"loss": loss, "labels": labels, "predictions": predictions}
+
else:
+
outputs = {}
+
for head in self.heads:
+
outputs.update(
+
head(
+
inputs,
+
call_body=True,
+
targets=targets,
+
training=training,
+
testing=testing,
+
top_k=self.top_k,
+
**kwargs,
+
)
+
)
+
if len(outputs) == 1:
+
return list(outputs.values())[0]
+
+
return outputs
+
+
[docs] def calculate_metrics( # type: ignore
+
self,
+
predictions: Union[torch.Tensor, TabularData],
+
targets: Union[torch.Tensor, TabularData],
+
) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor]]:
+
"""Calculate metrics of the task(s) set in the Head instance.
+
Parameters
+
----------
+
predictions: Union[torch.Tensor, TabularData]
+
The predictions tensors returned by the model.
+
They can be either a torch.Tensor if a single task is used or
+
a dictionary of torch.Tensor if multiple heads/tasks are used. In the
+
second case, the dictionary is indexed by the tasks names.
+
targets:
+
The tensor or dictionary of targets returned by the model.
+
They are used for computing the metrics of one or multiple tasks.
+
"""
+
outputs = {}
+
for head in self.heads:
+
outputs.update(
+
head.calculate_metrics(
+
predictions,
+
targets,
+
)
+
)
+
+
return outputs
+
+
[docs] def compute_metrics(self, mode=None) -> Dict[str, Union[float, torch.Tensor]]:
+
metrics = {}
+
for head in self.heads:
+
metrics.update(head.compute_metrics(mode=mode))
+
+
return metrics
+
+
[docs] def reset_metrics(self):
+
for head in self.heads:
+
head.reset_metrics()
+
+
[docs] def to_lightning(self):
+
import pytorch_lightning as pl
+
+
parent_self = self
+
+
class BlockWithHeadLightning(pl.LightningModule):
+
def __init__(self):
+
super(BlockWithHeadLightning, self).__init__()
+
self.parent = parent_self
+
+
def forward(self, inputs, targets=None, training=False, testing=False, *args, **kwargs):
+
return self.parent(
+
inputs, targets=targets, training=training, testing=testing, *args, **kwargs
+
)
+
+
def training_step(self, batch, batch_idx, targets=None, training=True, testing=False):
+
loss = self.parent(*batch, targets=targets, training=training, testing=testing)[
+
"loss"
+
]
+
self.log("train_loss", loss)
+
+
return loss
+
+
def configure_optimizers(self):
+
optimizer = self.parent.optimizer(self.parent.parameters(), lr=1e-3)
+
+
return optimizer
+
+
return BlockWithHeadLightning()
+
+
[docs] def fit(
+
self,
+
dataloader,
+
optimizer=torch.optim.Adam,
+
eval_dataloader=None,
+
num_epochs=1,
+
amp=False,
+
train=True,
+
verbose=True,
+
compute_metric=True,
+
):
+
if isinstance(dataloader, torch.utils.data.DataLoader):
+
dataset = dataloader.dataset
+
else:
+
dataset = dataloader
+
+
if inspect.isclass(optimizer):
+
optimizer = optimizer(self.parameters())
+
+
self.train(mode=train)
+
epoch_losses = []
+
with torch.set_grad_enabled(mode=train):
+
for epoch in range(num_epochs):
+
losses = []
+
batch_iterator = enumerate(iter(dataset))
+
if verbose:
+
batch_iterator = tqdm(batch_iterator)
+
for batch_idx, (x, y) in batch_iterator:
+
if amp:
+
with torch.cuda.amp.autocast():
+
output = self(x, targets=y, training=True)
+
else:
+
output = self(x, targets=y, training=True)
+
losses.append(float(output["loss"]))
+
if compute_metric:
+
self.calculate_metrics(
+
output["predictions"],
+
targets=output["labels"],
+
)
+
if train:
+
optimizer.zero_grad()
+
output["loss"].backward()
+
optimizer.step()
+
if verbose:
+
print(self.compute_metrics(mode="train"))
+
if eval_dataloader:
+
print(self.evaluate(eval_dataloader, verbose=False))
+
epoch_losses.append(np.mean(losses))
+
+
return np.array(epoch_losses)
+
+
[docs] def evaluate(
+
self, dataloader, targets=None, training=False, testing=True, verbose=True, mode="eval"
+
):
+
if isinstance(dataloader, torch.utils.data.DataLoader):
+
dataset = dataloader.dataset
+
else:
+
dataset = dataloader
+
+
batch_iterator = enumerate(iter(dataset))
+
if verbose:
+
batch_iterator = tqdm(batch_iterator)
+
self.reset_metrics()
+
for batch_idx, (x, y) in batch_iterator:
+
output = self(x, targets=y, training=training, testing=testing)
+
self.calculate_metrics(
+
output["predictions"],
+
targets=output["labels"],
+
)
+
+
return self.compute_metrics(mode=mode)
+
+
def _get_name(self):
+
if self.name:
+
return self.name
+
+
return super(Model, self)._get_name()
+
+
@property
+
def input_schema(self):
+
# return the input schema given by the model
+
# loop over the heads to get input schemas
+
schemas = []
+
for head in self.heads:
+
schemas.append(head.body.inputs.schema)
+
if all(isinstance(s, Core_Schema) for s in schemas):
+
return sum(schemas, Core_Schema())
+
+
model_schema = sum(schemas, Schema())
+
+
# TODO: rework T4R to use Merlin Schemas.
+
# In the meantime, we convert model_schema to merlin core schema
+
core_schema = Core_Schema()
+
for column in model_schema:
+
name = column.name
+
+
dtype = {0: np.float32, 2: np.int64, 3: np.float32}[column.type]
+
tags = column.tags
+
dims = None
+
if column.value_count.max > 0:
+
dims = (None, (column.value_count.min, column.value_count.max))
+
int_domain = {"min": column.int_domain.min, "max": column.int_domain.max}
+
properties = {
+
"int_domain": int_domain,
+
}
+
+
col_schema = ColumnSchema(
+
name, dtype=dtype, tags=tags, properties=properties, dims=dims
+
)
+
core_schema[name] = col_schema
+
return core_schema
+
+
@property
+
def output_schema(self):
+
from merlin.schema import Tags
+
+
from .prediction_task import BinaryClassificationTask, RegressionTask
+
+
# if the model has one head with one task, the output is a tensor
+
# if multiple heads and/or multiple prediction task, the output is a dictionary
+
output_cols = []
+
for head in self.heads:
+
dims = None
+
for name, task in head.prediction_task_dict.items():
+
target_dim = task.target_dim
+
int_domain = {"min": target_dim, "max": target_dim}
+
if (
+
isinstance(task, (BinaryClassificationTask, RegressionTask))
+
and not task.summary_type
+
):
+
dims = (None, (1, None))
+
elif (
+
isinstance(task, (BinaryClassificationTask, RegressionTask))
+
and task.summary_type
+
):
+
dims = (None,)
+
else:
+
dims = (None, task.target_dim)
+
properties = {
+
"int_domain": int_domain,
+
}
+
# in case one sets top_k at the inference step we return two outputs
+
if self.top_k:
+
# be sure categ item-id dtype in model.input schema and output schema matches
+
col_name = self.input_schema.select_by_tag(Tags.ITEM_ID).column_names[0]
+
col_dtype = (
+
self.input_schema.select_by_tag(Tags.ITEM_ID)
+
.column_schemas[col_name]
+
.dtype.name
+
)
+
col_schema_scores = ColumnSchema(
+
"item_id_scores", dtype=np.float32, properties=properties, dims=dims
+
)
+
col_schema_ids = ColumnSchema(
+
"item_ids", dtype=np.dtype(col_dtype), properties=properties, dims=dims
+
)
+
output_cols.append(col_schema_scores)
+
output_cols.append(col_schema_ids)
+
else:
+
col_schema = ColumnSchema(
+
name, dtype=np.float32, properties=properties, dims=dims
+
)
+
output_cols.append(col_schema)
+
+
return Core_Schema(output_cols)
+
+
@property
+
def prediction_tasks(self):
+
return [task for head in self.heads for task in list(head.prediction_task_dict.values())]
+
+
[docs] def save(self, path: Union[str, os.PathLike], model_name="t4rec_model_class"):
+
"""Saves the model to f"{export_path}/{model_name}.pkl" using `cloudpickle`
+
Parameters
+
----------
+
path : Union[str, os.PathLike]
+
Path to the directory where the T4Rec model should be saved.
+
model_name : str, optional
+
the name given to the pickle file storing the T4Rec model,
+
by default 't4rec_model_class'
+
"""
+
try:
+
import cloudpickle
+
except ImportError:
+
raise ValueError("cloudpickle is required to save model class")
+
+
export_path = pathlib.Path(path)
+
export_path.mkdir(exist_ok=True)
+
+
model_name = model_name + ".pkl"
+
export_path = export_path / model_name
+
with open(export_path, "wb") as out:
+
cloudpickle.dump(self, out)
+
+
[docs] @classmethod
+
def load(cls, path: Union[str, os.PathLike], model_name="t4rec_model_class") -> "Model":
+
"""Loads a T4Rec model that was saved with `model.save()`.
+
Parameters
+
----------
+
path : Union[str, os.PathLike]
+
Path to the directory where the T4Rec model is saved.
+
model_name : str, optional
+
the name given to the pickle file storing the T4Rec model,
+
by default 't4rec_model_class'.
+
"""
+
try:
+
import cloudpickle
+
except ImportError:
+
raise ValueError("cloudpickle is required to load T4Rec model")
+
+
export_path = pathlib.Path(path)
+
model_name = model_name + ".pkl"
+
export_path = export_path / model_name
+
return cloudpickle.load(open(export_path, "rb"))
+
+
+def _output_metrics(metrics):
+ # If there is only a single head with metrics, returns just those metrics
+ if len(metrics) == 1 and isinstance(metrics[list(metrics.keys())[0]], dict):
+ return metrics[list(metrics.keys())[0]]
+
+ return metrics
+