Skip to content

Commit

Permalink
FIX: removed contexts references
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Apr 9, 2024
1 parent 55776d0 commit 7911820
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions choice_learn/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ChoiceDataset(object):
def __init__(
self,
choices, # Should not have None as default value ?
shared_features_by_choice=None, # as many context as choices. values or ids (look at key)
shared_features_by_choice=None, # as many as choices. values or ids (look at key)
items_features_by_choice=None,
available_items_by_choice=None,
features_by_ids=[], # list of (name, FeaturesStorage)
Expand All @@ -40,11 +40,11 @@ def __init__(
choices: list or np.ndarray
list of chosen items indexes
shared_features_by_choice : tuple of (array_like, )
matrix of shape (num_choices, num_contexts_features) containing the features of the
different contexts that are common to all items (e.g. store features,
matrix of shape (num_choices, num_shared_features) containing the features of the
different choices that are common to all items (e.g. store features,
customer features, etc...)
items_features_by_choice : tuple of (array_like, ), default is None
matrix of shape (num_choices, num_items, num_contexts_items_features)
matrix of shape (num_choices, num_items, num_items_features)
containing the features
of the items that change over time (e.g. price, promotion, etc...), default is None
available_items_by_choice : array_like
Expand All @@ -55,9 +55,11 @@ def __init__(
among shared_features_by_choice or items_features_by_choice
and their ids must match to those features values. Default is []
shared_features_by_choice_names : tuple of (array_like, )
list of names of the contexts_features, default is None
list of names of the shared_features_by_choice, default is None
Shapes must match with shared_features_by_choice
items_features_by_choice_names : tuple of (array_like, )
list of names of the contexts_items_features, default is None
list of names of the items_features_by_choice, default is None
Shapes must match with items_features_by_choice
"""
if choices is None:
# Done to keep a logical order of arguments, and has logic: choices have to be specified
Expand Down Expand Up @@ -160,13 +162,13 @@ def __init__(
# Basically it transforms them to be internally stocked as np.ndarray and keep columns
# names as features names

# Handling context features
# Handling shared features
if shared_features_by_choice is not None:
for i, feature in enumerate(shared_features_by_choice):
if isinstance(feature, pd.DataFrame):
# Ordering choices by id ?
if "context_id" in feature.columns:
feature = feature.set_index("context_id")
if "choice_id" in feature.columns:
feature = feature.set_index("choice_id")
shared_features_by_choice = (
shared_features_by_choice[:i]
+ (shared_features_by_choice[i].loc[np.sort(feature.index)].to_numpy(),)
Expand Down Expand Up @@ -245,7 +247,7 @@ def __init__(
)
available_items_by_choice = np.array(temp_availabilities)
else:
feature = feature.set_index("context_id")
feature = feature.set_index("choice_id")
items_features_by_choice = (
items_features_by_choice[:i]
+ (feature.loc[np.sort(feature.index)].to_numpy(),)
Expand Down Expand Up @@ -284,9 +286,9 @@ def __init__(
if "choice_id" in available_items_by_choice.columns:
if "item_id" in available_items_by_choice.columns:
av_array = []
for sess in np.sort(available_items_by_choice.context_id):
for sess in np.sort(available_items_by_choice.choice_id):
sess_df = available_items_by_choice.loc[
available_items_by_choice.context_id == sess
available_items_by_choice.choice_id == sess
]
sess_df = sess_df.set_index("item_id")
av_array.append(sess_df.loc[np.sort(sess_df.index)].to_numpy())
Expand Down Expand Up @@ -359,9 +361,9 @@ def _build_features_by_ids(self):
Returns:
--------
tuple
indexes and features_by_id of contexts_features
indexes and features_by_id of shared_features_by_choice
tuple
indexes and features_by_id of contexts_items_features
indexes and features_by_id of items_features_by_choice
"""
if len(self.features_by_ids) == 0:
return {}, {}
Expand Down Expand Up @@ -670,7 +672,7 @@ def _long_df_to_items_features_array(
items_index=None,
choices_index=None,
):
"""Builds contexts_items_features and contexts_items_availabilities from dataframe.
"""Builds items_features_by_choice and available_items_by_choice from dataframe.
Parameters:
-----------
Expand All @@ -690,7 +692,7 @@ def _long_df_to_items_features_array(
Returns:
-------
np.ndarray of shape (n_choices, n_items, n_features)
Corresponding contexts_items_features
Corresponding items_features_by_choice
np.ndarray of shape (n_choices, n_items)
Corresponding availabilities
"""
Expand Down Expand Up @@ -802,8 +804,8 @@ def from_single_wide_df(
)
if available_items_prefix is not None and available_items_suffix is not None:
raise ValueError(
"You cannot give both contexts_items_availabilities_prefix and\
contexts_items_availabilities_suffix."
"You cannot give both available_items_prefix and\
available_items_suffix."
)
if choice_format not in ["items_index", "items_name"]:
logging.warning("choice_format not undersood, defaulting to 'items_index'")
Expand Down Expand Up @@ -1200,7 +1202,7 @@ def __getitem__(self, choices_indexes):
Parameters:
-----------
choices_indexes : np.ndarray
indexes of the contexts / choices to keep, shape should be (num_choices,)
indexes of the choices to keep, shape should be (num_choices,)
Returns:
-------
Expand Down Expand Up @@ -1319,14 +1321,14 @@ def filter(self, bool_list):
Parameters:
-----------
bool_list : list of boolean
list of booleans of length self.get_n_contexts() to filter contexts.
list of booleans of length self.get_n_choices() to filter choices.
True to keep, False to discard.
"""
indexes = [i for i, keep in enumerate(bool_list) if keep]
return self[indexes]

def get_n_shared_features(self):
"""Method to access the number of contexts features.
"""Method to access the number of shared features.
Returns:
-------
Expand All @@ -1341,7 +1343,7 @@ def get_n_shared_features(self):
return 0

def get_n_items_features(self):
"""Method to access the number of context items features.
"""Method to access the number of items features.
Returns:
-------
Expand Down

0 comments on commit 7911820

Please sign in to comment.