Skip to content

Commit 48b29c6

Browse files
committed
condense even more, only need to change a single line for new datasets
1 parent 1a7782d commit 48b29c6

File tree

2 files changed

+94
-177
lines changed

2 files changed

+94
-177
lines changed

tests/unit/datasets/datasets_test.py

+94-19
Original file line numberDiff line numberDiff line change
@@ -20,51 +20,126 @@
2020
"""Test public dataset definitions."""
2121
from __future__ import annotations
2222

23+
from pathlib import Path
24+
2325
import pytest
2426

2527
import pymovements as pm
2628

2729

2830
@pytest.mark.parametrize(
29-
('definition_class', 'dataset_name'),
31+
('public_dataset', 'dataset_name'),
3032
[
31-
pytest.param(pm.datasets.ToyDataset, 'ToyDataset', id='ToyDataset'),
3233
pytest.param(pm.datasets.GazeBase, 'GazeBase', id='GazeBase'),
3334
pytest.param(pm.datasets.GazeBaseVR, 'GazeBaseVR', id='GazeBaseVR'),
3435
pytest.param(pm.datasets.GazeOnFaces, 'GazeOnFaces', id='GazeOnFaces'),
3536
pytest.param(pm.datasets.HBN, 'HBN', id='HBN'),
3637
pytest.param(pm.datasets.JuDo1000, 'JuDo1000', id='JuDo1000'),
3738
pytest.param(pm.datasets.PoTeC, 'PoTeC', id='PoTeC'),
3839
pytest.param(pm.datasets.SBSAT, 'SBSAT', id='SBSAT'),
40+
pytest.param(pm.datasets.ToyDataset, 'ToyDataset', id='ToyDataset'),
41+
pytest.param(pm.datasets.ToyDatasetEyeLink, 'ToyDatasetEyeLink', id='ToyDatasetEyeLink'),
3942
],
4043
)
41-
def test_public_dataset_registered(definition_class, dataset_name):
42-
assert dataset_name in pm.DatasetLibrary.definitions
43-
assert pm.DatasetLibrary.get(dataset_name) == definition_class
44-
assert pm.DatasetLibrary.get(dataset_name)().name == dataset_name
45-
44+
@pytest.mark.parametrize(
45+
('dataset_path'),
46+
[
47+
pytest.param(
48+
None,
49+
id='dataset_path_None',
50+
),
51+
pytest.param(
52+
'.',
53+
id='dataset_path_dot',
54+
),
55+
pytest.param(
56+
'dataset_path',
57+
id='dataset_path_dataset',
58+
),
59+
],
60+
)
61+
@pytest.mark.parametrize(
62+
('downloads'),
63+
[
64+
pytest.param(
65+
'downloads',
66+
id='downloads_None',
67+
),
68+
pytest.param(
69+
'custom_downloads',
70+
id='downloads_custom_downloads',
71+
),
4672
73+
],
74+
)
4775
@pytest.mark.parametrize(
48-
'dataset_definition_class',
76+
('str_root'),
4977
[
50-
pytest.param(pm.datasets.ToyDataset, id='ToyDataset'),
51-
pytest.param(pm.datasets.GazeBase, id='GazeBase'),
52-
pytest.param(pm.datasets.GazeBaseVR, id='GazeBaseVR'),
53-
pytest.param(pm.datasets.GazeOnFaces, id='GazeOnFaces'),
54-
pytest.param(pm.datasets.HBN, id='HBN'),
55-
pytest.param(pm.datasets.JuDo1000, id='JuDo1000'),
56-
pytest.param(pm.datasets.PoTeC, id='PoTeC'),
57-
pytest.param(pm.datasets.SBSAT, id='SBSAT'),
78+
pytest.param(
79+
True,
80+
id='path_str',
81+
),
82+
pytest.param(
83+
False,
84+
id='path_DatasetPaths',
85+
),
5886
],
5987
)
60-
def test_public_dataset_registered_correct_attributes(dataset_definition_class):
61-
dataset_definition = dataset_definition_class()
88+
def test_public_dataset_registered(public_dataset, dataset_name, dataset_path, downloads, str_root):
89+
assert dataset_name in pm.DatasetLibrary.definitions
90+
assert pm.DatasetLibrary.get(dataset_name) == public_dataset
91+
assert pm.DatasetLibrary.get(dataset_name)().name == dataset_name
6292

93+
dataset_definition = public_dataset()
6394
registered_definition = pm.DatasetLibrary.get(dataset_definition.name)()
64-
6595
assert dataset_definition.mirrors == registered_definition.mirrors
6696
assert dataset_definition.resources == registered_definition.resources
6797
assert dataset_definition.experiment == registered_definition.experiment
6898
assert dataset_definition.filename_format == registered_definition.filename_format
6999
assert dataset_definition.filename_format_dtypes == registered_definition.filename_format_dtypes
70100
assert dataset_definition.custom_read_kwargs == registered_definition.custom_read_kwargs
101+
102+
dataset, expected_paths = construct_public_dataset(
103+
public_dataset,
104+
dataset_path,
105+
downloads,
106+
str_root,
107+
)
108+
assert dataset.paths.root == expected_paths['root']
109+
assert dataset.path == expected_paths['dataset']
110+
assert dataset.paths.dataset == expected_paths['dataset']
111+
assert dataset.paths.downloads == expected_paths['downloads']
112+
113+
114+
def construct_public_dataset(
115+
public_dataset,
116+
dataset_path,
117+
downloads,
118+
str_root,
119+
):
120+
expected = {}
121+
expected['root'] = Path('/data/set/path')
122+
123+
if str_root:
124+
init_path = '/data/set/path'
125+
expected['dataset'] = Path('/data/set/path')
126+
expected['downloads'] = Path('/data/set/path/downloads')
127+
128+
dataset = pm.Dataset(public_dataset, path=init_path)
129+
return dataset, expected
130+
init_path = pm.DatasetPaths(
131+
root='/data/set/path',
132+
dataset=dataset_path,
133+
downloads=downloads,
134+
)
135+
136+
if dataset_path == '.':
137+
expected['dataset'] = Path('/data/set/path')
138+
elif dataset_path == 'dataset_path':
139+
expected['dataset'] = Path('/data/set/path/dataset_path')
140+
else:
141+
expected['dataset'] = Path(f'/data/set/path/{public_dataset.__name__}')
142+
expected['downloads'] = expected['dataset'] / Path(downloads)
143+
144+
dataset = pm.Dataset(public_dataset, path=init_path)
145+
return dataset, expected

tests/unit/datasets/public_datasets_test.py

-158
This file was deleted.

0 commit comments

Comments
 (0)