From 93b0e88f524bdb3036ad074e4c54fcce8dc9edff Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Tue, 9 Apr 2024 11:19:53 +0200 Subject: [PATCH] FIX: stacking and typing orders in RUMnet --- choice_learn/models/rumnet.py | 54 +++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/choice_learn/models/rumnet.py b/choice_learn/models/rumnet.py index d56fd65e..0c5e22e7 100644 --- a/choice_learn/models/rumnet.py +++ b/choice_learn/models/rumnet.py @@ -628,14 +628,20 @@ def compute_batch_utility( Shape must be (n_choices, n_items) """ (_, _) = available_items_by_choice, choices - # Restacking of the item features + # Restacking and dtyping of the item features if isinstance(shared_features_by_choice, tuple): - shared_features_by_choice = tf.concat([*shared_features_by_choice], axis=-1) + shared_features_by_choice = tf.concat( + [ + tf.cast(shared_feature, tf.float32) + for shared_feature in shared_features_by_choice + ], + axis=-1, + ) if isinstance(items_features_by_choice, tuple): - items_features_by_choice = tf.concat([*items_features_by_choice], axis=-1) - - shared_features_by_choice = tf.cast(shared_features_by_choice, tf.float32) - items_features_by_choice = tf.cast(items_features_by_choice, tf.float32) + items_features_by_choice = tf.concat( + [tf.cast(items_feature, tf.float32) for items_feature in items_features_by_choice], + axis=-1, + ) # Computation of utilities utilities = [] @@ -859,14 +865,20 @@ def compute_batch_utility( Shape must be (n_choices, n_items) """ (_, _) = available_items_by_choice, choices - # Restacking of the item features + # Restacking and dtyping of the item features if isinstance(shared_features_by_choice, tuple): - shared_features_by_choice = tf.concat([*shared_features_by_choice], axis=-1) + shared_features_by_choice = tf.concat( + [ + tf.cast(shared_feature, tf.float32) + for shared_feature in shared_features_by_choice + ], + axis=-1, + ) if isinstance(items_features_by_choice, tuple): - items_features_by_choice = tf.concat([*items_features_by_choice], axis=-1) - - shared_features_by_choice = tf.cast(shared_features_by_choice, tf.float32) - items_features_by_choice = tf.cast(items_features_by_choice, tf.float32) + items_features_by_choice = tf.concat( + [tf.cast(items_feature, tf.float32) for items_feature in items_features_by_choice], + axis=-1, + ) # Computation of utilities utilities = [] @@ -986,14 +998,20 @@ def compute_batch_utility( """ (_, _) = available_items_by_choice, choices - # Restacking of the item features + # Restacking and dtyping of the item features if isinstance(shared_features_by_choice, tuple): - shared_features_by_choice = tf.concat([*shared_features_by_choice], axis=-1) + shared_features_by_choice = tf.concat( + [ + tf.cast(shared_feature, tf.float32) + for shared_feature in shared_features_by_choice + ], + axis=-1, + ) if isinstance(items_features_by_choice, tuple): - items_features_by_choice = tf.concat([*items_features_by_choice], axis=-1) - - shared_features_by_choice = tf.cast(shared_features_by_choice, tf.float32) - items_features_by_choice = tf.cast(items_features_by_choice, tf.float32) + items_features_by_choice = tf.concat( + [tf.cast(items_feature, tf.float32) for items_feature in items_features_by_choice], + axis=-1, + ) item_utility_by_choice = []