Skip to content

Commit

Permalink
Add next-basket recommendation evaluation method
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Nov 9, 2023
1 parent fbc4b57 commit 0e1a497
Show file tree
Hide file tree
Showing 10 changed files with 1,605 additions and 25 deletions.
2 changes: 2 additions & 0 deletions cornac/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
from .sentiment import SentimentModality
from .reader import Reader
from .dataset import Dataset
from .basket_dataset import BasketDataset

__all__ = ['FeatureModality',
'TextModality',
'ReviewModality',
'ImageModality',
'GraphModality',
'SentimentModality',
'BasketDataset',
'Dataset',
'Reader']
363 changes: 363 additions & 0 deletions cornac/data/basket_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
# Copyright 2023 The Cornac Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

from collections import OrderedDict, defaultdict

import numpy as np

from ..utils import get_rng
from ..utils import validate_format
from ..utils import estimate_batches


class BasketDataset(object):
"""Training set contains history baskets
Parameters
----------
num_users: int, required
Number of users.
num_items: int, required
Number of items.
uid_map: :obj:`OrderDict`, required
The dictionary containing mapping from user original ids to mapped integer indices.
iid_map: :obj:`OrderDict`, required
The dictionary containing mapping from item original ids to mapped integer indices.
ub_tuple: tuple, required
Tuple (user_indices, baskets).
timestamps: numpy.array, optional, default: None
Array of timestamps corresponding to observations in `ub_tuple`.
seed: int, optional, default: None
Random seed for reproducing data sampling.
Attributes
----------
ub_tuple: tuple
Tuple (user_indices, baskets).
timestamps: numpy.array
Numpy array of timestamps corresponding to feedback in `ub_tuple`.
This is only available when input data is in `UTB` format.
"""

def __init__(
self,
num_users,
num_items,
uid_map,
iid_map,
ub_tuple,
timestamps=None,
seed=None,
):
self.num_users = num_users
self.num_items = num_items
self.uid_map = uid_map
self.iid_map = iid_map
self.ub_tuple = ub_tuple
self.timestamps = timestamps
self.seed = seed
self.rng = get_rng(seed)

self.num_baskets = len(ub_tuple[0])
basket_sizes = [len(basket) for basket in ub_tuple[1]]
self.max_basket_size = np.max(basket_sizes)
self.min_basket_size = np.min(basket_sizes)
self.avg_basket_size = np.mean(basket_sizes)

self.__user_ids = None
self.__item_ids = None
self.__user_data = None

@property
def user_ids(self):
"""Return the list of raw user ids"""
if self.__user_ids is None:
self.__user_ids = list(self.uid_map.keys())
return self.__user_ids

@property
def item_ids(self):
"""Return the list of raw item ids"""
if self.__item_ids is None:
self.__item_ids = list(self.iid_map.keys())
return self.__item_ids

@property
def user_data(self):
"""Data organized by user. A dictionary where keys are users,
values are list of baskets purchased by corresponding users.
"""
if self.__user_data is None:
self.__user_data = defaultdict()
for u, basket in zip(*self.ub_tuple):
u_data = self.__user_data.setdefault(u, [])
u_data.append(basket)
return self.__user_data

@property
def chrono_user_data(self):
"""Data organized by user sorted chronologically (timestamps required).
A dictionary where keys are users, values are tuples of three chronologically
sorted lists (items, ratings, timestamps) interacted by the corresponding users.
"""
if self.timestamps is None:
raise ValueError("Timestamps are required but None!")

if self.__chrono_user_data is None:
self.__chrono_user_data = defaultdict()
for u, b, t in zip(*self.ub_tuple, self.timestamps):
u_data = self.__chrono_user_data.setdefault(u, ([], [], []))
u_data[0].append(b)
u_data[1].append(t)
# sorting based on timestamps
for user, (baskets, timestamps) in self.__chrono_user_data.items():
sorted_idx = np.argsort(timestamps)
sorted_baskets = [baskets[i] for i in sorted_idx]
sorted_timestamps = [timestamps[i] for i in sorted_idx]
self.__chrono_user_data[user] = (
sorted_baskets,
sorted_timestamps,
)
return self.__chrono_user_data

@classmethod
def build(
cls,
data,
fmt="UB",
global_uid_map=None,
global_iid_map=None,
seed=None,
exclude_unknowns=False,
):
"""Constructing Dataset from given data of specific format.
Parameters
----------
data: list, required
Data in the form of tuple (user, basket) for UB format,
or tuple (user, timestamps, basket) for UTB format.
fmt: str, default: 'UB'
Format of the input data. Currently, we are supporting:
'UB': User, Basket
'UTB': User, Timestamp, Basket
global_uid_map: :obj:`defaultdict`, optional, default: None
The dictionary containing global mapping from original ids to mapped ids of users.
global_iid_map: :obj:`defaultdict`, optional, default: None
The dictionary containing global mapping from original ids to mapped ids of items.
seed: int, optional, default: None
Random seed for reproducing data sampling.
exclude_unknowns: bool, default: False
Ignore unknown users and items.
Returns
-------
res: :obj:`<cornac.data.BasketDataset>`
BasketDataset object.
"""
fmt = validate_format(fmt, ["UB", "UTB"])

