Skip to content

Commit

Permalink
ADD: modelwise proba
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Dec 26, 2024
1 parent 2ede973 commit e15d9ef
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions choice_learn/models/latent_class_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,43 @@ def predict_probas(self, choice_dataset, batch_size=-1):

return tf.concat(stacked_probabilities, axis=0)

def predict_modelwise_probas(self, choice_dataset, batch_size=-1):
"""Predicts the choice probabilities for each choice and each product of a ChoiceDataset.
Stacks each model probability.
Parameters
----------
choice_dataset : ChoiceDataset
Dataset on which to apply to prediction
batch_size : int, optional
Batch size to use for the prediction, by default -1
Returns
-------
np.ndarray (n_choices, n_items)
Choice probabilties for each choice and each product
"""
modelwise_probabilities = []
for model in self.models:
stacked_probabilities = []
for (
shared_features,
items_features,
available_items,
choices,
) in choice_dataset.iter_batch(batch_size=batch_size):
_, probabilities = model.batch_predict(
shared_features_by_choice=shared_features,
items_features_by_choice=items_features,
available_items_by_choice=available_items,
choices=choices,
)
stacked_probabilities.append(probabilities)
modelwise_probabilities.append(tf.concat(stacked_probabilities, axis=0))

return tf.stack(modelwise_probabilities, axis=0)

def get_latent_classes_weights(self):
"""Return the latent classes weights / probabilities from logits.
Expand Down

0 comments on commit e15d9ef

Please sign in to comment.