Skip to content

Commit

Permalink
FIX: stacking and typing orders in RUMnet
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Apr 9, 2024
1 parent 7911820 commit 93b0e88
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions choice_learn/models/rumnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,14 +628,20 @@ def compute_batch_utility(
Shape must be (n_choices, n_items)
"""
(_, _) = available_items_by_choice, choices
# Restacking of the item features
# Restacking and dtyping of the item features
if isinstance(shared_features_by_choice, tuple):
shared_features_by_choice = tf.concat([*shared_features_by_choice], axis=-1)
shared_features_by_choice = tf.concat(
[
tf.cast(shared_feature, tf.float32)
for shared_feature in shared_features_by_choice
],
axis=-1,
)
if isinstance(items_features_by_choice, tuple):
items_features_by_choice = tf.concat([*items_features_by_choice], axis=-1)

shared_features_by_choice = tf.cast(shared_features_by_choice, tf.float32)
items_features_by_choice = tf.cast(items_features_by_choice, tf.float32)
items_features_by_choice = tf.concat(
[tf.cast(items_feature, tf.float32) for items_feature in items_features_by_choice],
axis=-1,
)

# Computation of utilities
utilities = []
Expand Down Expand Up @@ -859,14 +865,20 @@ def compute_batch_utility(
Shape must be (n_choices, n_items)
"""
(_, _) = available_items_by_choice, choices
# Restacking of the item features
# Restacking and dtyping of the item features
if isinstance(shared_features_by_choice, tuple):
shared_features_by_choice = tf.concat([*shared_features_by_choice], axis=-1)
shared_features_by_choice = tf.concat(
[
tf.cast(shared_feature, tf.float32)
for shared_feature in shared_features_by_choice
],
axis=-1,
)
if isinstance(items_features_by_choice, tuple):
items_features_by_choice = tf.concat([*items_features_by_choice], axis=-1)

shared_features_by_choice = tf.cast(shared_features_by_choice, tf.float32)
items_features_by_choice = tf.cast(items_features_by_choice, tf.float32)
items_features_by_choice = tf.concat(
[tf.cast(items_feature, tf.float32) for items_feature in items_features_by_choice],
axis=-1,
)

# Computation of utilities
utilities = []
Expand Down Expand Up @@ -986,14 +998,20 @@ def compute_batch_utility(
"""
(_, _) = available_items_by_choice, choices

# Restacking of the item features
# Restacking and dtyping of the item features
if isinstance(shared_features_by_choice, tuple):
shared_features_by_choice = tf.concat([*shared_features_by_choice], axis=-1)
shared_features_by_choice = tf.concat(
[
tf.cast(shared_feature, tf.float32)
for shared_feature in shared_features_by_choice
],
axis=-1,
)
if isinstance(items_features_by_choice, tuple):
items_features_by_choice = tf.concat([*items_features_by_choice], axis=-1)

shared_features_by_choice = tf.cast(shared_features_by_choice, tf.float32)
items_features_by_choice = tf.cast(items_features_by_choice, tf.float32)
items_features_by_choice = tf.concat(
[tf.cast(items_feature, tf.float32) for items_feature in items_features_by_choice],
axis=-1,
)

item_utility_by_choice = []

Expand Down

0 comments on commit 93b0e88

Please sign in to comment.