Skip to content

Commit

Permalink
Merge pull request #13 from artefactory/unmerged_new_sig
Browse files Browse the repository at this point in the history
Unmerged new sig
  • Loading branch information
VincentAuriau authored Jan 15, 2024
2 parents 4197a63 + a510991 commit 118f864
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 215 deletions.
45 changes: 35 additions & 10 deletions choice_learn/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
if choices is None:
# Done to keep a logical order of arguments, and has logic: choices have to be specified
raise ValueError("Choices must be specified, got None")
choices = np.array(choices)

# --------- [ Handling features type given as tuples or not ] --------- #
# If items_features is not given as tuple, transform it internally as a tuple
Expand Down Expand Up @@ -766,7 +767,7 @@ def from_single_df(
choices = choices.set_index(contexts_id_column)
choices = choices.loc[sessions].to_numpy()
# items is the value (str) of the item
choices = [np.where(items == c)[0] for c in choices]
choices = np.squeeze([np.where(items == c)[0] for c in choices])
elif choice_mode == "one_zero":
choices = df[[items_id_column, choices_column, contexts_id_column]]
choices = choices.loc[choices[choices_column] == 1]
Expand Down Expand Up @@ -1048,21 +1049,45 @@ def __getitem__(self, choices_indexes):
elif isinstance(choices_indexes, slice):
return self.__getitem__(list(range(*choices_indexes.indices(len(self.choices)))))

return ChoiceDataset(
fixed_items_features=self.fixed_items_features,
contexts_features=tuple(
if self.fixed_items_features[0] is None:
fixed_items_features = None
else:
fixed_items_features = self.fixed_items_features
if self.contexts_features[0] is None:
contexts_features = None
else:
contexts_features = tuple(
self.contexts_features[i][choices_indexes]
for i in range(len(self.contexts_features))
),
contexts_items_features=tuple(
)
if self.contexts_items_features[0] is None:
contexts_items_features = None
else:
contexts_items_features = tuple(
self.contexts_items_features[i][choices_indexes]
for i in range(len(self.contexts_items_features))
),
)
if self.fixed_items_features_names[0] is None:
fixed_items_features_names = None
else:
fixed_items_features_names = self.fixed_items_features_names
if self.contexts_features_names[0] is None:
contexts_features_names = None
else:
contexts_features_names = self.contexts_features_names
if self.contexts_items_features_names[0] is None:
contexts_items_features_names = None
else:
contexts_items_features_names = self.contexts_items_features_names
return ChoiceDataset(
fixed_items_features=fixed_items_features,
contexts_features=contexts_features,
contexts_items_features=contexts_items_features,
contexts_items_availabilities=self.contexts_items_availabilities[choices_indexes],
choices=[self.choices[i] for i in choices_indexes],
fixed_items_features_names=self.fixed_items_features_names,
contexts_features_names=self.contexts_features_names,
contexts_items_features_names=self.contexts_items_features_names,
fixed_items_features_names=fixed_items_features_names,
contexts_features_names=contexts_features_names,
contexts_items_features_names=contexts_items_features_names,
features_by_ids=self.features_by_ids,
)

Expand Down
3 changes: 1 addition & 2 deletions choice_learn/data/indexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Indexer classes for data classes."""
from abc import abstractmethod
from collections.abc import Iterable

import numpy as np

Expand Down Expand Up @@ -92,7 +91,7 @@ def __getitem__(self, sequence_keys):
array_like
features corresponding to the sequence_keys
"""
if isinstance(sequence_keys, Iterable):
if isinstance(sequence_keys, list) or isinstance(sequence_keys, np.ndarray):
return np.array([self.storage.storage[key] for key in sequence_keys])
if isinstance(sequence_keys, slice):
raise ValueError("Slicing is not supported for storage")
Expand Down
8 changes: 6 additions & 2 deletions choice_learn/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ def get_element_from_index(self, index):
array_like
features corresponding to the index index in self.store
"""
keys = list(self.storage.keys())[index]
return self.storage[keys]
if isinstance(index, int):
index = [index]
keys = [list(self.storage.keys())[i] for i in index]
return self.batch[keys]

def __len__(self):
"""Returns the length of the sequence of apparition of the features."""
Expand All @@ -135,6 +137,8 @@ def __getitem__(self, id_keys):
FeaturesStorage
Subset of the FeaturesStorage, with only the features whose id is in id_keys
"""
if not isinstance(id_keys, list):
id_keys = [id_keys]
sub_storage = {k: self.storage[k] for k in id_keys}
return FeaturesStorage(values=sub_storage, values_names=self.values_names, name=self.name)

Expand Down
1 change: 0 additions & 1 deletion notebooks/features_storage_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"outputs": [],
"source": [
"features = {\"customerA\": [1, 2, 3], \"customerB\": [4, 5, 6], \"customerC\": [7, 8, 9]}\n",
"\n",
"storage = FeaturesStorage(values=features, values_names=[\"age\", \"income\", \"children_nb\"], name=\"customers\")"
]
},
Expand Down
Loading

0 comments on commit 118f864

Please sign in to comment.