Skip to content

Commit

Permalink
Merge pull request #53 from artefactory/batch
Browse files Browse the repository at this point in the history
FIX:  subsampling, dtype & weights addition
  • Loading branch information
VincentAuriau authored Apr 8, 2024
2 parents 914ce0b + bdc2a86 commit ef4e577
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
15 changes: 8 additions & 7 deletions choice_learn/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,8 @@ def __getitem__(self, choices_indexes):
self.shared_features_by_choice[i][choices_indexes]
for i in range(len(self.shared_features_by_choice))
)
if not self._return_shared_features_by_choice_tuple:
shared_features_by_choice = shared_features_by_choice[0]
except TypeError:
shared_features_by_choice = None

Expand All @@ -1236,6 +1238,8 @@ def __getitem__(self, choices_indexes):
self.items_features_by_choice[i][choices_indexes]
for i in range(len(self.items_features_by_choice))
)
if not self._return_items_features_by_choice_tuple:
items_features_by_choice = items_features_by_choice[0]
except TypeError:
items_features_by_choice = None

Expand Down Expand Up @@ -1300,19 +1304,16 @@ def iter_batch(self, batch_size, shuffle=False, sample_weight=None):
yielded_size = 0
while yielded_size < num_choices:
# Return sample_weight if not None, for index matching
batch_indexes = indexes[yielded_size : yielded_size + batch_size].tolist()
if sample_weight is not None:
yield (
self.batch[indexes[yielded_size : yielded_size + batch_size].tolist()],
sample_weight[indexes[yielded_size : yielded_size + batch_size].tolist()],
self.batch[batch_indexes],
sample_weight[batch_indexes],
)
else:
yield self.batch[indexes[yielded_size : yielded_size + batch_size].tolist()]
yield self.batch[batch_indexes]
yielded_size += batch_size

# Special exit strategy for batch_size = -1
if batch_size == -1:
yielded_size += 2 * num_choices

def filter(self, bool_list):
"""Filter over sessions indexes following bool.
Expand Down
4 changes: 3 additions & 1 deletion choice_learn/data/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def __getitem__(self, choices_indexes):
available_items_by_choice = self.choice_dataset.available_items_by_choice[
choices_indexes
]
# .astype(self._return_types[3])
available_items_by_choice = available_items_by_choice.astype(
self.choice_dataset._return_types[2]
)

choices = self.choice_dataset.choices[choices_indexes].astype(
self.choice_dataset._return_types[3]
Expand Down
2 changes: 1 addition & 1 deletion choice_learn/models/simple_mnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def instantiate(self, n_items, n_shared_features, n_items_features):
["shared_features", "items_features"],
):
if n_feat > 0:
weights = [
weights += [
tf.Variable(
tf.random_normal_initializer(0.0, 0.02, seed=42)(shape=(n_feat,)),
name=f"Weights_{feat_name}",
Expand Down

0 comments on commit ef4e577

Please sign in to comment.