diff --git a/cornac/eval_methods/next_basket_evaluation.py b/cornac/eval_methods/next_basket_evaluation.py index 99ce57d9a..2c697d5a5 100644 --- a/cornac/eval_methods/next_basket_evaluation.py +++ b/cornac/eval_methods/next_basket_evaluation.py @@ -327,9 +327,13 @@ def _split(self): test_users = safe_indexing(users, test_idx) val_users = safe_indexing(users, val_idx) if len(val_idx) > 0 else None - train_data = [tup for tup in self._data if tup[0] in train_users] - val_data = [tup for tup in self._data if tup[0] in val_users] - test_data = [tup for tup in self._data if tup[0] in test_users] + data_by_user = OrderedDict() + for tup in self._data: + data_by_user.setdefault(tup[0], []) + data_by_user[tup[0]].append(tup) + train_data = [tup for u in train_users for tup in data_by_user[u]] + val_data = [tup for u in val_users for tup in data_by_user[u]] + test_data = [tup for u in test_users for tup in data_by_user[u]] self.build(train_data=train_data, test_data=test_data, val_data=val_data) def _build_datasets(self, train_data, test_data, val_data=None):