Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ragged inputs to model #666

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/unit/torch/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,46 @@ def test_simple_model(torch_tabular_features, torch_tabular_data):
assert all(loss.min() >= 0 and loss.max() <= 1 for loss in losses)


@pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask])
def test_sequential_prediction_model_with_ragged_inputs(
torch_yoochoose_tabular_transformer_features, torch_yoochoose_like, task
):
inputs = torch_yoochoose_tabular_transformer_features

transformer_config = tconf.XLNetConfig.build(
d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
body = tr.SequentialBlock(inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config))

head_1 = tr.Head(
body,
tr.NextItemPredictionTask(weight_tying=True),
inputs=inputs,
)
head_2 = task("target", summary_type="mean").to_head(body, inputs)

bc_targets = torch.randint(2, (100,)).float()

model = tr.Model(head_1, head_2)
output = model(torch_yoochoose_like, training=True, targets=bc_targets)

assert isinstance(output, dict)
assert len(list(output.keys())) == 3
assert len(list(output["predictions"])) == 2
assert set(list(output.keys())) == set(["loss", "labels", "predictions"])

# test inference inputs with only one item
inference_inputs = tr.data.tabular_sequence_testing_data.torch_synthetic_data(
num_rows=10, min_session_length=1, max_session_length=4, ragged=True
)
_ = model(inference_inputs)

inference_inputs_2 = tr.data.tabular_sequence_testing_data.torch_synthetic_data(
num_rows=20, min_session_length=1, max_session_length=10, ragged=True
)
_ = model(inference_inputs_2)


@pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask])
def test_sequential_prediction_model(
torch_yoochoose_tabular_transformer_features, torch_yoochoose_like, task
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/utils/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,38 @@ def test_pad_values_offsets_dict():
)


def test_pad_with_truncation():
data = [[1, 2], [], [3, 4, 5, 4, 7]]
values, offsets = _get_values_offsets(data)

x = {
"a__values": values,
"a__offsets": offsets,
"b": torch.tensor([[1, 2, 3, 4], [6, 7, 8, 9]]),
}

padded_x = pad_batch(x, {"a": 3, "b": 2})
assert torch.equal(
padded_x["a"],
torch.tensor(
[
[1, 2, 0],
[0, 0, 0],
[3, 4, 5],
]
),
)
assert torch.equal(
padded_x["b"],
torch.tensor(
[
[1, 2],
[6, 7],
]
),
)


def test_pad_values_dense():
data = [[1, 2], [], [3, 4, 5]]
values, offsets = _get_values_offsets(data)
Expand Down
3 changes: 2 additions & 1 deletion transformers4rec/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def merlin_schema(self) -> CoreSchema:
return TensorflowMetadata.from_json(self.schema.to_json()).to_merlin_schema()

def torch_synthetic_data(
self, num_rows=100, min_session_length=5, max_session_length=20, device=None
self, num_rows=100, min_session_length=5, max_session_length=20, device=None, ragged=False
):
from transformers4rec.torch.utils import schema_utils

Expand All @@ -51,6 +51,7 @@ def torch_synthetic_data(
min_session_length=min_session_length,
max_session_length=max_session_length,
device=device,
ragged=ragged,
)

def tf_synthetic_data(self, num_rows=100, min_session_length=5, max_session_length=20):
Expand Down
42 changes: 42 additions & 0 deletions transformers4rec/torch/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ..features.base import InputBlock
from ..features.sequence import TabularFeaturesType
from ..typing import TabularData
from ..utils.padding import pad_batch
from ..utils.torch_utils import LossMixin, MetricsMixin


