-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add next-basket recommendation evaluation method
- Loading branch information
Showing
10 changed files
with
1,605 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,363 @@ | ||
# Copyright 2023 The Cornac Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
from collections import OrderedDict, defaultdict | ||
|
||
import numpy as np | ||
|
||
from ..utils import get_rng | ||
from ..utils import validate_format | ||
from ..utils import estimate_batches | ||
|
||
|
||
class BasketDataset(object): | ||
"""Training set contains history baskets | ||
Parameters | ||
---------- | ||
num_users: int, required | ||
Number of users. | ||
num_items: int, required | ||
Number of items. | ||
uid_map: :obj:`OrderDict`, required | ||
The dictionary containing mapping from user original ids to mapped integer indices. | ||
iid_map: :obj:`OrderDict`, required | ||
The dictionary containing mapping from item original ids to mapped integer indices. | ||
ub_tuple: tuple, required | ||
Tuple (user_indices, baskets). | ||
timestamps: numpy.array, optional, default: None | ||
Array of timestamps corresponding to observations in `ub_tuple`. | ||
seed: int, optional, default: None | ||
Random seed for reproducing data sampling. | ||
Attributes | ||
---------- | ||
ub_tuple: tuple | ||
Tuple (user_indices, baskets). | ||
timestamps: numpy.array | ||
Numpy array of timestamps corresponding to feedback in `ub_tuple`. | ||
This is only available when input data is in `UTB` format. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_users, | ||
num_items, | ||
uid_map, | ||
iid_map, | ||
ub_tuple, | ||
timestamps=None, | ||
seed=None, | ||
): | ||
self.num_users = num_users | ||
self.num_items = num_items | ||
self.uid_map = uid_map | ||
self.iid_map = iid_map | ||
self.ub_tuple = ub_tuple | ||
self.timestamps = timestamps | ||
self.seed = seed | ||
self.rng = get_rng(seed) | ||
|
||
self.num_baskets = len(ub_tuple[0]) | ||
basket_sizes = [len(basket) for basket in ub_tuple[1]] | ||
self.max_basket_size = np.max(basket_sizes) | ||
self.min_basket_size = np.min(basket_sizes) | ||
self.avg_basket_size = np.mean(basket_sizes) | ||
|
||
self.__user_ids = None | ||
self.__item_ids = None | ||
self.__user_data = None | ||
|
||
@property | ||
def user_ids(self): | ||
"""Return the list of raw user ids""" | ||
if self.__user_ids is None: | ||
self.__user_ids = list(self.uid_map.keys()) | ||
return self.__user_ids | ||
|
||
@property | ||
def item_ids(self): | ||
"""Return the list of raw item ids""" | ||
if self.__item_ids is None: | ||
self.__item_ids = list(self.iid_map.keys()) | ||
return self.__item_ids | ||
|
||
@property | ||
def user_data(self): | ||
"""Data organized by user. A dictionary where keys are users, | ||
values are list of baskets purchased by corresponding users. | ||
""" | ||
if self.__user_data is None: | ||
self.__user_data = defaultdict() | ||
for u, basket in zip(*self.ub_tuple): | ||
u_data = self.__user_data.setdefault(u, []) | ||
u_data.append(basket) | ||
return self.__user_data | ||
|
||
@property | ||
def chrono_user_data(self): | ||
"""Data organized by user sorted chronologically (timestamps required). | ||
A dictionary where keys are users, values are tuples of three chronologically | ||
sorted lists (items, ratings, timestamps) interacted by the corresponding users. | ||
""" | ||
if self.timestamps is None: | ||
raise ValueError("Timestamps are required but None!") | ||
|
||
if self.__chrono_user_data is None: | ||
self.__chrono_user_data = defaultdict() | ||
for u, b, t in zip(*self.ub_tuple, self.timestamps): | ||
u_data = self.__chrono_user_data.setdefault(u, ([], [], [])) | ||
u_data[0].append(b) | ||
u_data[1].append(t) | ||
# sorting based on timestamps | ||
for user, (baskets, timestamps) in self.__chrono_user_data.items(): | ||
sorted_idx = np.argsort(timestamps) | ||
sorted_baskets = [baskets[i] for i in sorted_idx] | ||
sorted_timestamps = [timestamps[i] for i in sorted_idx] | ||
self.__chrono_user_data[user] = ( | ||
sorted_baskets, | ||
sorted_timestamps, | ||
) | ||
return self.__chrono_user_data | ||
|
||
@classmethod | ||
def build( | ||
cls, | ||
data, | ||
fmt="UB", | ||
global_uid_map=None, | ||
global_iid_map=None, | ||
seed=None, | ||
exclude_unknowns=False, | ||
): | ||
"""Constructing Dataset from given data of specific format. | ||
Parameters | ||
---------- | ||
data: list, required | ||
Data in the form of tuple (user, basket) for UB format, | ||
or tuple (user, timestamps, basket) for UTB format. | ||
fmt: str, default: 'UB' | ||
Format of the input data. Currently, we are supporting: | ||
'UB': User, Basket | ||
'UTB': User, Timestamp, Basket | ||
global_uid_map: :obj:`defaultdict`, optional, default: None | ||
The dictionary containing global mapping from original ids to mapped ids of users. | ||
global_iid_map: :obj:`defaultdict`, optional, default: None | ||
The dictionary containing global mapping from original ids to mapped ids of items. | ||
seed: int, optional, default: None | ||
Random seed for reproducing data sampling. | ||
exclude_unknowns: bool, default: False | ||
Ignore unknown users and items. | ||
Returns | ||
------- | ||
res: :obj:`<cornac.data.BasketDataset>` | ||
BasketDataset object. | ||
""" | ||
fmt = validate_format(fmt, ["UB", "UTB"]) | ||
|
||
if global_uid_map is None: | ||
global_uid_map = OrderedDict() | ||
if global_iid_map is None: | ||
global_iid_map = OrderedDict() | ||
|
||
uid_map = OrderedDict() | ||
iid_map = OrderedDict() | ||
|
||
u_indices = [] | ||
i_indices = [] | ||
baskets = [] | ||
valid_idx = [] | ||
|
||
for idx, basket_tuples in enumerate(data): | ||
uid = basket_tuples[0] | ||
raw_basket = basket_tuples[-1] | ||
if exclude_unknowns: | ||
raw_basket = [tup for tup in raw_basket if tup[0] in global_iid_map] | ||
|
||
if len(raw_basket) == 0: | ||
continue | ||
|
||
uid_map[uid] = global_uid_map.setdefault(uid, len(global_uid_map)) | ||
for (iid, *_) in raw_basket: | ||
iid_map[iid] = global_iid_map.setdefault(iid, len(global_iid_map)) | ||
|
||
u_indices.append(uid_map[uid]) | ||
i_indices.append(iid_map[iid]) | ||
baskets.append([tuple((iid_map[tup[0]], *tup[1:])) for tup in raw_basket]) | ||
valid_idx.append(idx) | ||
|
||
ub_tuple = (np.asarray(u_indices, dtype="int"), baskets) | ||
|
||
timestamps = ( | ||
np.fromiter((int(data[i][1]) for i in valid_idx), dtype="int") | ||
if fmt == "UTB" | ||
else None | ||
) | ||
|
||
dataset = cls( | ||
num_users=len(global_uid_map), | ||
num_items=len(global_iid_map), | ||
uid_map=global_uid_map, | ||
iid_map=global_iid_map, | ||
ub_tuple=ub_tuple, | ||
timestamps=timestamps, | ||
seed=seed, | ||
) | ||
|
||
return dataset | ||
|
||
@classmethod | ||
def from_ub(cls, data, seed=None): | ||
"""Constructing Dataset from UB (User, Basket) tuple data. | ||
Parameters | ||
---------- | ||
data: list | ||
Data in the form of tuple (user, basket). | ||
Each basket is a list of items [(item, <optional attributes>),...]. | ||
seed: int, optional, default: None | ||
Random seed for reproducing data sampling. | ||
Returns | ||
------- | ||
res: :obj:`<cornac.data.BasketDataset>` | ||
BasketDataset object. | ||
""" | ||
return cls.build(data, fmt="UB", seed=seed) | ||
|
||
@classmethod | ||
def from_utb(cls, data, seed=None): | ||
"""Constructing Dataset from UTB format (User, Timestamp, Basket) | ||
Parameters | ||
---------- | ||
data: tuple | ||
Data in the form of triplets (user, timestamp, basket) | ||
seed: int, optional, default: None | ||
Random seed for reproducing data sampling. | ||
Returns | ||
------- | ||
res: :obj:`<cornac.data.BasketDataset>` | ||
BasketDataset object. | ||
""" | ||
return cls.build(data, fmt="UTB", seed=seed) | ||
|
||
def reset(self): | ||
"""Reset the random number generator for reproducibility""" | ||
self.rng = get_rng(self.seed) | ||
return self | ||
|
||
def num_batches(self, batch_size): | ||
"""Estimate number of batches per epoch""" | ||
return estimate_batches(len(self.ub_tuple[0]), batch_size) | ||
|
||
def idx_iter(self, idx_range, batch_size=1, shuffle=False): | ||
"""Create an iterator over batch of indices | ||
Parameters | ||
---------- | ||
batch_size: int, optional, default = 1 | ||
shuffle: bool, optional | ||
If True, orders of triplets will be randomized. If False, default orders kept | ||
Returns | ||
------- | ||
iterator : batch of indices (array of 'int') | ||
""" | ||
indices = np.arange(idx_range) | ||
if shuffle: | ||
self.rng.shuffle(indices) | ||
|
||
n_batches = estimate_batches(len(indices), batch_size) | ||
for b in range(n_batches): | ||
start_offset = batch_size * b | ||
end_offset = batch_size * b + batch_size | ||
end_offset = min(end_offset, len(indices)) | ||
batch_ids = indices[start_offset:end_offset] | ||
yield batch_ids | ||
|
||
def ub_iter(self, batch_size=1, shuffle=False): | ||
"""Create an iterator over data yielding batch of users, baskets | ||
Parameters | ||
---------- | ||
batch_size: int, optional, default = 1 | ||
shuffle: bool, optional, default: False | ||
If `True`, orders of triplets will be randomized. If `False`, default orders kept. | ||
Returns | ||
------- | ||
iterator : batch of users (array of 'int'), batch of baskets (list of list) | ||
""" | ||
for batch_ids in self.idx_iter(len(self.ub_tuple[0]), batch_size, shuffle): | ||
batch_users = self.ub_tuple[0][batch_ids] | ||
batch_baskets = [self.ub_tuple[-1][idx] for idx in batch_ids] | ||
|
||
yield batch_users, batch_baskets | ||
|
||
def user_iter(self, batch_size=1, shuffle=False): | ||
"""Create an iterator over user indices | ||
Parameters | ||
---------- | ||
batch_size : int, optional, default = 1 | ||
shuffle : bool, optional | ||
If True, orders of triplets will be randomized. If False, default orders kept | ||
Returns | ||
------- | ||
iterator : batch of user indices (array of 'int') | ||
""" | ||
user_indices = np.fromiter(set(self.ub_tuple[0]), dtype="int") | ||
for batch_ids in self.idx_iter(len(user_indices), batch_size, shuffle): | ||
yield user_indices[batch_ids] | ||
|
||
def add_modalities(self, **kwargs): | ||
self.user_feature = kwargs.get("user_feature", None) | ||
self.item_feature = kwargs.get("item_feature", None) | ||
self.user_text = kwargs.get("user_text", None) | ||
self.item_text = kwargs.get("item_text", None) | ||
self.user_image = kwargs.get("user_image", None) | ||
self.item_image = kwargs.get("item_image", None) | ||
self.user_graph = kwargs.get("user_graph", None) | ||
self.item_graph = kwargs.get("item_graph", None) | ||
self.sentiment = kwargs.get("sentiment", None) | ||
self.review_text = kwargs.get("review_text", None) |
Oops, something went wrong.