Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Nov 23, 2023
1 parent 1262444 commit cf78908
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 57 deletions.
111 changes: 67 additions & 44 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
num_items,
uid_map,
iid_map,
uir_tuple=None,
uir_tuple,
timestamps=None,
seed=None,
):
Expand All @@ -89,17 +89,11 @@ def __init__(
self.seed = seed
self.rng = get_rng(seed)

if uir_tuple is not None:
r_values = uir_tuple[2]
self.num_ratings = len(r_values)
self.max_rating = np.max(r_values)
self.min_rating = np.min(r_values)
self.global_mean = np.mean(r_values)
else:
self.num_ratings = 0
self.max_rating = 0
self.min_rating = 0
self.global_mean = 0
(_, _, r_values) = uir_tuple
self.num_ratings = len(r_values)
self.max_rating = np.max(r_values)
self.min_rating = np.min(r_values)
self.global_mean = np.mean(r_values)

self.__user_ids = None
self.__item_ids = None
Expand Down Expand Up @@ -587,14 +581,18 @@ class BasketDataset(Dataset):
iid_map: :obj:`OrderDict`, required
The dictionary containing mapping from item original ids to mapped integer indices.
ubi_tuple: tuple, required
Tuple (user_indices, basket_indices, item_indices).
uir_tuple: tuple, required
Tuple of 3 numpy arrays (user_indices, item_indices, rating_values).
basket_ids: numpy.array, required
Array of basket indices corresponding to observation in `uir_tuple`.
timestamps: numpy.array, optional, default: None
Array of timestamps corresponding to observations in `ubi_tuple`.
Numpy array of timestamps corresponding to feedback in `uir_tuple`.
This is only available when input data is in `UBIT` and `UBITJson` formats.
extra_data: numpy.array, optional, default: None
Array of json object corresponding to observations in `ubi_tuple`.
Array of json object corresponding to observations in `uir_tuple`.
seed: int, optional, default: None
Random seed for reproducing data sampling.
Expand All @@ -617,7 +615,8 @@ def __init__(
uid_map,
bid_map,
iid_map,
ubi_tuple=None,
uir_tuple,
basket_ids=None,
timestamps=None,
extra_data=None,
seed=None,
Expand All @@ -627,20 +626,23 @@ def __init__(
num_items=num_items,
uid_map=uid_map,
iid_map=iid_map,
uir_tuple=uir_tuple,
timestamps=timestamps,
seed=seed,
)
self.num_baskets = num_baskets
self.bid_map = bid_map
self.ubi_tuple = ubi_tuple
self.basket_ids = basket_ids
self.extra_data = extra_data
basket_sizes = list(Counter(ubi_tuple[1]).values())
basket_sizes = list(Counter(basket_ids).values())
self.max_basket_size = np.max(basket_sizes)
self.min_basket_size = np.min(basket_sizes)
self.avg_basket_size = np.mean(basket_sizes)
self._build_basket()
self.__user_data = None
self.__chrono_user_data = None

self.__baskets = None
self.__basket_timestamps = None
self.__user_basket_data = None
self.__chrono_user_basket_data = None

def _build_basket(self):
baskets = OrderedDict()
Expand All @@ -656,44 +658,62 @@ def _build_basket(self):
self.basket_timestamps = basket_timestamps

@property
def user_data(self):
def baskets(self):
if self.__baskets is None:
self.__baskets = OrderedDict()
for idx, bid in enumerate(self.basket_ids):
self.__baskets.setdefault(bid, [])
self.__baskets[bid].append(idx)
return self.__baskets

@property
def basket_timestamps(self):
if self.__basket_timestamps is None:
if self.timestamps is not None:
self.__basket_timestamps = []
for _, ids in self.baskets.items():
self.__basket_timestamps.append(self.timestamps[ids[0]])
return self.__basket_timestamps

@property
def user_basket_data(self):
"""Data organized by user. A dictionary where keys are users,
values are list of baskets purchased by corresponding users.
"""
if self.__user_data is None:
self.__user_data = defaultdict()
if self.__user_basket_data is None:
self.__user_basket_data = defaultdict()
for bid, ids in self.baskets.items():
u = self.ubi_tuple[0][ids[0]]
u_data = self.__user_data.setdefault(u, [])
u_data.append(bid)
return self.__user_data
u = self.uir_tuple[0][ids[0]]
self.__user_basket_data.setdefault(u, [])
self.__user_basket_data[u].append(bid)
return self.__user_basket_data

@property
def chrono_user_data(self):
def chrono_user_basket_data(self):
"""Data organized by user sorted chronologically (timestamps required).
A dictionary where keys are users, values are tuples of three chronologically
sorted lists (baskets, timestamps) interacted by the corresponding users.
"""
if self.basket_timestamps is None:
raise ValueError("Basket Timestamps are required but None!")

if self.__chrono_user_data is None:
self.__chrono_user_data = defaultdict()
if self.__chrono_user_basket_data is None:
self.__chrono_user_basket_data = defaultdict()
for (bid, ids), t in zip(*self.baskets.values(), self.basket_timestamps):
u = self.ubi_tuple[0][ids[0]]
u_data = self.__chrono_user_data.setdefault(u, ([], []))
u = self.uir_tuple[0][ids[0]]
u_data = self.__chrono_user_basket_data.setdefault(u, ([], []))
u_data[0].append(bid)
u_data[1].append(t)
# sorting based on timestamps
for user, (baskets, timestamps) in self.__chrono_user_data.items():
for user, (baskets, timestamps) in self.__chrono_user_basket_data.items():
sorted_idx = np.argsort(timestamps)
sorted_baskets = [baskets[i] for i in sorted_idx]
sorted_timestamps = [timestamps[i] for i in sorted_idx]
self.__chrono_user_data[user] = (
self.__chrono_user_basket_data[user] = (
sorted_baskets,
sorted_timestamps,
)
return self.__chrono_user_data
return self.__chrono_user_basket_data

@classmethod
def build(
Expand Down Expand Up @@ -769,12 +789,14 @@ def build(
i_indices.append(global_iid_map[iid])
valid_idx.append(idx)

ubi_tuple = (
uir_tuple = (
np.asarray(u_indices, dtype="int"),
np.asarray(b_indices, dtype="int"),
np.asarray(i_indices, dtype="int"),
np.ones(len(u_indices), dtype="float"),
)

basket_ids = np.asarray(b_indices, dtype="int")

timestamps = (
np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int")
if fmt in ["UBIT", "UBITJson"]
Expand All @@ -790,7 +812,8 @@ def build(
uid_map=global_uid_map,
bid_map=global_bid_map,
iid_map=global_iid_map,
ubi_tuple=ubi_tuple,
uir_tuple=uir_tuple,
basket_ids=basket_ids,
timestamps=timestamps,
extra_data=extra_data,
seed=seed,
Expand Down Expand Up @@ -862,7 +885,7 @@ def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.user_data), batch_size)

def user_data_iter(self, batch_size=1, shuffle=False):
def user_basket_data_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of basket indices and batch of baskets
Parameters
Expand All @@ -877,12 +900,12 @@ def user_data_iter(self, batch_size=1, shuffle=False):
iterator : batch of user indices, batch of user data corresponding to user indices
"""
user_indices = np.array(list(self.user_data.keys()))
user_indices = np.asarray(list(self.user_basket_data.keys()), dtype="int")
for batch_ids in self.idx_iter(
len(self.user_data), batch_size=batch_size, shuffle=shuffle
len(self.user_basket_data), batch_size=batch_size, shuffle=shuffle
):
batch_users = user_indices[batch_ids]
batch_basket_ids = [self.user_data[uid] for uid in batch_users]
batch_basket_ids = np.asarray([self.user_basket_data[uid] for uid in batch_users], dtype="int")
yield batch_users, batch_basket_ids

def basket_iter(self, batch_size=1, shuffle=False):
Expand Down
13 changes: 6 additions & 7 deletions cornac/eval_methods/next_basket_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def ranking_eval(
def pos_items(baskets):
return [item_idx for basket in baskets for item_idx in basket]

test_user_indices = set(test_set.ubi_tuple[0])
test_user_indices = set(test_set.uir_tuple[0])
for user_idx in tqdm(
test_user_indices, desc="Ranking", disable=not verbose, miniters=100
):
test_pos_items = pos_items(
[
[test_set.ubi_tuple[2][idx] for idx in test_set.baskets[bid]]
for bid in test_set.user_data[user_idx][-1:]
[test_set.uir_tuple[1][idx] for idx in test_set.baskets[bid]]
for bid in test_set.user_basket_data[user_idx][-1:]
]
)
if len(test_pos_items) == 0:
Expand All @@ -109,13 +109,12 @@ def pos_items(baskets):
item_rank, item_scores = model.rank(
user_idx,
[
[test_set.ubi_tuple[2][idx] for idx in test_set.baskets[bid]]
for bid in test_set.user_data[user_idx][:-1]
[test_set.uir_tuple[1][idx] for idx in test_set.baskets[bid]]
for bid in test_set.user_basket_data[user_idx][:-1]
],
item_indices,
baskets=test_set.baskets,
user_data=test_set.user_data[user_idx],
ubi_tuple=test_set.ubi_tuple,
basket_ids=test_set.basket_ids,
extra_data=test_set.extra_data,
)

Expand Down
8 changes: 2 additions & 6 deletions cornac/models/gp_top/recom_gp_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def __init__(
def fit(self, train_set, val_set=None):
super().fit(train_set=train_set, val_set=val_set)
if self.use_global_popularity:
self.item_freq = Counter(
self.train_set.ubi_tuple[2]
)
self.item_freq = Counter(self.train_set.uir_tuple[1])
return self

def score(self, user_idx, history_baskets, **kwargs):
Expand All @@ -64,9 +62,7 @@ def score(self, user_idx, history_baskets, **kwargs):
item_scores[iid] = freq

if self.use_personalized_popularity:
p_item_freq = Counter(
[iid for iids in history_baskets for iid in iids]
)
p_item_freq = Counter([iid for iids in history_baskets for iid in iids])

max_item_freq = (
max(self.item_freq.values()) if len(self.item_freq) > 0 else 1
Expand Down

0 comments on commit cf78908

Please sign in to comment.