Skip to content

Commit

Permalink
fix init dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Nov 22, 2023
1 parent c504c16 commit dbd59cb
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,17 @@ def __init__(
self.seed = seed
self.rng = get_rng(seed)

r_values = uir_tuple[2] if uir_tuple is not None else None
self.num_ratings = len(r_values) if r_values is not None else 0
self.max_rating = np.max(r_values, 0)
self.min_rating = np.min(r_values, 0)
self.global_mean = np.mean(r_values) if r_values is not None else 0
if uir_tuple is not None:
r_values = uir_tuple[2]
self.num_ratings = len(r_values)
self.max_rating = np.max(r_values)
self.min_rating = np.min(r_values)
self.global_mean = np.mean(r_values)
else:
self.num_ratings = 0
self.max_rating = 0
self.min_rating = 0
self.global_mean = 0

self.__user_ids = None
self.__item_ids = None
Expand Down Expand Up @@ -628,10 +634,10 @@ def __init__(
self.bid_map = bid_map
self.ubi_tuple = ubi_tuple
self.extra_data = extra_data
basket_sizes = list(Counter(ubi_tuple[1]))
self.max_basket_size = np.max(basket_sizes) if basket_sizes is not None else 0
self.min_basket_size = np.min(basket_sizes) if basket_sizes is not None else 0
self.avg_basket_size = np.mean(basket_sizes) if basket_sizes is not None else 0
basket_sizes = list(Counter(ubi_tuple[1]).values())
self.max_basket_size = np.max(basket_sizes)
self.min_basket_size = np.min(basket_sizes)
self.avg_basket_size = np.mean(basket_sizes)
self._build_basket()
self.__user_data = None
self.__chrono_user_data = None
Expand Down Expand Up @@ -776,7 +782,7 @@ def build(
)

extra_data = (
np.fromiter((data[i][4] for i in valid_idx), dtype="object")
np.fromiter((data[i][4] for i in valid_idx), dtype=object)
if fmt in ["UBITJson"]
else None
)
Expand Down

0 comments on commit dbd59cb

Please sign in to comment.