From 89f0be14f1855f9ed7f34713cff619b92a5dbee3 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 10:53:17 +0100 Subject: [PATCH 01/15] Update padding to handle truncation to smaller sequence length --- tests/unit/utils/test_padding.py | 32 +++++++++++++++++++++++++ transformers4rec/torch/utils/padding.py | 5 ++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/tests/unit/utils/test_padding.py b/tests/unit/utils/test_padding.py index 1aec66dc50..063dad4c24 100644 --- a/tests/unit/utils/test_padding.py +++ b/tests/unit/utils/test_padding.py @@ -68,6 +68,38 @@ def test_pad_values_offsets_dict(): ) +def test_pad_with_truncation(): + data = [[1, 2], [], [3, 4, 5, 4, 7]] + values, offsets = _get_values_offsets(data) + + x = { + "a__values": values, + "a__offsets": offsets, + "b": torch.tensor([[1, 2, 3, 4], [6, 7, 8, 9]]), + } + + padded_x = pad_batch(x, {"a": 3, "b": 2}) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0], + [0, 0, 0], + [3, 4, 5], + ] + ), + ) + assert torch.equal( + padded_x["b"], + torch.tensor( + [ + [1, 2], + [6, 7], + ] + ), + ) + + def test_pad_values_dense(): data = [[1, 2], [], [3, 4, 5]] values, offsets = _get_values_offsets(data) diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index c86ffefe72..28347d1304 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -47,11 +47,12 @@ def _pad_ragged_tensor(values, offsets, padding_length): offsets = _squeeze(offsets) num_rows = len(offsets) - 1 diff_offsets = offsets[1:] - offsets[:-1] + max_length = int(diff_offsets.max()) indices = _get_indices(offsets, diff_offsets) sparse_tensor = torch.sparse_coo_tensor( - indices.T, values, torch.Size([num_rows, padding_length]), device=values.device + indices.T, values, torch.Size([num_rows, max_length]), device=values.device ) - return sparse_tensor.to_dense() + return _pad_dense_tensor(sparse_tensor.to_dense(), padding_length) Batch = Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] From 11c02ef567edccb1176f7b226cc428b178d4f962 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 10:57:05 +0100 Subject: [PATCH 02/15] Add ragged argument to enable returning ragged representation --- transformers4rec/data/dataset.py | 3 ++- transformers4rec/torch/utils/schema_utils.py | 20 +++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/transformers4rec/data/dataset.py b/transformers4rec/data/dataset.py index 8744f1ce44..bb081f0f00 100644 --- a/transformers4rec/data/dataset.py +++ b/transformers4rec/data/dataset.py @@ -41,7 +41,7 @@ def merlin_schema(self) -> CoreSchema: return TensorflowMetadata.from_json(self.schema.to_json()).to_merlin_schema() def torch_synthetic_data( - self, num_rows=100, min_session_length=5, max_session_length=20, device=None + self, num_rows=100, min_session_length=5, max_session_length=20, device=None, ragged=False ): from transformers4rec.torch.utils import schema_utils @@ -51,6 +51,7 @@ def torch_synthetic_data( min_session_length=min_session_length, max_session_length=max_session_length, device=device, + ragged=ragged, ) def tf_synthetic_data(self, num_rows=100, min_session_length=5, max_session_length=20): diff --git a/transformers4rec/torch/utils/schema_utils.py b/transformers4rec/torch/utils/schema_utils.py index c5a2f93c0f..cf3d9d7f4f 100644 --- a/transformers4rec/torch/utils/schema_utils.py +++ b/transformers4rec/torch/utils/schema_utils.py @@ -31,6 +31,7 @@ def random_data_from_schema( max_session_length: Optional[int] = None, min_session_length: int = 5, device=None, + ragged=False, ) -> TabularData: data: Dict[str, Any] = {} @@ -87,13 +88,18 @@ def random_data_from_schema( for key, val in data.items(): if isinstance(val, tuple): offsets = [0] - for length in val[1][:-1]: - offsets.append(offsets[-1] + length) - vals = (val[0], torch.tensor(offsets, device=device).unsqueeze(dim=1)) - values, offsets, diff_offsets, num_rows = _pull_values_offsets(vals, device=device) - indices = _get_indices(offsets, diff_offsets, device=device) - seq_limit = max_session_length or val[1][0] - outputs[key] = _get_sparse_tensor(values, indices, num_rows, seq_limit) + for row_length in val[1]: + offsets.append(offsets[-1] + row_length) + + if ragged: + outputs[f"{key}__values"] = val[0] + outputs[f"{key}__offsets"] = torch.tensor(offsets, device=device) + else: + vals = (val[0], torch.tensor(offsets[:-1], device=device).unsqueeze(dim=1)) + values, offsets, diff_offsets, num_rows = _pull_values_offsets(vals, device=device) + indices = _get_indices(offsets, diff_offsets, device=device) + seq_limit = max_session_length or val[1][0] + outputs[key] = _get_sparse_tensor(values, indices, num_rows, seq_limit) else: outputs[key] = data[key] From d86e638d1910275cc48710377f9daab02dca9c7f Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 10:57:37 +0100 Subject: [PATCH 03/15] Add support for ragged inputs in model and add test for model --- tests/unit/torch/model/test_model.py | 40 ++++++++++++++++++++++++++++ transformers4rec/torch/model/base.py | 16 +++++++++++ 2 files changed, 56 insertions(+) diff --git a/tests/unit/torch/model/test_model.py b/tests/unit/torch/model/test_model.py index fa562edb90..716a5d9bde 100644 --- a/tests/unit/torch/model/test_model.py +++ b/tests/unit/torch/model/test_model.py @@ -55,6 +55,46 @@ def test_simple_model(torch_tabular_features, torch_tabular_data): assert all(loss.min() >= 0 and loss.max() <= 1 for loss in losses) +@pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask]) +def test_sequential_prediction_model_with_ragged_inputs( + torch_yoochoose_tabular_transformer_features, torch_yoochoose_like, task +): + inputs = torch_yoochoose_tabular_transformer_features + + transformer_config = tconf.XLNetConfig.build( + d_model=64, n_head=4, n_layer=2, total_seq_length=20 + ) + body = tr.SequentialBlock(inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config)) + + head_1 = tr.Head( + body, + tr.NextItemPredictionTask(weight_tying=True), + inputs=inputs, + ) + head_2 = task("target", summary_type="mean").to_head(body, inputs) + + bc_targets = torch.randint(2, (100,)).float() + + model = tr.Model(head_1, head_2) + output = model(torch_yoochoose_like, training=True, targets=bc_targets) + + assert isinstance(output, dict) + assert len(list(output.keys())) == 3 + assert len(list(output["predictions"])) == 2 + assert set(list(output.keys())) == set(["loss", "labels", "predictions"]) + + # test inference inputs with only one item + inference_inputs = tr.data.tabular_sequence_testing_data.torch_synthetic_data( + num_rows=10, min_session_length=1, max_session_length=4, ragged=True + ) + _ = model(inference_inputs) + + inference_inputs_2 = tr.data.tabular_sequence_testing_data.torch_synthetic_data( + num_rows=20, min_session_length=1, max_session_length=10, ragged=True + ) + _ = model(inference_inputs_2) + + @pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask]) def test_sequential_prediction_model( torch_yoochoose_tabular_transformer_features, torch_yoochoose_like, task diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index ef8aa12bd0..b180adf54e 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -37,6 +37,7 @@ from ..features.base import InputBlock from ..features.sequence import TabularFeaturesType from ..typing import TabularData +from ..utils.padding import pad_batch from ..utils.torch_utils import LossMixin, MetricsMixin @@ -518,6 +519,21 @@ def forward(self, inputs: TabularData, targets=None, training=False, testing=Fal if torch.is_floating_point(val): inputs[name] = val.to(torch.float32) + # Pad ragged inputs + max_sequence_length = 0 + for name, val in inputs.items(): + if name.endswith("__offsets"): + max_row_length = int(torch.max(val[1:] - val[:-1])) + max_sequence_length = max(max_row_length, max_sequence_length) + + if max_sequence_length: + padding_lengths = {} + for name in inputs.keys(): + if name.endswith("__offsets"): + padding_lengths[name[:-9]] = max_sequence_length + if padding_lengths: + inputs = pad_batch(inputs, padding_lengths) + if isinstance(targets, dict) and len(targets) == 0: # `pyarrow`` dataloader is returning {} instead of None # TODO remove this code when `PyarraowDataLoader` is dropped From c867dee2af5e8deae5e342805ac848803f841135 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 11:16:54 +0100 Subject: [PATCH 04/15] Add `max_sequence_length` to Model and move input padding to method --- transformers4rec/torch/model/base.py | 46 ++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index b180adf54e..2bd38f0eaf 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -482,6 +482,7 @@ def __init__( head_reduction: str = "mean", optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, name: str = None, + max_sequence_length: Optional[int] = None, ): """Model class that can aggregate one or multiple heads. Parameters @@ -496,6 +497,9 @@ def __init__( Optimizer-class to use during fitting name: str, optional Name of the model. + max_sequence_length : int, optional + The maximum sequence length supported by the model. + Used to truncate sequence inputs longer than this value. """ if head_weights: if not isinstance(head_weights, list): @@ -512,28 +516,50 @@ def __init__( self.head_weights = head_weights or [1.0] * len(head) self.head_reduction = head_reduction self.optimizer = optimizer + self.max_sequence_length = max_sequence_length - def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs): - # Convert inputs to float32 which is the default type, expected by PyTorch - for name, val in inputs.items(): - if torch.is_floating_point(val): - inputs[name] = val.to(torch.float32) + def pad_inputs(self, inputs): + """Pad ragged inputs to dense tensors with max sequence length - # Pad ragged inputs - max_sequence_length = 0 + Parameters + ---------- + inputs : Dict[str, Tensor] + Dictionary of tensors + + Returns + ------- + inputs : Dict[str, Tensor] + Padded inputs + """ + batch_max_sequence_length = 0 for name, val in inputs.items(): if name.endswith("__offsets"): max_row_length = int(torch.max(val[1:] - val[:-1])) - max_sequence_length = max(max_row_length, max_sequence_length) + batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) + + padding_sequence_length = batch_max_sequence_length + if self.max_sequence_length is not None: + padding_sequence_length = min(self.max_sequence_length, batch_max_sequence_length) - if max_sequence_length: + if padding_sequence_length: padding_lengths = {} for name in inputs.keys(): if name.endswith("__offsets"): - padding_lengths[name[:-9]] = max_sequence_length + padding_lengths[name[:-9]] = padding_sequence_length if padding_lengths: inputs = pad_batch(inputs, padding_lengths) + return inputs + + def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs): + # Convert inputs to float32 which is the default type, expected by PyTorch + for name, val in inputs.items(): + if torch.is_floating_point(val): + inputs[name] = val.to(torch.float32) + + # pad ragged inputs + inputs = self.pad_inputs(inputs) + if isinstance(targets, dict) and len(targets) == 0: # `pyarrow`` dataloader is returning {} instead of None # TODO remove this code when `PyarraowDataLoader` is dropped From 919a570acae082d50bd4b3eeaa67ea9f2a449f27 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 11:26:59 +0100 Subject: [PATCH 05/15] Use len("__offsets") to get feature name Co-authored-by: Marc Romeyn --- transformers4rec/torch/model/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index 2bd38f0eaf..c077991d0d 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -545,7 +545,7 @@ def pad_inputs(self, inputs): padding_lengths = {} for name in inputs.keys(): if name.endswith("__offsets"): - padding_lengths[name[:-9]] = padding_sequence_length + padding_lengths[name[:-len("__offsets")]] = padding_sequence_length if padding_lengths: inputs = pad_batch(inputs, padding_lengths) From eec9df0cc713aa0c63cee7b49691110ce7e93c9b Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 11:41:18 +0100 Subject: [PATCH 06/15] Reformat padding_lengths line --- transformers4rec/torch/model/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index c077991d0d..d970a57588 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -545,7 +545,7 @@ def pad_inputs(self, inputs): padding_lengths = {} for name in inputs.keys(): if name.endswith("__offsets"): - padding_lengths[name[:-len("__offsets")]] = padding_sequence_length + padding_lengths[name[: -len("__offsets")]] = padding_sequence_length if padding_lengths: inputs = pad_batch(inputs, padding_lengths) From cc1d6fbcf45d8c8e6461dcde50729dcbd3150e50 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 14:16:10 +0100 Subject: [PATCH 07/15] Add torch.jit.script decorator to pad_batch --- tests/unit/utils/test_padding.py | 19 ---------- transformers4rec/torch/utils/padding.py | 47 ++++++++++++++----------- 2 files changed, 26 insertions(+), 40 deletions(-) diff --git a/tests/unit/utils/test_padding.py b/tests/unit/utils/test_padding.py index 063dad4c24..4c41a79532 100644 --- a/tests/unit/utils/test_padding.py +++ b/tests/unit/utils/test_padding.py @@ -30,25 +30,6 @@ def _get_values_offsets(data): return torch.tensor(values), torch.tensor(offsets) -def test_pad_values_offsets_tuple(): - data = [[1, 2], [], [3, 4, 5]] - values, offsets = _get_values_offsets(data) - - x = {"a": (values, offsets)} - - padded_x = pad_batch(x, {"a": 5}) - assert torch.equal( - padded_x["a"], - torch.tensor( - [ - [1, 2, 0, 0, 0], - [0, 0, 0, 0, 0], - [3, 4, 5, 0, 0], - ] - ), - ) - - def test_pad_values_offsets_dict(): data = [[1, 2], [], [3, 4, 5]] values, offsets = _get_values_offsets(data) diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index 28347d1304..c35b1b5092 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -13,36 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, Optional, Tuple, Union +from typing import Dict import torch import torch.nn.functional as F -from merlin.table import TensorTable -def _pad_dense_tensor(t: torch.Tensor, length: Optional[int]) -> torch.Tensor: - if length and len(t.shape) == 2: +def _pad_dense_tensor(t: torch.Tensor, length: int) -> torch.Tensor: + if len(t.shape) == 2: pad_diff = length - t.shape[1] return F.pad(input=t, pad=(0, pad_diff, 0, 0)) return t -def _squeeze(tensor): +def _squeeze(tensor: torch.Tensor): if len(tensor.shape) == 2: return tensor.squeeze(1) return tensor -def _get_indices(offsets, diff_offsets): +def _get_indices(offsets: torch.Tensor, diff_offsets: torch.Tensor): row_ids = torch.arange(len(offsets) - 1, device=offsets.device) row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) col_ids = torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated - indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1) + indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], dim=1) return indices -def _pad_ragged_tensor(values, offsets, padding_length): +def _pad_ragged_tensor(values: torch.Tensor, offsets: torch.Tensor, padding_length: int): values = _squeeze(values) offsets = _squeeze(offsets) num_rows = len(offsets) - 1 @@ -55,15 +54,15 @@ def _pad_ragged_tensor(values, offsets, padding_length): return _pad_dense_tensor(sparse_tensor.to_dense(), padding_length) -Batch = Dict[str, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] - - -def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: +@torch.jit.script +def pad_batch( + batch: Dict[str, torch.Tensor], padding_lengths: Dict[str, int] +) -> Dict[str, torch.Tensor]: """Pad list features in a batch to padding length specified Parameters ---------- - X : Batch + batch : Batch dictionary of tensors in batch padding_lengths : Dict[str, int] dictionary mapping list column name to padding length @@ -78,15 +77,15 @@ def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: ValueError If ragged column found with no padding length provided """ - if batch is None or not isinstance(batch, dict): - return batch - batch_padded = {} - for col_name, col in TensorTable(batch).items(): - if col.offsets is not None: + for key, value in batch.items(): + if key.endswith("__offsets"): + col_name = key[: -len("__offsets")] padding_length = padding_lengths.get(col_name) - if padding_length: - padded_values = _pad_ragged_tensor(col.values, col.offsets, padding_length) + if padding_length is not None: + padded_values = _pad_ragged_tensor( + batch[f"{col_name}__values"], value, padding_length + ) batch_padded[col_name] = padded_values else: # Note: This exception can be removed if the model is @@ -96,8 +95,14 @@ def pad_batch(batch: Batch, padding_lengths: Dict[str, int]) -> Batch: "Please provide a padding length for this feature " "to be converted to a dense tensor. " ) + elif key.endswith("__values"): + continue else: + col_name = key padding_length = padding_lengths.get(col_name) - batch_padded[col_name] = _pad_dense_tensor(col.values, padding_length) + if padding_length is not None: + batch_padded[col_name] = _pad_dense_tensor(value, padding_length) + else: + batch_padded[col_name] = value return batch_padded From ab4019ee3e8b552147016b4ad1605413c86b0be6 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 14:17:10 +0100 Subject: [PATCH 08/15] Move pad_inputs to function and make jit scriptable --- transformers4rec/torch/model/base.py | 72 +++++++++++++++------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index d970a57588..0a5b94b688 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -474,6 +474,43 @@ def to_model(self, **kwargs) -> "Model": return Model(self, **kwargs) +@torch.jit.script +def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int]): + """Pad ragged inputs to dense tensors with max sequence length + + Parameters + ---------- + inputs : Dict[str, Tensor] + Dictionary of tensors + + Returns + ------- + inputs : Dict[str, Tensor] + Padded inputs + """ + batch_max_sequence_length = 0 + for key, val in inputs.items(): + if key.endswith("__offsets"): + offsets = val + max_row_length = int(torch.max(offsets[1:] - offsets[:-1])) + batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) + + padding_sequence_length = batch_max_sequence_length + if max_sequence_length is not None: + padding_sequence_length = min(max_sequence_length, batch_max_sequence_length) + + if padding_sequence_length > 0: + padding_lengths: Dict[str, int] = {} + for key in inputs.keys(): + if key.endswith("__offsets"): + col_name: str = key[: -len("__offsets")] + padding_lengths[col_name] = padding_sequence_length + if padding_lengths: + inputs = pad_batch(inputs, padding_lengths) + + return inputs + + class Model(torch.nn.Module, LossMixin, MetricsMixin): def __init__( self, @@ -518,39 +555,6 @@ def __init__( self.optimizer = optimizer self.max_sequence_length = max_sequence_length - def pad_inputs(self, inputs): - """Pad ragged inputs to dense tensors with max sequence length - - Parameters - ---------- - inputs : Dict[str, Tensor] - Dictionary of tensors - - Returns - ------- - inputs : Dict[str, Tensor] - Padded inputs - """ - batch_max_sequence_length = 0 - for name, val in inputs.items(): - if name.endswith("__offsets"): - max_row_length = int(torch.max(val[1:] - val[:-1])) - batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) - - padding_sequence_length = batch_max_sequence_length - if self.max_sequence_length is not None: - padding_sequence_length = min(self.max_sequence_length, batch_max_sequence_length) - - if padding_sequence_length: - padding_lengths = {} - for name in inputs.keys(): - if name.endswith("__offsets"): - padding_lengths[name[: -len("__offsets")]] = padding_sequence_length - if padding_lengths: - inputs = pad_batch(inputs, padding_lengths) - - return inputs - def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs): # Convert inputs to float32 which is the default type, expected by PyTorch for name, val in inputs.items(): @@ -558,7 +562,7 @@ def forward(self, inputs: TabularData, targets=None, training=False, testing=Fal inputs[name] = val.to(torch.float32) # pad ragged inputs - inputs = self.pad_inputs(inputs) + inputs = pad_inputs(inputs, self.max_sequence_length) if isinstance(targets, dict) and len(targets) == 0: # `pyarrow`` dataloader is returning {} instead of None From 8bafdf5871b41e69efb417805c4584781bd9b2f0 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 14:18:05 +0100 Subject: [PATCH 09/15] Only call pad_batch if is a dict of tensors --- transformers4rec/torch/utils/data_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index 30a587817f..d33621f57d 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -372,7 +372,10 @@ def __init__( def _get_pad_fn(padding_lengths): def pad_fn(x, y): new_x = pad_batch(x, padding_lengths) - new_y = pad_batch(y, padding_lengths) + if y is not None and isinstance(y, dict): + new_y = pad_batch(y, padding_lengths) + else: + new_y = y return new_x, new_y return pad_fn From ee28762356407060a2330bb74637024e5726797f Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 14:18:42 +0100 Subject: [PATCH 10/15] Add test of model tracing to ragged inputs test --- tests/unit/torch/model/test_model.py | 51 +++++++++++++--------------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/tests/unit/torch/model/test_model.py b/tests/unit/torch/model/test_model.py index 716a5d9bde..073b47e272 100644 --- a/tests/unit/torch/model/test_model.py +++ b/tests/unit/torch/model/test_model.py @@ -55,44 +55,39 @@ def test_simple_model(torch_tabular_features, torch_tabular_data): assert all(loss.min() >= 0 and loss.max() <= 1 for loss in losses) -@pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask]) -def test_sequential_prediction_model_with_ragged_inputs( - torch_yoochoose_tabular_transformer_features, torch_yoochoose_like, task -): - inputs = torch_yoochoose_tabular_transformer_features - - transformer_config = tconf.XLNetConfig.build( - d_model=64, n_head=4, n_layer=2, total_seq_length=20 +def test_sequential_prediction_model_with_ragged_inputs(torch_yoochoose_like, yoochoose_schema): + input_module = tr.TabularSequenceFeatures.from_schema( + yoochoose_schema, + max_sequence_length=20, + d_output=64, + masking="causal", ) - body = tr.SequentialBlock(inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config)) - - head_1 = tr.Head( - body, - tr.NextItemPredictionTask(weight_tying=True), - inputs=inputs, + prediction_task = tr.NextItemPredictionTask(weight_tying=True) + transformer_config = tconf.XLNetConfig.build( + d_model=64, n_head=8, n_layer=2, total_seq_length=20 ) - head_2 = task("target", summary_type="mean").to_head(body, inputs) - - bc_targets = torch.randint(2, (100,)).float() + model = transformer_config.to_torch_model(input_module, prediction_task) - model = tr.Model(head_1, head_2) - output = model(torch_yoochoose_like, training=True, targets=bc_targets) + _ = model(torch_yoochoose_like) - assert isinstance(output, dict) - assert len(list(output.keys())) == 3 - assert len(list(output["predictions"])) == 2 - assert set(list(output.keys())) == set(["loss", "labels", "predictions"]) + model.eval() - # test inference inputs with only one item inference_inputs = tr.data.tabular_sequence_testing_data.torch_synthetic_data( num_rows=10, min_session_length=1, max_session_length=4, ragged=True ) - _ = model(inference_inputs) - inference_inputs_2 = tr.data.tabular_sequence_testing_data.torch_synthetic_data( - num_rows=20, min_session_length=1, max_session_length=10, ragged=True + num_rows=10, min_session_length=1, max_session_length=10, ragged=True ) - _ = model(inference_inputs_2) + model_output = model(inference_inputs) + + # if model is traced with ragged inputs it must be called with ragged inputs + traced_model = torch.jit.trace(model, inference_inputs, strict=False) + traced_model_output = traced_model(inference_inputs) + assert torch.equal(model_output, traced_model_output) + + model_output = model(inference_inputs_2) + traced_model_output = traced_model(inference_inputs_2) + assert torch.equal(model_output, traced_model_output) @pytest.mark.parametrize("task", [tr.BinaryClassificationTask, tr.RegressionTask]) From 92925b23661d5f528a0c5ae444231f2782bf38e5 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 17:40:34 +0100 Subject: [PATCH 11/15] Move `pad_inputs` into padding module --- transformers4rec/torch/model/base.py | 39 +------------------------ transformers4rec/torch/utils/padding.py | 39 ++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index 0a5b94b688..4ae7683e50 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -37,7 +37,7 @@ from ..features.base import InputBlock from ..features.sequence import TabularFeaturesType from ..typing import TabularData -from ..utils.padding import pad_batch +from ..utils.padding import pad_inputs from ..utils.torch_utils import LossMixin, MetricsMixin @@ -474,43 +474,6 @@ def to_model(self, **kwargs) -> "Model": return Model(self, **kwargs) -@torch.jit.script -def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int]): - """Pad ragged inputs to dense tensors with max sequence length - - Parameters - ---------- - inputs : Dict[str, Tensor] - Dictionary of tensors - - Returns - ------- - inputs : Dict[str, Tensor] - Padded inputs - """ - batch_max_sequence_length = 0 - for key, val in inputs.items(): - if key.endswith("__offsets"): - offsets = val - max_row_length = int(torch.max(offsets[1:] - offsets[:-1])) - batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) - - padding_sequence_length = batch_max_sequence_length - if max_sequence_length is not None: - padding_sequence_length = min(max_sequence_length, batch_max_sequence_length) - - if padding_sequence_length > 0: - padding_lengths: Dict[str, int] = {} - for key in inputs.keys(): - if key.endswith("__offsets"): - col_name: str = key[: -len("__offsets")] - padding_lengths[col_name] = padding_sequence_length - if padding_lengths: - inputs = pad_batch(inputs, padding_lengths) - - return inputs - - class Model(torch.nn.Module, LossMixin, MetricsMixin): def __init__( self, diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index c35b1b5092..2ea1a15809 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict +from typing import Dict, Optional import torch import torch.nn.functional as F @@ -106,3 +106,40 @@ def pad_batch( batch_padded[col_name] = value return batch_padded + + +@torch.jit.script +def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int]): + """Pad ragged inputs to dense tensors with max sequence length + + Parameters + ---------- + inputs : Dict[str, Tensor] + Dictionary of tensors + + Returns + ------- + inputs : Dict[str, Tensor] + Padded inputs + """ + batch_max_sequence_length = 0 + for key, val in inputs.items(): + if key.endswith("__offsets"): + offsets = val + max_row_length = int(torch.max(offsets[1:] - offsets[:-1])) + batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) + + padding_sequence_length = batch_max_sequence_length + if max_sequence_length is not None: + padding_sequence_length = min(max_sequence_length, batch_max_sequence_length) + + if padding_sequence_length > 0: + padding_lengths: Dict[str, int] = {} + for key in inputs.keys(): + if key.endswith("__offsets"): + col_name: str = key[: -len("__offsets")] + padding_lengths[col_name] = padding_sequence_length + if padding_lengths: + inputs = pad_batch(inputs, padding_lengths) + + return inputs From b3b81d5790a67a743dc5c79f84d296884670c9a9 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 17:54:49 +0100 Subject: [PATCH 12/15] update example inputs to demonstrate different batch size --- tests/unit/torch/model/test_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/torch/model/test_model.py b/tests/unit/torch/model/test_model.py index 073b47e272..1f16c2f251 100644 --- a/tests/unit/torch/model/test_model.py +++ b/tests/unit/torch/model/test_model.py @@ -76,11 +76,12 @@ def test_sequential_prediction_model_with_ragged_inputs(torch_yoochoose_like, yo num_rows=10, min_session_length=1, max_session_length=4, ragged=True ) inference_inputs_2 = tr.data.tabular_sequence_testing_data.torch_synthetic_data( - num_rows=10, min_session_length=1, max_session_length=10, ragged=True + num_rows=20, min_session_length=1, max_session_length=10, ragged=True ) model_output = model(inference_inputs) - # if model is traced with ragged inputs it must be called with ragged inputs + # if the model is traced with ragged inputs it can only be called with ragged inputs + # if the model is traced with padded inputs it can only be called with padded inputs traced_model = torch.jit.trace(model, inference_inputs, strict=False) traced_model_output = traced_model(inference_inputs) assert torch.equal(model_output, traced_model_output) From 29fa413a10e62978be9507fbf5b026fc9d3bec51 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 18:12:31 +0100 Subject: [PATCH 13/15] Update docstring for pad_inputs --- transformers4rec/torch/utils/padding.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/transformers4rec/torch/utils/padding.py b/transformers4rec/torch/utils/padding.py index 2ea1a15809..ec10945c14 100644 --- a/transformers4rec/torch/utils/padding.py +++ b/transformers4rec/torch/utils/padding.py @@ -109,13 +109,18 @@ def pad_batch( @torch.jit.script -def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int]): - """Pad ragged inputs to dense tensors with max sequence length +def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None): + """Pad ragged inputs to fixed size tensors. + + Pads all the sequence features in the inputs to the same length. + The minimum of max_sequence_length and the maximum sequence length in the inputs. Parameters ---------- inputs : Dict[str, Tensor] Dictionary of tensors + max_sequence_length: int, optional + The maximum sequence length to limit sequence features to Returns ------- From 08c79d69b2fa9e8064d3eb70bbb3d2d8bddd8ffd Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 18:17:37 +0100 Subject: [PATCH 14/15] Add tests for pad_inputs --- tests/unit/utils/test_padding.py | 204 +++++++++++++++++++------------ 1 file changed, 125 insertions(+), 79 deletions(-) diff --git a/tests/unit/utils/test_padding.py b/tests/unit/utils/test_padding.py index 4c41a79532..949036f2a1 100644 --- a/tests/unit/utils/test_padding.py +++ b/tests/unit/utils/test_padding.py @@ -17,7 +17,7 @@ import torch -from transformers4rec.torch.utils.padding import pad_batch +from transformers4rec.torch.utils.padding import pad_batch, pad_inputs def _get_values_offsets(data): @@ -30,81 +30,127 @@ def _get_values_offsets(data): return torch.tensor(values), torch.tensor(offsets) -def test_pad_values_offsets_dict(): - data = [[1, 2], [], [3, 4, 5]] - values, offsets = _get_values_offsets(data) - - x = {"a__values": values, "a__offsets": offsets} - - padded_x = pad_batch(x, {"a": 7}) - assert torch.equal( - padded_x["a"], - torch.tensor( - [ - [1, 2, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [3, 4, 5, 0, 0, 0, 0], - ] - ), - ) - - -def test_pad_with_truncation(): - data = [[1, 2], [], [3, 4, 5, 4, 7]] - values, offsets = _get_values_offsets(data) - - x = { - "a__values": values, - "a__offsets": offsets, - "b": torch.tensor([[1, 2, 3, 4], [6, 7, 8, 9]]), - } - - padded_x = pad_batch(x, {"a": 3, "b": 2}) - assert torch.equal( - padded_x["a"], - torch.tensor( - [ - [1, 2, 0], - [0, 0, 0], - [3, 4, 5], - ] - ), - ) - assert torch.equal( - padded_x["b"], - torch.tensor( - [ - [1, 2], - [6, 7], - ] - ), - ) - - -def test_pad_values_dense(): - data = [[1, 2], [], [3, 4, 5]] - values, offsets = _get_values_offsets(data) - - x = {"a__values": values, "a__offsets": offsets, "b": torch.tensor([[3, 6], [4, 1], [8, 4]])} - - padded_x = pad_batch(x, {"a": 7, "b": 3}) - assert torch.equal( - padded_x["a"], - torch.tensor( - [ - [1, 2, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [3, 4, 5, 0, 0, 0, 0], - ] - ), - ) - assert torch.equal( - padded_x["b"], - torch.tensor( - [ - [3, 6, 0], - [4, 1, 0], - [8, 4, 0], - ] - ), - ) +class TestPadBatch: + def test_pad_values_offsets(self): + data = [[1, 2], [], [3, 4, 5]] + values, offsets = _get_values_offsets(data) + + x = {"a__values": values, "a__offsets": offsets} + + padded_x = pad_batch(x, {"a": 7}) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [3, 4, 5, 0, 0, 0, 0], + ] + ), + ) + + def test_with_truncation(self): + data = [[1, 2], [], [3, 4, 5, 4, 7]] + values, offsets = _get_values_offsets(data) + + x = { + "a__values": values, + "a__offsets": offsets, + "b": torch.tensor([[1, 2, 3, 4], [6, 7, 8, 9]]), + } + + padded_x = pad_batch(x, {"a": 3, "b": 2}) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0], + [0, 0, 0], + [3, 4, 5], + ] + ), + ) + assert torch.equal( + padded_x["b"], + torch.tensor( + [ + [1, 2], + [6, 7], + ] + ), + ) + + def test_ragged_and_dense(self): + data = [[1, 2], [], [3, 4, 5]] + values, offsets = _get_values_offsets(data) + + x = { + "a__values": values, + "a__offsets": offsets, + "b": torch.tensor([[3, 6], [4, 1], [8, 4]]), + } + + padded_x = pad_batch(x, {"a": 7, "b": 3}) + assert torch.equal( + padded_x["a"], + torch.tensor( + [ + [1, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [3, 4, 5, 0, 0, 0, 0], + ] + ), + ) + assert torch.equal( + padded_x["b"], + torch.tensor( + [ + [3, 6, 0], + [4, 1, 0], + [8, 4, 0], + ] + ), + ) + + +class TestPadInputs: + def test_ragged_inputs(self): + data = [[1, 2, 3, 4, 5], [6, 7, 8]] + values, offsets = _get_values_offsets(data) + + inputs = {"a__values": values, "a__offsets": offsets} + padded_inputs = pad_inputs(inputs) + assert torch.equal( + padded_inputs["a"], + torch.tensor( + [ + [1, 2, 3, 4, 5], + [6, 7, 8, 0, 0], + ] + ), + ) + + def test_with_max_sequence_length(self): + data = [[1, 2, 3, 4, 5], [6, 7, 8, 9]] + values, offsets = _get_values_offsets(data) + + inputs = {"a__values": values, "a__offsets": offsets, "b": torch.tensor([[3, 6], [4, 1]])} + padded_inputs = pad_inputs(inputs, max_sequence_length=3) + assert torch.equal( + padded_inputs["a"], + torch.tensor( + [ + [1, 2, 3], + [6, 7, 8], + ] + ), + ) + assert torch.equal( + padded_inputs["b"], + torch.tensor( + [ + [3, 6], + [4, 1], + ] + ), + ) From 686696c4c38bb6633af903f941df7dd2591d1f31 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 5 Apr 2023 18:24:48 +0100 Subject: [PATCH 15/15] Update test of pad_inputs to make test clearer --- tests/unit/utils/test_padding.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/unit/utils/test_padding.py b/tests/unit/utils/test_padding.py index 949036f2a1..028504117f 100644 --- a/tests/unit/utils/test_padding.py +++ b/tests/unit/utils/test_padding.py @@ -147,10 +147,5 @@ def test_with_max_sequence_length(self): ) assert torch.equal( padded_inputs["b"], - torch.tensor( - [ - [3, 6], - [4, 1], - ] - ), + inputs["b"], )