Skip to content

Commit

Permalink
Adjust the device used in synthetic data generation (#486)
Browse files Browse the repository at this point in the history
* Update precommit package versions

* Adjust device for synthetic data generation
  • Loading branch information
karlhigley authored Sep 9, 2022
1 parent 021d6a9 commit bcc9392
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 19 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.6.4
rev: 5.10.1
hooks:
- id: isort
additional_dependencies: [toml]
- repo: https://github.com/python/black
rev: 20.8b1
rev: 22.8.0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.910'
rev: 'v0.971'
hooks:
- id: mypy
language_version: python3
args: [--no-strict-optional, --ignore-missing-imports, --show-traceback, --install-types, --non-interactive]
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
rev: v2.2.1
hooks:
- id: codespell
# - repo: https://github.com/mgedmin/check-manifest
Expand Down
24 changes: 24 additions & 0 deletions tests/data/testing/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from transformers4rec.data.dataset import ParquetDataset
from transformers4rec.data.testing.dataset import tabular_sequence_testing_data

Expand All @@ -9,3 +11,25 @@ def test_tabular_sequence_testing_data():
"transformers4rec/data/testing/schema.json"
)
assert len(tabular_sequence_testing_data.schema) == 22

torch_yoochoose_like = tabular_sequence_testing_data.torch_synthetic_data(
num_rows=100, min_session_length=5, max_session_length=20
)

t4r_yoochoose_schema = tabular_sequence_testing_data.schema

non_matching_dtypes = {}
for column in t4r_yoochoose_schema:
name = column.name
column_dtype = column.type
schema_dtype = {0: np.float32, 2: np.int64, 3: np.float32}[column_dtype]

value = torch_yoochoose_like[name]
value_dtype = value.numpy().dtype

if schema_dtype != value_dtype:
non_matching_dtypes[name] = (column_dtype, schema_dtype, value_dtype)

assert (
len(non_matching_dtypes) == 0
), f"Found columns whose dtype does not match schema: {non_matching_dtypes}"
5 changes: 4 additions & 1 deletion transformers4rec/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ def __init__(self, schema_path: str):
def schema(self) -> Schema:
return self._schema

def torch_synthetic_data(self, num_rows=100, min_session_length=5, max_session_length=20):
def torch_synthetic_data(
self, num_rows=100, min_session_length=5, max_session_length=20, device=None
):
from transformers4rec.torch.utils import schema_utils

return schema_utils.random_data_from_schema(
self.schema,
num_rows=num_rows,
min_session_length=min_session_length,
max_session_length=max_session_length,
device=device,
)

def tf_synthetic_data(self, num_rows=100, min_session_length=5, max_session_length=20):
Expand Down
27 changes: 14 additions & 13 deletions transformers4rec/torch/utils/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def random_data_from_schema(
num_rows: int,
max_session_length: Optional[int] = None,
min_session_length: int = 5,
device=None,
) -> TabularData:
data: Dict[str, Any] = {}

Expand All @@ -49,16 +50,16 @@ def random_data_from_schema(
max_num = feature.int_domain.max
if is_list_feature:
list_length = session_length or feature.value_count.max
row = torch.randint(1, max_num, (list_length,))
row = torch.randint(1, max_num, (list_length,), device=device)

else:
row = torch.randint(1, max_num, tuple(shape))
row = torch.randint(1, max_num, tuple(shape), device=device)
else:
if is_list_feature:
list_length = session_length or feature.value_count.max
row = torch.rand((list_length,))
row = torch.rand((list_length,), device=device)
else:
row = torch.rand(tuple(shape))
row = torch.rand(tuple(shape), device=device)

if is_list_feature:
row = (row, [len(row)]) # type: ignore
Expand Down Expand Up @@ -88,9 +89,9 @@ def random_data_from_schema(
offsets = [0]
for length in val[1][:-1]:
offsets.append(offsets[-1] + length)
vals = (val[0], torch.tensor(offsets).unsqueeze(dim=1))
values, offsets, diff_offsets, num_rows = _pull_values_offsets(vals)
indices = _get_indices(offsets, diff_offsets)
vals = (val[0], torch.tensor(offsets, device=device).unsqueeze(dim=1))
values, offsets, diff_offsets, num_rows = _pull_values_offsets(vals, device=device)
indices = _get_indices(offsets, diff_offsets, device=device)
seq_limit = max_session_length or val[1][0]
outputs[key] = _get_sparse_tensor(values, indices, num_rows, seq_limit)
else:
Expand All @@ -99,25 +100,25 @@ def random_data_from_schema(
return outputs


def _pull_values_offsets(values_offset):
def _pull_values_offsets(values_offset, device=None):
# pull_values_offsets, return values offsets diff_offsets
if isinstance(values_offset, tuple):
values = values_offset[0].flatten()
offsets = values_offset[1].flatten()
else:
values = values_offset.flatten()
offsets = torch.arange(values.size()[0])
offsets = torch.arange(values.size()[0], device=device)
num_rows = len(offsets)
offsets = torch.cat([offsets, torch.tensor([len(values)])])
offsets = torch.cat([offsets, torch.tensor([len(values)], device=device)])
diff_offsets = offsets[1:] - offsets[:-1]
return values, offsets, diff_offsets, num_rows


def _get_indices(offsets, diff_offsets):
row_ids = torch.arange(len(offsets) - 1)
def _get_indices(offsets, diff_offsets, device=None):
row_ids = torch.arange(len(offsets) - 1, device=device)
row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets)
row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets)
col_ids = torch.arange(len(row_offset_repeated)) - row_offset_repeated
col_ids = torch.arange(len(row_offset_repeated), device=device) - row_offset_repeated
indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], axis=1)
return indices

Expand Down

0 comments on commit bcc9392

Please sign in to comment.