|
20 | 20 | """Test public dataset definitions."""
|
21 | 21 | from __future__ import annotations
|
22 | 22 |
|
| 23 | +from pathlib import Path |
| 24 | + |
23 | 25 | import pytest
|
24 | 26 |
|
25 | 27 | import pymovements as pm
|
26 | 28 |
|
27 | 29 |
|
28 | 30 | @pytest.mark.parametrize(
|
29 |
| - ('definition_class', 'dataset_name'), |
| 31 | + ('public_dataset', 'dataset_name'), |
30 | 32 | [
|
31 |
| - pytest.param(pm.datasets.ToyDataset, 'ToyDataset', id='ToyDataset'), |
32 | 33 | pytest.param(pm.datasets.GazeBase, 'GazeBase', id='GazeBase'),
|
33 | 34 | pytest.param(pm.datasets.GazeBaseVR, 'GazeBaseVR', id='GazeBaseVR'),
|
34 | 35 | pytest.param(pm.datasets.GazeOnFaces, 'GazeOnFaces', id='GazeOnFaces'),
|
35 | 36 | pytest.param(pm.datasets.HBN, 'HBN', id='HBN'),
|
36 | 37 | pytest.param(pm.datasets.JuDo1000, 'JuDo1000', id='JuDo1000'),
|
37 | 38 | pytest.param(pm.datasets.PoTeC, 'PoTeC', id='PoTeC'),
|
38 | 39 | pytest.param(pm.datasets.SBSAT, 'SBSAT', id='SBSAT'),
|
| 40 | + pytest.param(pm.datasets.ToyDataset, 'ToyDataset', id='ToyDataset'), |
| 41 | + pytest.param(pm.datasets.ToyDatasetEyeLink, 'ToyDatasetEyeLink', id='ToyDatasetEyeLink'), |
39 | 42 | ],
|
40 | 43 | )
|
41 |
| -def test_public_dataset_registered(definition_class, dataset_name): |
42 |
| - assert dataset_name in pm.DatasetLibrary.definitions |
43 |
| - assert pm.DatasetLibrary.get(dataset_name) == definition_class |
44 |
| - assert pm.DatasetLibrary.get(dataset_name)().name == dataset_name |
45 |
| - |
| 44 | +@pytest.mark.parametrize( |
| 45 | + ('dataset_path'), |
| 46 | + [ |
| 47 | + pytest.param( |
| 48 | + None, |
| 49 | + id='dataset_path_None', |
| 50 | + ), |
| 51 | + pytest.param( |
| 52 | + '.', |
| 53 | + id='dataset_path_dot', |
| 54 | + ), |
| 55 | + pytest.param( |
| 56 | + 'dataset_path', |
| 57 | + id='dataset_path_dataset', |
| 58 | + ), |
| 59 | + ], |
| 60 | +) |
| 61 | +@pytest.mark.parametrize( |
| 62 | + ('downloads'), |
| 63 | + [ |
| 64 | + pytest.param( |
| 65 | + 'downloads', |
| 66 | + id='downloads_None', |
| 67 | + ), |
| 68 | + pytest.param( |
| 69 | + 'custom_downloads', |
| 70 | + id='downloads_custom_downloads', |
| 71 | + ), |
46 | 72 |
|
| 73 | + ], |
| 74 | +) |
47 | 75 | @pytest.mark.parametrize(
|
48 |
| - 'dataset_definition_class', |
| 76 | + ('str_root'), |
49 | 77 | [
|
50 |
| - pytest.param(pm.datasets.ToyDataset, id='ToyDataset'), |
51 |
| - pytest.param(pm.datasets.GazeBase, id='GazeBase'), |
52 |
| - pytest.param(pm.datasets.GazeBaseVR, id='GazeBaseVR'), |
53 |
| - pytest.param(pm.datasets.GazeOnFaces, id='GazeOnFaces'), |
54 |
| - pytest.param(pm.datasets.HBN, id='HBN'), |
55 |
| - pytest.param(pm.datasets.JuDo1000, id='JuDo1000'), |
56 |
| - pytest.param(pm.datasets.PoTeC, id='PoTeC'), |
57 |
| - pytest.param(pm.datasets.SBSAT, id='SBSAT'), |
| 78 | + pytest.param( |
| 79 | + True, |
| 80 | + id='path_str', |
| 81 | + ), |
| 82 | + pytest.param( |
| 83 | + False, |
| 84 | + id='path_DatasetPaths', |
| 85 | + ), |
58 | 86 | ],
|
59 | 87 | )
|
60 |
| -def test_public_dataset_registered_correct_attributes(dataset_definition_class): |
61 |
| - dataset_definition = dataset_definition_class() |
| 88 | +def test_public_dataset_registered(public_dataset, dataset_name, dataset_path, downloads, str_root): |
| 89 | + assert dataset_name in pm.DatasetLibrary.definitions |
| 90 | + assert pm.DatasetLibrary.get(dataset_name) == public_dataset |
| 91 | + assert pm.DatasetLibrary.get(dataset_name)().name == dataset_name |
62 | 92 |
|
| 93 | + dataset_definition = public_dataset() |
63 | 94 | registered_definition = pm.DatasetLibrary.get(dataset_definition.name)()
|
64 |
| - |
65 | 95 | assert dataset_definition.mirrors == registered_definition.mirrors
|
66 | 96 | assert dataset_definition.resources == registered_definition.resources
|
67 | 97 | assert dataset_definition.experiment == registered_definition.experiment
|
68 | 98 | assert dataset_definition.filename_format == registered_definition.filename_format
|
69 | 99 | assert dataset_definition.filename_format_dtypes == registered_definition.filename_format_dtypes
|
70 | 100 | assert dataset_definition.custom_read_kwargs == registered_definition.custom_read_kwargs
|
| 101 | + |
| 102 | + dataset, expected_paths = construct_public_dataset( |
| 103 | + public_dataset, |
| 104 | + dataset_path, |
| 105 | + downloads, |
| 106 | + str_root, |
| 107 | + ) |
| 108 | + assert dataset.paths.root == expected_paths['root'] |
| 109 | + assert dataset.path == expected_paths['dataset'] |
| 110 | + assert dataset.paths.dataset == expected_paths['dataset'] |
| 111 | + assert dataset.paths.downloads == expected_paths['downloads'] |
| 112 | + |
| 113 | + |
| 114 | +def construct_public_dataset( |
| 115 | + public_dataset, |
| 116 | + dataset_path, |
| 117 | + downloads, |
| 118 | + str_root, |
| 119 | +): |
| 120 | + expected = {} |
| 121 | + expected['root'] = Path('/data/set/path') |
| 122 | + |
| 123 | + if str_root: |
| 124 | + init_path = '/data/set/path' |
| 125 | + expected['dataset'] = Path('/data/set/path') |
| 126 | + expected['downloads'] = Path('/data/set/path/downloads') |
| 127 | + |
| 128 | + dataset = pm.Dataset(public_dataset, path=init_path) |
| 129 | + return dataset, expected |
| 130 | + init_path = pm.DatasetPaths( |
| 131 | + root='/data/set/path', |
| 132 | + dataset=dataset_path, |
| 133 | + downloads=downloads, |
| 134 | + ) |
| 135 | + |
| 136 | + if dataset_path == '.': |
| 137 | + expected['dataset'] = Path('/data/set/path') |
| 138 | + elif dataset_path == 'dataset_path': |
| 139 | + expected['dataset'] = Path('/data/set/path/dataset_path') |
| 140 | + else: |
| 141 | + expected['dataset'] = Path(f'/data/set/path/{public_dataset.__name__}') |
| 142 | + expected['downloads'] = expected['dataset'] / Path(downloads) |
| 143 | + |
| 144 | + dataset = pm.Dataset(public_dataset, path=init_path) |
| 145 | + return dataset, expected |
0 commit comments