Skip to content

Commit

Permalink
Merge pull request #8 from artefactory/filter
Browse files Browse the repository at this point in the history
Filter
  • Loading branch information
VincentAuriau authored Jan 3, 2024
2 parents e429f2c + 3d77145 commit 818507f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
13 changes: 13 additions & 0 deletions choice_learn/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,3 +976,16 @@ def iter_batch(self, batch_size=None, shuffle=None, sample_weight=None):
# Special exit strategy for batch_size = -1
if batch_size == -1:
yielded_size += 2 * num_choices

def filter(self, bool_list):
"""Filter over sessions indexes following bool.
Parameters
----------
bool_list : list of boolean
list of booleans of length self.get_num_sessions() to filter sessions.
True to keep, False to discard.
"""
indexes = list(range(len(bool_list)))
indexes = [i for i, keep in zip(indexes, bool_list) if keep]
return self[indexes]
24 changes: 24 additions & 0 deletions tests/unit_tests/data/test_choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,27 @@ def test_iter_batch():
assert batch[3].shape[1] == 3
assert batch[4].shape[0] == 2 or batch[4].shape[0] == 1
assert batch_nb == 1

def test_filter():
"""Tests the filter method."""
dataset = ChoiceDataset(
items_features=items_features,
sessions_features=sessions_features,
sessions_items_features=sessions_items_features,
sessions_items_availabilities=sessions_items_availabilities,
choices=choices,
)
filtered_dataset = dataset.filter([True, False, True])
assert len(filtered_dataset) == 2
assert (filtered_dataset.items_features[0] == dataset.items_features[0]).all()
assert (filtered_dataset.sessions_features[0] == dataset.sessions_features[0][[0, 2]]).all()
assert (
filtered_dataset.sessions_items_features[0]
== dataset.sessions_items_features[0][[0, 2]]
).all()
assert (
filtered_dataset.sessions_items_availabilities
== dataset.sessions_items_availabilities[[0, 2]]
).all()
assert (filtered_dataset.choices == dataset.choices[[0, 2]]).all()
assert (filtered_dataset.choices == [0, 1]).all()

0 comments on commit 818507f

Please sign in to comment.