Skip to content

Commit

Permalink
feat(dataset): Allow range-objects for subset argument (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkrako authored May 30, 2023
1 parent eab9804 commit 79ec59c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
8 changes: 5 additions & 3 deletions docs/source/tutorials/pymovements-in-10-minutes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@
"The last two columns refer to the pixel coordinates at the timestep specified by `time`.\n",
"\n",
"\n",
"We are also able to just take a subset of the data by specifying values of the fileinfo columns:\n"
"We are also able to just take a subset of the data by specifying values of the fileinfo columns.\n",
"The key refers to the column in the `fileinfo` dataframe.\n",
"The values in the dictionary can be of type `bool`, `int`, `float` or `str`, but also lists and ranges \n"
]
},
{
Expand All @@ -214,8 +216,8 @@
"outputs": [],
"source": [
"subset = {\n",
" 'text_id': [0],\n",
" 'page_id': [0, 1, 2],\n",
" 'text_id': 0,\n",
" 'page_id': range(3),\n",
"}\n",
"dataset.load(subset=subset)\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions src/pymovements/dataset/dataset_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,12 @@ def take_subset(

if isinstance(subset_value, (bool, float, int, str)):
column_values = [subset_value]
elif isinstance(subset_value, (list, tuple)):
elif isinstance(subset_value, (list, tuple, range)):
column_values = subset_value
else:
raise TypeError(
f'subset value must be of type bool, float, int, str or a list of these but'
f' key-value pair {subset_key}: {subset_value} is of type {type(subset_value)}',
f'subset values must be of type bool, float, int, str, range, or list, '
f'but value of pair {subset_key}: {subset_value} is of type {type(subset_value)}',
)

fileinfo = fileinfo.filter(pl.col(subset_key).is_in(column_values))
Expand Down
9 changes: 7 additions & 2 deletions tests/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,17 @@ def test_load_correct_event_dfs(dataset_configuration):
pytest.param(
{'subject_id': 1},
[0],
id='subset_key_not_in_fileinfo',
id='subset_int',
),
pytest.param(
{'subject_id': [1, 11, 12]},
[0, 2, 3],
id='subset_key_not_in_fileinfo',
id='subset_list',
),
pytest.param(
{'subject_id': range(3)},
[0, 11],
id='subset_range',
),
],
)
Expand Down

0 comments on commit 79ec59c

Please sign in to comment.