Expand Down Expand Up @@ -481,6 +482,7 @@ def __init__(
head_reduction: str = "mean",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
name: str = None,
max_sequence_length: Optional[int] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a max_sequence_length to limit the size of the padding when receiving ragged inputs.

):
"""Model class that can aggregate one or multiple heads.
Parameters
Expand All @@ -495,6 +497,9 @@ def __init__(
Optimizer-class to use during fitting
name: str, optional
Name of the model.
max_sequence_length : int, optional
The maximum sequence length supported by the model.
Used to truncate sequence inputs longer than this value.
"""
if head_weights:
if not isinstance(head_weights, list):
Expand All @@ -511,13 +516,50 @@ def __init__(
self.head_weights = head_weights or [1.0] * len(head)
self.head_reduction = head_reduction
self.optimizer = optimizer
self.max_sequence_length = max_sequence_length

def pad_inputs(self, inputs):
"""Pad ragged inputs to dense tensors with max sequence length

Parameters
----------
inputs : Dict[str, Tensor]
Dictionary of tensors

Returns
-------
inputs : Dict[str, Tensor]
Padded inputs
"""
batch_max_sequence_length = 0
for name, val in inputs.items():
if name.endswith("__offsets"):
max_row_length = int(torch.max(val[1:] - val[:-1]))
batch_max_sequence_length = max(max_row_length, batch_max_sequence_length)

padding_sequence_length = batch_max_sequence_length
if self.max_sequence_length is not None:
padding_sequence_length = min(self.max_sequence_length, batch_max_sequence_length)

if padding_sequence_length:
padding_lengths = {}
for name in inputs.keys():
if name.endswith("__offsets"):
padding_lengths[name[:-9]] = padding_sequence_length
oliverholworthy marked this conversation as resolved.
Show resolved Hide resolved
if padding_lengths:
inputs = pad_batch(inputs, padding_lengths)

return inputs

def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs):
# Convert inputs to float32 which is the default type, expected by PyTorch
for name, val in inputs.items():
if torch.is_floating_point(val):
inputs[name] = val.to(torch.float32)

# pad ragged inputs
inputs = self.pad_inputs(inputs)

if isinstance(targets, dict) and len(targets) == 0:
# `pyarrow`` dataloader is returning {} instead of None
# TODO remove this code when `PyarraowDataLoader` is dropped
Expand Down
5 changes: 3 additions & 2 deletions transformers4rec/torch/utils/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def _pad_ragged_tensor(values, offsets, padding_length):
offsets = _squeeze(offsets)
num_rows = len(offsets) - 1
diff_offsets = offsets[1:] - offsets[:-1]
max_length = int(diff_offsets.max())
indices = _get_indices(offsets, diff_offsets)
sparse_tensor = torch.sparse_coo_tensor(
indices.T, values, torch.Size([num_rows, padding_length]), device=values.device
indices.T, values, torch.Size([num_rows, max_length]), device=values.device
)
return sparse_tensor.to_dense()
return _pad_dense_tensor(sparse_tensor.to_dense(), padding_length)


Batch = Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]
Expand Down
20 changes: 13 additions & 7 deletions transformers4rec/torch/utils/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def random_data_from_schema(
max_session_length: Optional[int] = None,
min_session_length: int = 5,
device=None,
ragged=False,
) -> TabularData:
data: Dict[str, Any] = {}

Expand Down Expand Up @@ -87,13 +88,18 @@ def random_data_from_schema(
for key, val in data.items():
if isinstance(val, tuple):
offsets = [0]
for length in val[1][:-1]:
offsets.append(offsets[-1] + length)
vals = (val[0], torch.tensor(offsets, device=device).unsqueeze(dim=1))
values, offsets, diff_offsets, num_rows = _pull_values_offsets(vals, device=device)
indices = _get_indices(offsets, diff_offsets, device=device)
seq_limit = max_session_length or val[1][0]
outputs[key] = _get_sparse_tensor(values, indices, num_rows, seq_limit)
for row_length in val[1]:
offsets.append(offsets[-1] + row_length)

if ragged:
outputs[f"{key}__values"] = val[0]
outputs[f"{key}__offsets"] = torch.tensor(offsets, device=device)
else:
vals = (val[0], torch.tensor(offsets[:-1], device=device).unsqueeze(dim=1))
values, offsets, diff_offsets, num_rows = _pull_values_offsets(vals, device=device)
indices = _get_indices(offsets, diff_offsets, device=device)
seq_limit = max_session_length or val[1][0]
outputs[key] = _get_sparse_tensor(values, indices, num_rows, seq_limit)
else:
outputs[key] = data[key]

Expand Down