Skip to content

Commit

Permalink
Merge pull request #6 from artefactory/examples
Browse files Browse the repository at this point in the history
ADD: Data & Model base examples
  • Loading branch information
VincentAuriau authored Jan 2, 2024
2 parents 45f9f84 + 16d1856 commit 69049d9
Show file tree
Hide file tree
Showing 6 changed files with 966 additions and 11 deletions.
4 changes: 4 additions & 0 deletions choice_learn/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Data handling classes and functions."""
from .choice_dataset import ChoiceDataset
from .store import FeaturesStore, OneHotStore

__all__ = ["ChoiceDataset", "FeaturesStore", "OneHotStore"]
8 changes: 8 additions & 0 deletions choice_learn/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Init file for datasets module."""

from .base import load_modecanada, load_swissmetro

__all__ = [
"load_modecanada",
"load_swissmetro",
]
5 changes: 5 additions & 0 deletions choice_learn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
"""Models classes and functions."""

from .conditional_mnl import ConditionalMNL, ModelSpecification
from .rumnet import PaperRUMnet as RUMnet

__all__ = ["ModelSpecification", "ConditionalMNL", "RUMnet"]
22 changes: 11 additions & 11 deletions choice_learn/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
import tensorflow as tf
import tqdm
from choice_modeling.tf_ops import (

from choice_learn.tf_ops import (
CustomCategoricalCrossEntropy,
availability_softmax,
custom_softmax,
)

Expand Down Expand Up @@ -163,7 +163,7 @@ def train_step(
# Probabilities of selected product
available_utilities = tf.gather_nd(indices=choices_nd, params=final_utilities)
"""
probabilities = availability_softmax(all_u, availabilities_batch, axis=-1)
# probabilities = availability_softmax(all_u, availabilities_batch, axis=-1)
probabilities = custom_softmax(
all_u, availabilities_batch, normalize_exit=self.normalize_non_buy, axis=-1
)
Expand Down Expand Up @@ -226,15 +226,15 @@ def fit(
if sample_weight is not None:
if verbose > 0:
inner_range = tqdm.tqdm(
choice_dataset.batch(
choice_dataset.iter_batch(
shuffle=True, sample_weight=sample_weight, batch_size=batch_size
),
total=int(len(choice_dataset) / np.max([1, batch_size])),
position=1,
leave=False,
)
else:
inner_range = choice_dataset.batch(
inner_range = choice_dataset.iter_batch(
shuffle=True, sample_weight=sample_weight, batch_size=batch_size
)

Expand Down Expand Up @@ -270,13 +270,13 @@ def fit(
else:
if verbose > 0:
inner_range = tqdm.tqdm(
choice_dataset.batch(shuffle=True, batch_size=batch_size),
choice_dataset.iter_batch(shuffle=True, batch_size=batch_size),
total=int(len(choice_dataset) / np.max([batch_size, 1])),
position=1,
leave=False,
)
else:
inner_range = choice_dataset.batch(shuffle=True, batch_size=batch_size)
inner_range = choice_dataset.iter_batch(shuffle=True, batch_size=batch_size)
for batch_nb, (
items_batch,
sessions_batch,
Expand Down Expand Up @@ -329,7 +329,7 @@ def fit(
sessions_items_batch,
availabilities_batch,
choices_batch,
) in enumerate(val_dataset.batch(shuffle=False, batch_size=batch_size)):
) in enumerate(val_dataset.iter_batch(shuffle=False, batch_size=batch_size)):
self.callbacks.on_batch_begin(batch_nb)
self.callbacks.on_test_batch_begin(batch_nb)
test_losses.append(
Expand Down Expand Up @@ -407,7 +407,7 @@ def batch_predict(
items_batch, sessions_batch, sessions_items_batch, availabilities_batch, choices_batch
)
# Compute probabilities from utilities & availabilties
probabilities = availability_softmax(utilities, availabilities_batch, axis=-1)
# probabilities = availability_softmax(utilities, availabilities_batch, axis=-1)
probabilities = custom_softmax(
utilities, availabilities_batch, normalize_exit=self.normalize_non_buy, axis=-1
)
Expand Down Expand Up @@ -492,7 +492,7 @@ def predict_probas(self, choice_dataset):
sessions_items_batch,
availabilities_batch,
choices_batch,
) in choice_dataset.batch():
) in choice_dataset.iter_batch():
_, probabilities = self.batch_predict(
items_batch,
sessions_batch,
Expand Down Expand Up @@ -529,7 +529,7 @@ def evaluate(self, choice_dataset, batch_size=None):
sessions_items_batch,
availabilities_batch,
choices_batch,
) in choice_dataset.batch(batch_size=batch_size):
) in choice_dataset.iter_batch(batch_size=batch_size):
loss, _ = self.batch_predict(
items_batch,
sessions_batch,
Expand Down
Loading

0 comments on commit 69049d9

Please sign in to comment.