if global_uid_map is None:
global_uid_map = OrderedDict()
if global_iid_map is None:
global_iid_map = OrderedDict()

uid_map = OrderedDict()
iid_map = OrderedDict()

u_indices = []
i_indices = []
baskets = []
valid_idx = []

for idx, basket_tuples in enumerate(data):
uid = basket_tuples[0]
raw_basket = basket_tuples[-1]
if exclude_unknowns:
raw_basket = [tup for tup in raw_basket if tup[0] in global_iid_map]

if len(raw_basket) == 0:
continue

uid_map[uid] = global_uid_map.setdefault(uid, len(global_uid_map))
for (iid, *_) in raw_basket:
iid_map[iid] = global_iid_map.setdefault(iid, len(global_iid_map))

u_indices.append(uid_map[uid])
i_indices.append(iid_map[iid])
baskets.append([tuple((iid_map[tup[0]], *tup[1:])) for tup in raw_basket])
valid_idx.append(idx)

ub_tuple = (np.asarray(u_indices, dtype="int"), baskets)

timestamps = (
np.fromiter((int(data[i][1]) for i in valid_idx), dtype="int")
if fmt == "UTB"
else None
)

dataset = cls(
num_users=len(global_uid_map),
num_items=len(global_iid_map),
uid_map=global_uid_map,
iid_map=global_iid_map,
ub_tuple=ub_tuple,
timestamps=timestamps,
seed=seed,
)

return dataset

@classmethod
def from_ub(cls, data, seed=None):
"""Constructing Dataset from UB (User, Basket) tuple data.
Parameters
----------
data: list
Data in the form of tuple (user, basket).
Each basket is a list of items [(item, <optional attributes>),...].
seed: int, optional, default: None
Random seed for reproducing data sampling.
Returns
-------
res: :obj:`<cornac.data.BasketDataset>`
BasketDataset object.
"""
return cls.build(data, fmt="UB", seed=seed)

@classmethod
def from_utb(cls, data, seed=None):
"""Constructing Dataset from UTB format (User, Timestamp, Basket)
Parameters
----------
data: tuple
Data in the form of triplets (user, timestamp, basket)
seed: int, optional, default: None
Random seed for reproducing data sampling.
Returns
-------
res: :obj:`<cornac.data.BasketDataset>`
BasketDataset object.
"""
return cls.build(data, fmt="UTB", seed=seed)

def reset(self):
"""Reset the random number generator for reproducibility"""
self.rng = get_rng(self.seed)
return self

def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.ub_tuple[0]), batch_size)

def idx_iter(self, idx_range, batch_size=1, shuffle=False):
"""Create an iterator over batch of indices
Parameters
----------
batch_size: int, optional, default = 1
shuffle: bool, optional
If True, orders of triplets will be randomized. If False, default orders kept
Returns
-------
iterator : batch of indices (array of 'int')
"""
indices = np.arange(idx_range)
if shuffle:
self.rng.shuffle(indices)

n_batches = estimate_batches(len(indices), batch_size)
for b in range(n_batches):
start_offset = batch_size * b
end_offset = batch_size * b + batch_size
end_offset = min(end_offset, len(indices))
batch_ids = indices[start_offset:end_offset]
yield batch_ids

def ub_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users, baskets
Parameters
----------
batch_size: int, optional, default = 1
shuffle: bool, optional, default: False
If `True`, orders of triplets will be randomized. If `False`, default orders kept.
Returns
-------
iterator : batch of users (array of 'int'), batch of baskets (list of list)
"""
for batch_ids in self.idx_iter(len(self.ub_tuple[0]), batch_size, shuffle):
batch_users = self.ub_tuple[0][batch_ids]
batch_baskets = [self.ub_tuple[-1][idx] for idx in batch_ids]

yield batch_users, batch_baskets

def user_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over user indices
Parameters
----------
batch_size : int, optional, default = 1
shuffle : bool, optional
If True, orders of triplets will be randomized. If False, default orders kept
Returns
-------
iterator : batch of user indices (array of 'int')
"""
user_indices = np.fromiter(set(self.ub_tuple[0]), dtype="int")
for batch_ids in self.idx_iter(len(user_indices), batch_size, shuffle):
yield user_indices[batch_ids]

def add_modalities(self, **kwargs):
self.user_feature = kwargs.get("user_feature", None)
self.item_feature = kwargs.get("item_feature", None)
self.user_text = kwargs.get("user_text", None)
self.item_text = kwargs.get("item_text", None)
self.user_image = kwargs.get("user_image", None)
self.item_image = kwargs.get("item_image", None)
self.user_graph = kwargs.get("user_graph", None)
self.item_graph = kwargs.get("item_graph", None)
self.sentiment = kwargs.get("sentiment", None)
self.review_text = kwargs.get("review_text", None)
Loading

0 comments on commit 0e1a497

Please sign in to comment.