diff --git a/cornac/data/__init__.py b/cornac/data/__init__.py index 718cd23e0..5b77aa6db 100644 --- a/cornac/data/__init__.py +++ b/cornac/data/__init__.py @@ -22,13 +22,17 @@ from .reader import Reader from .dataset import Dataset from .dataset import BasketDataset +from .dataset import SequentialDataset -__all__ = ['FeatureModality', - 'TextModality', - 'ReviewModality', - 'ImageModality', - 'GraphModality', - 'SentimentModality', - 'BasketDataset', - 'Dataset', - 'Reader'] +__all__ = [ + "FeatureModality", + "TextModality", + "ReviewModality", + "ImageModality", + "GraphModality", + "SentimentModality", + "BasketDataset", + "SequentialDataset", + "Dataset", + "Reader", +] diff --git a/cornac/data/dataset.py b/cornac/data/dataset.py index d9e41c452..c3e3b2709 100644 --- a/cornac/data/dataset.py +++ b/cornac/data/dataset.py @@ -13,8 +13,8 @@ # limitations under the License. # ============================================================================ -import os import copy +import os import pickle import warnings from collections import Counter, OrderedDict, defaultdict @@ -249,9 +249,7 @@ def csc_matrix(self): def dok_matrix(self): """The user-item interaction matrix in DOK sparse format""" if self.__dok_matrix is None: - self.__dok_matrix = dok_matrix( - (self.num_users, self.num_items), dtype="float" - ) + self.__dok_matrix = dok_matrix((self.num_users, self.num_items), dtype="float") for u, i, r in zip(*self.uir_tuple): self.__dok_matrix[u, i] = r return self.__dok_matrix @@ -317,9 +315,7 @@ def build( dup_count = 0 for idx, (uid, iid, rating, *_) in enumerate(data): - if exclude_unknowns and ( - uid not in global_uid_map or iid not in global_iid_map - ): + if exclude_unknowns and (uid not in global_uid_map or iid not in global_iid_map): continue if (uid, iid) in ui_set: @@ -347,11 +343,7 @@ def build( np.asarray(r_values, dtype="float"), ) - timestamps = ( - np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int") - if fmt == "UIRT" - else None - ) + timestamps = np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int") if fmt == "UIRT" else None dataset = cls( num_users=len(global_uid_map), @@ -491,9 +483,7 @@ def uir_iter(self, batch_size=1, shuffle=False, binary=False, num_zeros=0): neg_items[i] = j batch_users = np.concatenate((batch_users, repeated_users)) batch_items = np.concatenate((batch_items, neg_items)) - batch_ratings = np.concatenate( - (batch_ratings, np.zeros_like(neg_items)) - ) + batch_ratings = np.concatenate((batch_ratings, np.zeros_like(neg_items))) yield batch_users, batch_items, batch_ratings @@ -521,9 +511,7 @@ def uij_iter(self, batch_size=1, shuffle=False, neg_sampling="uniform"): elif neg_sampling.lower() == "popularity": neg_population = self.uir_tuple[1] else: - raise ValueError( - "Unsupported negative sampling option: {}".format(neg_sampling) - ) + raise ValueError("Unsupported negative sampling option: {}".format(neg_sampling)) for batch_ids in self.idx_iter(len(self.uir_tuple[0]), batch_size, shuffle): batch_users = self.uir_tuple[0][batch_ids] @@ -742,9 +730,7 @@ def chrono_user_basket_data(self): if self.__chrono_user_basket_data is None: assert self.timestamps is not None # we need timestamps - basket_timestamps = [ - self.timestamps[ids[0]] for ids in self.baskets.values() - ] # one-off + basket_timestamps = [self.timestamps[ids[0]] for ids in self.baskets.values()] # one-off self.__chrono_user_basket_data = defaultdict(lambda: ([], [])) for (bid, ids), t in zip(self.baskets.items(), basket_timestamps): @@ -847,9 +833,7 @@ def build( basket_indices = np.asarray(b_indices, dtype="int") timestamps = ( - np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int") - if fmt in ["UBIT", "UBITJson"] - else None + np.fromiter((int(data[i][3]) for i in valid_idx), dtype="int") if fmt in ["UBIT", "UBITJson"] else None ) extra_data = [data[i][4] for i in valid_idx] if fmt == "UBITJson" else None @@ -967,8 +951,7 @@ def ubi_iter(self, batch_size=1, shuffle=False): _, item_indices, _ = self.uir_tuple for batch_users, batch_baskets in self.ub_iter(batch_size, shuffle): batch_basket_items = [ - [item_indices[self.baskets[bid]] for bid in user_baskets] - for user_baskets in batch_baskets + [item_indices[self.baskets[bid]] for bid in user_baskets] for user_baskets in batch_baskets ] yield batch_users, batch_baskets, batch_basket_items @@ -990,3 +973,411 @@ def basket_iter(self, batch_size=1, shuffle=False): basket_indices = np.fromiter(set(self.baskets.keys()), dtype="int") for batch_ids in self.idx_iter(len(basket_indices), batch_size, shuffle): yield basket_indices[batch_ids] + + +class SequentialDataset(Dataset): + """Training set contains history sessions + + Parameters + ---------- + num_users: int, required + Number of users. + + num_items: int, required + Number of items. + + uid_map: :obj:`OrderDict`, required + The dictionary containing mapping from user original ids to mapped integer indices. + + iid_map: :obj:`OrderDict`, required + The dictionary containing mapping from item original ids to mapped integer indices. + + uir_tuple: tuple, required + Tuple of 3 numpy arrays (user_indices, item_indices, rating_values). + + session_ids: numpy.array, required + Array of session indices corresponding to observation in `uir_tuple`. + + timestamps: numpy.array, optional, default: None + Numpy array of timestamps corresponding to feedback in `uir_tuple`. + This is only available when input data is in `SIT`, `USIT`, SITJson`, and `USITJson` formats. + + extra_data: numpy.array, optional, default: None + Array of json object corresponding to observations in `uir_tuple`. + + seed: int, optional, default: None + Random seed for reproducing data sampling. + + Attributes + ---------- + timestamps: numpy.array + Numpy array of timestamps corresponding to feedback in `ubi_tuple`. + This is only available when input data is in `UTB` format. + """ + + def __init__( + self, + num_users, + num_sessions, + num_items, + uid_map, + sid_map, + iid_map, + uir_tuple, + session_indices=None, + timestamps=None, + extra_data=None, + seed=None, + ): + super().__init__( + num_users=num_users, + num_items=num_items, + uid_map=uid_map, + iid_map=iid_map, + uir_tuple=uir_tuple, + timestamps=timestamps, + seed=seed, + ) + self.num_sessions = num_sessions + self.sid_map = sid_map + self.session_indices = session_indices + self.extra_data = extra_data + session_sizes = list(Counter(session_indices).values()) + self.max_session_size = np.max(session_sizes) + self.min_session_size = np.min(session_sizes) + self.avg_session_size = np.mean(session_sizes) + + self.__sessions = None + self.__session_ids = None + self.__user_session_data = None + self.__chrono_user_session_data = None + + @property + def session_ids(self): + """Return the list of raw session ids""" + if self.__session_ids is None: + self.__session_ids = list(self.sid_map.keys()) + return self.__session_ids + + @property + def sessions(self): + """A dictionary to store indices where session ID appears in the data.""" + if self.__sessions is None: + self.__sessions = OrderedDict() + for idx, sid in enumerate(self.session_indices): + self.__sessions.setdefault(sid, []) + self.__sessions[sid].append(idx) + return self.__sessions + + @property + def user_session_data(self): + """Data organized by user. A dictionary where keys are users, + values are list of sessions purchased by corresponding users. + """ + if self.__user_session_data is None: + self.__user_session_data = defaultdict(list) + for sid, ids in self.sessions.items(): + u = self.uir_tuple[0][ids[0]] + self.__user_session_data[u].append(sid) + return self.__user_session_data + + @property + def chrono_user_session_data(self): + """Data organized by user sorted chronologically (timestamps required). + A dictionary where keys are users, values are tuples of three chronologically + sorted lists (sessions, timestamps) interacted by the corresponding users. + """ + if self.__chrono_user_session_data is None: + assert self.timestamps is not None # we need timestamps + + session_timestamps = [self.timestamps[ids[0]] for ids in self.sessions.values()] # one-off + + self.__chrono_user_session_data = defaultdict(lambda: ([], [])) + for (sid, ids), t in zip(self.sessions.items(), session_timestamps): + u = self.uir_tuple[0][ids[0]] + self.__chrono_user_session_data[u][0].append(sid) + self.__chrono_user_session_data[u][1].append(t) + + # sorting based on timestamps + for user, (sessions, timestamps) in self.__chrono_user_session_data.items(): + sorted_idx = np.argsort(timestamps) + sorted_sessions = [sessions[i] for i in sorted_idx] + sorted_timestamps = [timestamps[i] for i in sorted_idx] + self.__chrono_user_session_data[user] = ( + sorted_sessions, + sorted_timestamps, + ) + + return self.__chrono_user_session_data + + @classmethod + def build( + cls, + data, + fmt="SIT", + global_uid_map=None, + global_sid_map=None, + global_iid_map=None, + seed=None, + exclude_unknowns=False, + ): + """Constructing Dataset from given data of specific format. + + Parameters + ---------- + data: list, required + Data in the form of tuple (user, session) for UB format, + or tuple (user, timestamps, session) for UTB format. + + fmt: str, default: 'SIT' + Format of the input data. Currently, we are supporting: + + 'SIT': Session_ID, Item, Timestamp + 'USIT': User, Session_ID, Item, Timestamp + 'SITJson': Session_ID, Item, Timestamp, Extra data in Json format + 'USITJson': User, Session_ID, Item, Timestamp, Extra data in Json format + + global_uid_map: :obj:`defaultdict`, optional, default: None + The dictionary containing global mapping from original ids to mapped ids of users. + + global_sid_map: :obj:`defaultdict`, optional, default: None + The dictionary containing global mapping from original ids to mapped ids of sessions. + + global_iid_map: :obj:`defaultdict`, optional, default: None + The dictionary containing global mapping from original ids to mapped ids of items. + + seed: int, optional, default: None + Random seed for reproducing data sampling. + + exclude_unknowns: bool, default: False + Ignore unknown users and items. + + Returns + ------- + res: :obj:`` + SequentialDataset object. + + """ + fmt = validate_format(fmt, ["SIT", "USIT", "SITJson", "USITJson"]) + + if global_uid_map is None: + global_uid_map = OrderedDict() + if global_sid_map is None: + global_sid_map = OrderedDict() + if global_iid_map is None: + global_iid_map = OrderedDict() + + u_indices = [] + s_indices = [] + i_indices = [] + valid_idx = [] + extra_data = [] + for idx, tup in enumerate(data): + uid, sid, iid, *_ = tup if fmt in ["USIT", "USITJson"] else [None] + list(tup) + if exclude_unknowns and (iid not in global_iid_map): + continue + global_uid_map.setdefault(uid, len(global_uid_map)) + global_sid_map.setdefault(sid, len(global_sid_map)) + global_iid_map.setdefault(iid, len(global_iid_map)) + + u_indices.append(global_uid_map[uid]) + s_indices.append(global_sid_map[sid]) + i_indices.append(global_iid_map[iid]) + valid_idx.append(idx) + + uir_tuple = ( + np.asarray(u_indices, dtype="int"), + np.asarray(i_indices, dtype="int"), + np.ones(len(u_indices), dtype="float"), + ) + + session_indices = np.asarray(s_indices, dtype="int") + + ts_pos = 3 if fmt in ["USIT", "USITJson"] else 2 + timestamps = ( + np.fromiter((int(data[i][ts_pos]) for i in valid_idx), dtype="int") + if fmt in ["SIT", "SITJson", "USIT", "USITJson"] + else None + ) + + extra_pos = ts_pos + 1 + extra_data = [data[i][extra_pos] for i in valid_idx] if fmt in ["SITJson", "USITJson"] else None + + dataset = cls( + num_users=len(global_uid_map), + num_sessions=len(set(session_indices)), + num_items=len(global_iid_map), + uid_map=global_uid_map, + sid_map=global_sid_map, + iid_map=global_iid_map, + uir_tuple=uir_tuple, + session_indices=session_indices, + timestamps=timestamps, + extra_data=extra_data, + seed=seed, + ) + + return dataset + + @classmethod + def from_sit(cls, data, seed=None): + """Constructing Dataset from SIT (Session, Item, Timestamp) triples data. + + Parameters + ---------- + data: list + Data in the form of tuples (session, item, timestamp). + + seed: int, optional, default: None + Random seed for reproducing data sampling. + + Returns + ------- + res: :obj:`` + SequentialDataset object. + + """ + return cls.build(data, fmt="SIT", seed=seed) + + @classmethod + def from_usit(cls, data, seed=None): + """Constructing Dataset from USIT format (User, Session, Item, Timestamp) + + Parameters + ---------- + data: tuple + Data in the form of quadruples (user, session, item, timestamp) + + seed: int, optional, default: None + Random seed for reproducing data sampling. + + Returns + ------- + res: :obj:`` + SequentialDataset object. + + """ + return cls.build(data, fmt="USIT", seed=seed) + + @classmethod + def from_sitjson(cls, data, seed=None): + """Constructing Dataset from SITJson format (Session, Item, Timestamp, Json) + + Parameters + ---------- + data: tuple + Data in the form of tuples (session, item, timestamp, json) + + seed: int, optional, default: None + Random seed for reproducing data sampling. + + Returns + ------- + res: :obj:`` + SequentialDataset object. + + """ + return cls.build(data, fmt="SITJson", seed=seed) + + @classmethod + def from_usitjson(cls, data, seed=None): + """Constructing Dataset from USITJson format (User, Session, Item, Timestamp, Json) + + Parameters + ---------- + data: tuple + Data in the form of tuples (user, session, item, timestamp, json) + + seed: int, optional, default: None + Random seed for reproducing data sampling. + + Returns + ------- + res: :obj:`` + SequentialDataset object. + + """ + return cls.build(data, fmt="USITJson", seed=seed) + + def num_batches(self, batch_size): + """Estimate number of batches per epoch""" + return estimate_batches(len(self.sessions), batch_size) + + def session_iter(self, batch_size=1, shuffle=False): + """Create an iterator over session indices + + Parameters + ---------- + batch_size: int, optional, default = 1 + + shuffle: bool, optional, default: False + If `True`, orders of session_ids will be randomized. If `False`, default orders kept. + + Returns + ------- + iterator : batch of session indices (array of 'int') + + """ + session_indices = np.array(list(self.sessions.keys())) + for batch_ids in self.idx_iter(len(session_indices), batch_size, shuffle): + batch_session_indices = session_indices[batch_ids] + yield batch_session_indices + + def s_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of sessions + + Parameters + ---------- + batch_size: int, optional, default = 1 + + shuffle: bool, optional, default: False + If `True`, orders of sessions will be randomized. If `False`, default orders kept. + + Returns + ------- + iterator : batch of session indices, batch of indices corresponding to session indices + + """ + for batch_session_ids in self.session_iter(batch_size, shuffle): + batch_mapped_ids = [self.sessions[sid] for sid in batch_session_ids] + yield batch_session_ids, batch_mapped_ids + + def si_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of session indices, batch of mapped ids, and batch of sessions' items + + Parameters + ---------- + batch_size: int, optional, default = 1 + + shuffle: bool, optional, default: False + If `True`, orders of triplets will be randomized. If `False`, default orders kept. + + Returns + ------- + iterator : batch of session indices, batch mapped ids, batch of sessions' items (list of list) + + """ + for batch_session_indices, batch_mapped_ids in self.s_iter(batch_size, shuffle): + batch_session_items = [[self.uir_tuple[1][i] for i in ids] for ids in batch_mapped_ids] + yield batch_session_indices, batch_mapped_ids, batch_session_items + + def usi_iter(self, batch_size=1, shuffle=False): + """Create an iterator over data yielding batch of user indices, batch of session indices, batch of mapped ids, and batch of sessions' items + + Parameters + ---------- + batch_size: int, optional, default = 1 + + shuffle: bool, optional, default: False + If `True`, orders of triplets will be randomized. If `False`, default orders kept. + + Returns + ------- + iterator : batch of user indices, batch of session indices (list of list), batch mapped ids (list of list of list), batch of sessions' items (list of list of list) + + """ + for user_indices in self.user_iter(batch_size, shuffle): + batch_sids = [[sid for sid in self.user_session_data[uid]] for uid in user_indices] + batch_mapped_ids = [[self.sessions[sid] for sid in self.user_session_data[uid]] for uid in user_indices] + batch_session_items = [[[self.uir_tuple[1][i] for i in ids] for ids in u_batch_mapped_ids] for u_batch_mapped_ids in batch_mapped_ids] + yield user_indices, batch_sids, batch_mapped_ids, batch_session_items diff --git a/cornac/data/reader.py b/cornac/data/reader.py index 060257415..4efedb0ad 100644 --- a/cornac/data/reader.py +++ b/cornac/data/reader.py @@ -61,6 +61,24 @@ def ubitjson_parser(tokens, **kwargs): ] +def sit_parser(tokens, **kwargs): + return [(tokens[0], tokens[1], int(tokens[2]))] + + +def sitjson_parser(tokens, **kwargs): + return [(tokens[0], tokens[1], int(tokens[2]), ast.literal_eval(tokens[3]))] + + +def usit_parser(tokens, **kwargs): + return [(tokens[0], tokens[1], tokens[2], int(tokens[3]))] + + +def usitjson_parser(tokens, **kwargs): + return [ + (tokens[0], tokens[1], tokens[2], int(tokens[3]), ast.literal_eval(tokens[4])) + ] + + PARSERS = { "UI": ui_parser, "UIR": uir_parser, @@ -70,6 +88,10 @@ def ubitjson_parser(tokens, **kwargs): "UBI": ubi_parser, "UBIT": ubit_parser, "UBITJson": ubitjson_parser, + "SIT": sit_parser, + "SITJson": sitjson_parser, + "USIT": usit_parser, + "USITJson": usitjson_parser, } @@ -94,6 +116,14 @@ class Reader: The minimum frequency of an item to be retained. If `min_item_freq = 1`, all items will be included. + num_top_freq_user: int, default = 0 + The number of top popular users to be retained. + If `num_top_freq_user = 0`, all users will be included. + + num_top_freq_item: int, default = 0 + The number of top popular items to be retained. + If `num_top_freq_item = 0`, all items will be included. + min_basket_size: int, default = 1 The minimum number of items of a basket to be retained. If `min_basket_size = 1`, all items will be included. @@ -106,6 +136,14 @@ class Reader: The minimum number of baskets of a user to be retained. If `min_basket_sequence = 1`, all baskets will be included. + min_sequence_size: int, default = 1 + The minimum number of items of a sequence to be retained. + If `min_sequence_size = 1`, all sequences will be included. + + max_sequence_size: int, default = -1 + The maximum number of items of a sequence to be retained. + If `min_sequence_size = -1`, all sequences will be included. + bin_threshold: float, default = None The rating threshold to binarize rating values (turn explicit feedback to implicit feedback). For example, if `bin_threshold = 3.0`, all rating values >= 3.0 will be set to 1.0, @@ -126,9 +164,13 @@ def __init__( item_set=None, min_user_freq=1, min_item_freq=1, + num_top_freq_user=0, + num_top_freq_item=0, min_basket_size=1, max_basket_size=-1, min_basket_sequence=1, + min_sequence_size=1, + max_sequence_size=-1, bin_threshold=None, encoding="utf-8", errors=None, @@ -145,48 +187,95 @@ def __init__( ) self.min_uf = min_user_freq self.min_if = min_item_freq + self.num_top_freq_user = num_top_freq_user + self.num_top_freq_item = num_top_freq_item self.min_basket_size = min_basket_size self.max_basket_size = max_basket_size self.min_basket_sequence = min_basket_sequence + self.min_sequence_size = min_sequence_size + self.max_sequence_size = max_sequence_size self.bin_threshold = bin_threshold self.encoding = encoding self.errors = errors - def _filter(self, tuples): - if self.bin_threshold is not None: + def _filter(self, tuples, fmt="UIR"): + i_pos = fmt.find("I") + u_pos = fmt.find("U") + r_pos = fmt.find("R") + + if self.bin_threshold is not None and r_pos >= 0: def binarize(t): t = list(t) - t[2] = 1.0 + t[r_pos] = 1.0 return tuple(t) - tuples = [binarize(t) for t in tuples if t[2] >= self.bin_threshold] + tuples = [binarize(t) for t in tuples if t[r_pos] >= self.bin_threshold] + + if self.num_top_freq_user > 0: + user_freq = Counter(t[u_pos] for t in tuples) + top_freq_users = set( + k for (k, _) in user_freq.most_common(self.num_top_freq_user) + ) + tuples = [t for t in tuples if t[u_pos] in top_freq_users] + + if self.num_top_freq_item > 0: + item_freq = Counter(t[i_pos] for t in tuples) + top_freq_items = set( + k for (k, _) in item_freq.most_common(self.num_top_freq_item) + ) + tuples = [t for t in tuples if t[i_pos] in top_freq_items] if self.user_set is not None: - tuples = [t for t in tuples if t[0] in self.user_set] + tuples = [t for t in tuples if t[u_pos] in self.user_set] if self.item_set is not None: - tuples = [t for t in tuples if t[1] in self.item_set] + tuples = [t for t in tuples if t[i_pos] in self.item_set] if self.min_uf > 1: - user_freq = Counter(t[0] for t in tuples) - tuples = [t for t in tuples if user_freq[t[0]] >= self.min_uf] + user_freq = Counter(t[u_pos] for t in tuples) + tuples = [t for t in tuples if user_freq[t[u_pos]] >= self.min_uf] if self.min_if > 1: - item_freq = Counter(t[1] for t in tuples) - tuples = [t for t in tuples if item_freq[t[1]] >= self.min_if] + item_freq = Counter(t[i_pos] for t in tuples) + tuples = [t for t in tuples if item_freq[t[i_pos]] >= self.min_if] + + return tuples + + def _filter_basket(self, tuples, fmt="UBI"): + u_pos = fmt.find("U") + b_pos = fmt.find("B") if self.min_basket_size > 1: - basket_size = Counter(t[1] for t in tuples) - tuples = [t for t in tuples if basket_size[t[1]] >= self.min_basket_size] + sizes = Counter(t[b_pos] for t in tuples) + tuples = [t for t in tuples if sizes[t[b_pos]] >= self.min_basket_size] if self.max_basket_size > 1: - basket_size = Counter(t[1] for t in tuples) - tuples = [t for t in tuples if basket_size[t[1]] <= self.max_basket_size] + sizes = Counter(t[b_pos] for t in tuples) + tuples = [t for t in tuples if sizes[t[b_pos]] <= self.max_basket_size] if self.min_basket_sequence > 1: - basket_sequence = Counter(u for (u, _) in set((t[0], t[1]) for t in tuples)) - tuples = [t for t in tuples if basket_sequence[t[0]] >= self.min_basket_sequence] + basket_sequence = Counter( + u for (u, _) in set((t[u_pos], t[b_pos]) for t in tuples) + ) + tuples = [ + t + for t in tuples + if basket_sequence[t[u_pos]] >= self.min_basket_sequence + ] + + return tuples + + def _filter_sequence(self, tuples, fmt="SIT"): + s_pos = fmt.find("S") + + if self.min_sequence_size > 1: + sizes = Counter(t[s_pos] for t in tuples) + tuples = [t for t in tuples if sizes[t[s_pos]] >= self.min_sequence_size] + + if self.max_sequence_size > 1: + sizes = Counter(t[s_pos] for t in tuples) + tuples = [t for t in tuples if sizes[t[s_pos]] <= self.max_sequence_size] return tuples @@ -236,7 +325,7 @@ def read( if parser is None: raise ValueError( "Invalid line format: {}\n" - "Only support: {}".format(fmt, PARSERS.keys()) + "Supported formats: {}".format(fmt, PARSERS.keys()) ) with open(fpath, encoding=self.encoding, errors=self.errors) as f: @@ -247,7 +336,12 @@ def read( line.strip().split(sep), line_idx=idx, id_inline=id_inline, **kwargs ) ] - return self._filter(tuples) + tuples = self._filter(tuples=tuples, fmt=fmt) + if fmt in {"UBI", "UBIT", "UBITJson"}: + tuples = self._filter_basket(tuples=tuples, fmt=fmt) + elif fmt in {"SIT", "SITJson", "USIT", "USITJson"}: + tuples = self._filter_sequence(tuples=tuples, fmt=fmt) + return tuples def read_text(fpath, sep=None, encoding="utf-8", errors=None): diff --git a/cornac/datasets/__init__.py b/cornac/datasets/__init__.py index 5d406171b..ff098baf4 100644 --- a/cornac/datasets/__init__.py +++ b/cornac/datasets/__init__.py @@ -19,7 +19,9 @@ from . import citeulike from . import epinions from . import filmtrust +from . import gowalla from . import movielens from . import netflix from . import tafeng -from . import tradesy \ No newline at end of file +from . import tradesy +from . import yoochoose diff --git a/cornac/datasets/gowalla.py b/cornac/datasets/gowalla.py new file mode 100644 index 000000000..c623fe194 --- /dev/null +++ b/cornac/datasets/gowalla.py @@ -0,0 +1,47 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This data is built based on the Ta Feng Grocery Dataset that contains +a Chinese grocery store transaction data from November 2000 to February 2001. +Accessed at https://www.kaggle.com/datasets/chiranjivdas09/ta-feng-grocery-dataset +""" + +from ..utils import cache +from ..data import Reader +from typing import List + + +def load_checkins(fmt="USITJson", reader: Reader = None) -> List: + """Load the time and location information of check-ins made by users + + Parameters + ---------- + reader: `obj:cornac.data.Reader`, default: None + Reader object used to read the data. + + Returns + ------- + data: array-like + Data in the form of a list of tuples (user, session, item, timestamp, json). + Location information is stored in `json` format + """ + fpath = cache( + url="https://static.preferred.ai/datasets/gowalla/check-ins.zip", + unzip=True, + relative_path="gowalla/check-ins.txt", + ) + reader = Reader() if reader is None else reader + return reader.read(fpath, fmt=fmt, sep="\t") + diff --git a/cornac/datasets/yoochoose.py b/cornac/datasets/yoochoose.py new file mode 100644 index 000000000..726feff1d --- /dev/null +++ b/cornac/datasets/yoochoose.py @@ -0,0 +1,93 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This data is built based on the Ta Feng Grocery Dataset that contains +a Chinese grocery store transaction data from November 2000 to February 2001. +Accessed at https://www.kaggle.com/datasets/chiranjivdas09/ta-feng-grocery-dataset +""" + +from typing import List + +from ..data import Reader +from ..utils import cache + + +def load_buy(fmt="SITJson", reader: Reader = None) -> List: + """Load the time and location information of check-ins made by users + + Parameters + ---------- + reader: `obj:cornac.data.Reader`, default: None + Reader object used to read the data. + + Returns + ------- + data: array-like + Data in the form of a list of tuples (user, session, item, timestamp, json). + Location information is stored in `json` format + """ + fpath = cache( + url="https://static.preferred.ai/datasets/yoochoose/buy.zip", + unzip=True, + relative_path="yoochoose/buy.txt", + ) + reader = Reader() if reader is None else reader + return reader.read(fpath, fmt=fmt, sep="\t") + + +def load_click(fmt="SITJson", reader: Reader = None) -> List: + """Load the time and location information of check-ins made by users + + Parameters + ---------- + reader: `obj:cornac.data.Reader`, default: None + Reader object used to read the data. + + Returns + ------- + data: array-like + Data in the form of a list of tuples (user, session, item, timestamp, json). + Location information is stored in `json` format + """ + fpath = cache( + url="https://static.preferred.ai/datasets/yoochoose/click.zip", + unzip=True, + relative_path="yoochoose/click.txt", + ) + reader = Reader() if reader is None else reader + return reader.read(fpath, fmt=fmt, sep="\t") + + +def load_test(fmt="SITJson", reader: Reader = None) -> List: + """Load the time and location information of check-ins made by users + + Parameters + ---------- + reader: `obj:cornac.data.Reader`, default: None + Reader object used to read the data. + + Returns + ------- + data: array-like + Data in the form of a list of tuples (user, session, item, timestamp, json). + Location information is stored in `json` format + """ + fpath = cache( + url="https://static.preferred.ai/datasets/yoochoose/test.zip", + unzip=True, + relative_path="yoochoose/test.txt", + ) + reader = Reader() if reader is None else reader + return reader.read(fpath, fmt=fmt, sep="\t") diff --git a/cornac/eval_methods/__init__.py b/cornac/eval_methods/__init__.py index bd0d3dfbc..24fb729c5 100644 --- a/cornac/eval_methods/__init__.py +++ b/cornac/eval_methods/__init__.py @@ -21,11 +21,15 @@ from .stratified_split import StratifiedSplit from .cross_validation import CrossValidation from .next_basket_evaluation import NextBasketEvaluation +from .next_item_evaluation import NextItemEvaluation from .propensity_stratified_evaluation import PropensityStratifiedEvaluation -__all__ = ['BaseMethod', - 'RatioSplit', - 'StratifiedSplit', - 'CrossValidation', - 'NextBasketEvaluation', - 'PropensityStratifiedEvaluation'] \ No newline at end of file +__all__ = [ + "BaseMethod", + "RatioSplit", + "StratifiedSplit", + "CrossValidation", + "NextBasketEvaluation", + "NextItemEvaluation", + "PropensityStratifiedEvaluation", +] diff --git a/cornac/eval_methods/base_method.py b/cornac/eval_methods/base_method.py index 7540bf40c..7904a9e4d 100644 --- a/cornac/eval_methods/base_method.py +++ b/cornac/eval_methods/base_method.py @@ -275,8 +275,8 @@ def __init__( self.verbose = verbose self.seed = seed self.rng = get_rng(seed) - self.global_uid_map = OrderedDict() - self.global_iid_map = OrderedDict() + self.global_uid_map = kwargs.get("global_uid_map", OrderedDict()) + self.global_iid_map = kwargs.get("global_iid_map", OrderedDict()) self.user_feature = kwargs.get("user_feature", None) self.user_text = kwargs.get("user_text", None) diff --git a/cornac/eval_methods/next_item_evaluation.py b/cornac/eval_methods/next_item_evaluation.py new file mode 100644 index 000000000..bff23f331 --- /dev/null +++ b/cornac/eval_methods/next_item_evaluation.py @@ -0,0 +1,350 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from collections import OrderedDict, defaultdict + +import numpy as np +from tqdm.auto import tqdm + +from ..data import SequentialDataset +from ..experiment.result import Result +from . import BaseMethod + + +def ranking_eval( + model, + metrics, + train_set, + test_set, + user_based=False, + exclude_unknowns=True, + verbose=False, +): + """Evaluate model on provided ranking metrics. + + Parameters + ---------- + model: :obj:`cornac.models.NextItemRecommender`, required + NextItemRecommender model to be evaluated. + + metrics: :obj:`iterable`, required + List of rating metrics :obj:`cornac.metrics.RankingMetric`. + + train_set: :obj:`cornac.data.SequentialDataset`, required + SequentialDataset to be used for model training. This will be used to exclude + observations already appeared during training. + + test_set: :obj:`cornac.data.SequentialDataset`, required + SequentialDataset to be used for evaluation. + + exclude_unknowns: bool, optional, default: True + Ignore unknown users and items during evaluation. + + verbose: bool, optional, default: False + Output evaluation progress. + + Returns + ------- + res: (List, List) + Tuple of two lists: + - average result for each of the metrics + - average result per user for each of the metrics + + """ + + if len(metrics) == 0: + return [], [] + + avg_results = [] + session_results = [{} for _ in enumerate(metrics)] + user_results = [defaultdict(list) for _ in enumerate(metrics)] + + user_sessions = defaultdict(list) + for [sid], [mapped_ids], [session_items] in tqdm( + test_set.si_iter(batch_size=1, shuffle=False), + total=len(test_set.sessions), + desc="Ranking", + disable=not verbose, + miniters=100, + ): + test_pos_items = session_items[-1:] # last item in the session + if len(test_pos_items) == 0: + continue + user_idx = test_set.uir_tuple[0][mapped_ids[0]] + if user_based: + user_sessions[user_idx].append(sid) + # binary mask for ground-truth positive items + u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int") + u_gt_pos_mask[test_pos_items] = 1 + + # binary mask for ground-truth negative items, removing all positive items + u_gt_neg_mask = np.ones(test_set.num_items, dtype="int") + u_gt_neg_mask[test_pos_items] = 0 + + # filter items being considered for evaluation + if exclude_unknowns: + u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items] + u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items] + + u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0] + u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0] + item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0] + + item_rank, item_scores = model.rank( + user_idx, + item_indices, + history_items=session_items[:-1], + history_mapped_ids=mapped_ids[:-1], + sessions=test_set.sessions, + session_indices=test_set.session_indices, + extra_data=test_set.extra_data, + ) + + for i, mt in enumerate(metrics): + mt_score = mt.compute( + gt_pos=u_gt_pos_items, + gt_neg=u_gt_neg_items, + pd_rank=item_rank, + pd_scores=item_scores, + item_indices=item_indices, + ) + if user_based: + user_results[i][user_idx].append(mt_score) + else: + session_results[i][sid] = mt_score + + # avg results of ranking metrics + for i, mt in enumerate(metrics): + if user_based: + user_ids = list(user_sessions.keys()) + user_avg_results = [np.mean(user_results[i][user_idx]) for user_idx in user_ids] + avg_results.append(np.mean(user_avg_results)) + else: + avg_results.append(sum(session_results[i].values()) / len(session_results[i])) + return avg_results, user_results + + +class NextItemEvaluation(BaseMethod): + """Next Item Recommendation Evaluation method + + Parameters + ---------- + data: list, required + Raw preference data in the tuple format [(user_id, sessions)]. + + test_size: float, optional, default: 0.2 + The proportion of the test set, \ + if > 1 then it is treated as the size of the test set. + + val_size: float, optional, default: 0.0 + The proportion of the validation set, \ + if > 1 then it is treated as the size of the validation set. + + fmt: str, default: 'SIT' + Format of the input data. Currently, we are supporting: + + 'SIT': Session, Item, Timestamp + 'USIT': User, Session, Item, Timestamp + 'SITJson': Session, Item, Timestamp, Json + 'USITJson': User, Session, Item, Timestamp, Json + + seed: int, optional, default: None + Random seed for reproducibility. + + exclude_unknowns: bool, optional, default: True + If `True`, unknown items will be ignored during model evaluation. + + verbose: bool, optional, default: False + Output running log. + + """ + + def __init__( + self, + data=None, + test_size=0.2, + val_size=0.0, + fmt="SIT", + seed=None, + exclude_unknowns=True, + verbose=False, + **kwargs, + ): + super().__init__( + data=data, + data_size=0 if data is None else len(data), + test_size=test_size, + val_size=val_size, + fmt=fmt, + seed=seed, + exclude_unknowns=exclude_unknowns, + verbose=verbose, + **kwargs, + ) + self.global_sid_map = kwargs.get("global_sid_map", OrderedDict()) + + def _build_datasets(self, train_data, test_data, val_data=None): + self.train_set = SequentialDataset.build( + data=train_data, + fmt=self.fmt, + global_uid_map=self.global_uid_map, + global_iid_map=self.global_iid_map, + global_sid_map=self.global_sid_map, + seed=self.seed, + exclude_unknowns=False, + ) + if self.verbose: + print("---") + print("Training data:") + print("Number of users = {}".format(self.train_set.num_users)) + print("Number of items = {}".format(self.train_set.num_items)) + print("Number of sessions = {}".format(self.train_set.num_sessions)) + + self.test_set = SequentialDataset.build( + data=test_data, + fmt=self.fmt, + global_uid_map=self.global_uid_map, + global_iid_map=self.global_iid_map, + global_sid_map=self.global_sid_map, + seed=self.seed, + exclude_unknowns=self.exclude_unknowns, + ) + if self.verbose: + print("---") + print("Test data:") + print("Number of users = {}".format(len(self.test_set.uid_map))) + print("Number of items = {}".format(len(self.test_set.iid_map))) + print("Number of sessions = {}".format(self.test_set.num_sessions)) + print("Number of unknown users = {}".format(self.test_set.num_users - self.train_set.num_users)) + print("Number of unknown items = {}".format(self.test_set.num_items - self.train_set.num_items)) + + if val_data is not None and len(val_data) > 0: + self.val_set = SequentialDataset.build( + data=val_data, + fmt=self.fmt, + global_uid_map=self.global_uid_map, + global_iid_map=self.global_iid_map, + seed=self.seed, + exclude_unknowns=self.exclude_unknowns, + ) + if self.verbose: + print("---") + print("Validation data:") + print("Number of users = {}".format(len(self.val_set.uid_map))) + print("Number of items = {}".format(len(self.val_set.iid_map))) + print("Number of sessions = {}".format(self.val_set.num_sessions)) + + self.total_sessions = 0 if self.val_set is None else self.val_set.num_sessions + self.total_sessions += self.test_set.num_sessions + self.train_set.num_sessions + if self.verbose: + print("---") + print("Total users = {}".format(self.total_users)) + print("Total items = {}".format(self.total_items)) + print("Total sessions = {}".format(self.total_sessions)) + + @staticmethod + def eval( + model, + train_set, + test_set, + exclude_unknowns, + ranking_metrics, + user_based=False, + verbose=False, + **kwargs, + ): + metric_avg_results = OrderedDict() + metric_user_results = OrderedDict() + + avg_results, user_results = ranking_eval( + model=model, + metrics=ranking_metrics, + train_set=train_set, + test_set=test_set, + user_based=user_based, + exclude_unknowns=exclude_unknowns, + verbose=verbose, + ) + + for i, mt in enumerate(ranking_metrics): + metric_avg_results[mt.name] = avg_results[i] + metric_user_results[mt.name] = user_results[i] + + return Result(model.name, metric_avg_results, metric_user_results) + + @classmethod + def from_splits( + cls, + train_data, + test_data, + val_data=None, + fmt="SIT", + exclude_unknowns=False, + seed=None, + verbose=False, + **kwargs, + ): + """Constructing evaluation method given data. + + Parameters + ---------- + train_data: array-like + Training data + + test_data: array-like + Test data + + val_data: array-like, optional, default: None + Validation data + + fmt: str, default: 'SIT' + Format of the input data. Currently, we are supporting: + + 'SIT': Session, Item, Timestamp + 'USIT': User, Session, Item, Timestamp + 'SITJson': Session, Item, Timestamp, Json + 'USITJson': User, Session, Item, Timestamp, Json + + rating_threshold: float, default: 1.0 + Threshold to decide positive or negative preferences. + + exclude_unknowns: bool, default: False + Whether to exclude unknown users/items in evaluation. + + seed: int, optional, default: None + Random seed for reproduce the splitting. + + verbose: bool, default: False + The verbosity flag. + + Returns + ------- + method: :obj:`` + Evaluation method object. + + """ + method = cls( + fmt=fmt, + exclude_unknowns=exclude_unknowns, + seed=seed, + verbose=verbose, + **kwargs, + ) + + return method.build( + train_data=train_data, + test_data=test_data, + val_data=val_data, + ) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index e1432c968..40fe1740a 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -15,6 +15,7 @@ from .recommender import Recommender from .recommender import NextBasketRecommender +from .recommender import NextItemRecommender from .amr import AMR from .ann import AnnoyANN @@ -68,6 +69,7 @@ from .sbpr import SBPR from .skm import SKMeans from .sorec import SoRec +from .spop import SPop from .svd import SVD from .tifuknn import TIFUKNN from .trirank import TriRank diff --git a/cornac/models/recommender.py b/cornac/models/recommender.py index 6e174ce1a..62a17f7f4 100644 --- a/cornac/models/recommender.py +++ b/cornac/models/recommender.py @@ -13,20 +13,19 @@ # limitations under the License. # ============================================================================ -import os import copy import inspect +import os import pickle import warnings -from glob import glob from datetime import datetime +from glob import glob import numpy as np from ..exception import ScoreException from ..utils.common import clip - MEASURE_L2 = "l2 distance aka. Euclidean distance" MEASURE_DOT = "dot product aka. inner product" MEASURE_COSINE = "cosine similarity" @@ -247,9 +246,7 @@ def save(self, save_dir=None, save_trainset=False): model_file = os.path.join(model_dir, "{}.pkl".format(timestamp)) saved_model = copy.deepcopy(self) - pickle.dump( - saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL - ) + pickle.dump(saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL) if self.verbose: print("{} model is saved to {}".format(self.name, model_file)) @@ -548,9 +545,7 @@ def recommend(self, user_id, k=-1, remove_seen=False, train_set=None): raise ValueError(f"{user_id} is unknown to the model.") if k < -1 or k > self.total_items: - raise ValueError( - f"k={k} is invalid, there are {self.total_users} users in total." - ) + raise ValueError(f"k={k} is invalid, there are {self.total_users} users in total.") item_indices = np.arange(self.total_items) if remove_seen: @@ -627,11 +622,7 @@ def early_stop(self, train_set, val_set, min_delta=0.0, patience=0): if self.stopped_epoch > 0: print("Early stopping:") - print( - "- best epoch = {}, stopped epoch = {}".format( - self.best_epoch, self.stopped_epoch - ) - ) + print("- best epoch = {}, stopped epoch = {}".format(self.best_epoch, self.stopped_epoch)) print( "- best monitored value = {:.6f} (delta = {:.6f})".format( self.best_value, current_value - self.best_value @@ -696,3 +687,60 @@ def score(self, user_idx, history_baskets, **kwargs): """ raise NotImplementedError("The algorithm is not able to make score prediction!") + + +class NextItemRecommender(Recommender): + """Generic class for a next item recommender model. All next item recommendation models should inherit from this class. + + Parameters + ---------------- + name: str, required + Name of the recommender model. + + trainable: boolean, optional, default: True + When False, the model is not trainable. + + verbose: boolean, optional, default: False + When True, running logs are displayed. + + Attributes + ---------- + num_users: int + Number of users in training data. + + num_items: int + Number of items in training data. + + total_users: int + Number of users in training, validation, and test data. + In other words, this includes unknown/unseen users. + + total_items: int + Number of items in training, validation, and test data. + In other words, this includes unknown/unseen items. + + uid_map: int + Global mapping of user ID-index. + + iid_map: int + Global mapping of item ID-index. + """ + + def __init__(self, name, trainable=True, verbose=False): + super().__init__(name=name, trainable=trainable, verbose=verbose) + + def score(self, user_idx, history_items, **kwargs): + """Predict the scores for all items based on input history items + + Parameters + ---------- + history_items: list of lists + The list of history items in sequential manner for next-item prediction. + + Returns + ------- + res : a Numpy array + Relative scores of all known items + + """ + raise NotImplementedError("The algorithm is not able to make score prediction!") diff --git a/cornac/models/spop/__init__.py b/cornac/models/spop/__init__.py new file mode 100644 index 000000000..5e76ee142 --- /dev/null +++ b/cornac/models/spop/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_spop import SPop diff --git a/cornac/models/spop/recom_spop.py b/cornac/models/spop/recom_spop.py new file mode 100644 index 000000000..50f1795d7 --- /dev/null +++ b/cornac/models/spop/recom_spop.py @@ -0,0 +1,59 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from collections import Counter + +import numpy as np + +from ..recommender import NextItemRecommender + + +class SPop(NextItemRecommender): + """Recommend most popular items of the current session. + + Parameters + ---------- + name: string, default: 'SPop' + The name of the recommender model. + + use_session_popularity: boolean, optional, default: True + When False, no item frequency from history items in current session are being used. + + References + ---------- + Balázs Hidasi, Alexandros Karatzoglou, Linas Baltrunas, Domonkos Tikk: + Session-based Recommendations with Recurrent Neural Networks, ICLR 2016 + """ + + def __init__(self, name="SPop", use_session_popularity=True): + super().__init__(name=name, trainable=False) + self.use_session_popularity = use_session_popularity + self.item_freq = Counter() + + def fit(self, train_set, val_set=None): + super().fit(train_set=train_set, val_set=val_set) + self.item_freq = Counter(self.train_set.uir_tuple[1]) + return self + + def score(self, user_idx, history_items, **kwargs): + item_scores = np.ones(self.total_items, dtype=np.float32) + max_item_freq = max(self.item_freq.values()) if len(self.item_freq) > 0 else 1 + for iid, freq in self.item_freq.items(): + item_scores[iid] = freq / max_item_freq + if self.use_session_popularity: + s_item_freq = Counter([iid for iid in history_items]) + for iid, cnt in s_item_freq.most_common(): + item_scores[iid] += cnt + return item_scores diff --git a/examples/spop_yoochoose.py b/examples/spop_yoochoose.py new file mode 100644 index 000000000..ef87d9e2f --- /dev/null +++ b/examples/spop_yoochoose.py @@ -0,0 +1,52 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Example of a next-basket recommendation model that merely uses item top frequency""" + +import cornac +from cornac.datasets import yoochoose +from cornac.eval_methods import NextItemEvaluation +from cornac.metrics import MRR, NDCG, Recall +from cornac.models import SPop + +buy_data = yoochoose.load_buy() +print("buy data loaded") +test_data = yoochoose.load_test() +print("test data loaded") + +next_item_eval = NextItemEvaluation.from_splits( + train_data=buy_data, + test_data=test_data[:10000], # illustration purpose only, subset of test data for faster experiment + verbose=True, + fmt="SITJson", +) + +models = [ + SPop(name="Pop", use_session_popularity=False), + SPop(), +] + +metrics = [ + NDCG(k=10), + NDCG(k=50), + Recall(k=10), + Recall(k=50), + MRR(), +] + +cornac.Experiment( + eval_method=next_item_eval, + models=models, + metrics=metrics, +).run() diff --git a/tests/cornac/data/test_dataset.py b/tests/cornac/data/test_dataset.py index 161c542e7..1076764bb 100644 --- a/tests/cornac/data/test_dataset.py +++ b/tests/cornac/data/test_dataset.py @@ -18,7 +18,7 @@ import numpy as np import numpy.testing as npt -from cornac.data import BasketDataset, Dataset, Reader +from cornac.data import BasketDataset, Dataset, SequentialDataset, Reader class TestDataset(unittest.TestCase): @@ -261,5 +261,33 @@ def test_init(self): ) +class TestSequentialDataset(unittest.TestCase): + def setUp(self): + self.sequential_data = Reader().read("./tests/sequence.txt", fmt="USIT", sep=" ") + + def test_init(self): + train_set = SequentialDataset.from_usit(self.sequential_data) + + self.assertEqual(train_set.num_sessions, 16) + self.assertEqual(train_set.max_session_size, 6) + self.assertEqual(train_set.min_session_size, 2) + + self.assertEqual(train_set.num_users, 5) + self.assertEqual(train_set.num_items, 9) + + self.assertEqual(train_set.uid_map["1"], 0) + self.assertEqual(train_set.sid_map["1"], 0) + self.assertEqual(train_set.iid_map["1"], 0) + + self.assertSetEqual( + set(train_set.user_ids), + set(["1", "2", "3", "4", "5"]), + ) + + self.assertSetEqual( + set(train_set.item_ids), + set(["1", "2", "3", "4", "5", "6", "7", "8", "9"]), + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/cornac/eval_methods/test_next_item_evaluation.py b/tests/cornac/eval_methods/test_next_item_evaluation.py new file mode 100644 index 000000000..07af3c3ba --- /dev/null +++ b/tests/cornac/eval_methods/test_next_item_evaluation.py @@ -0,0 +1,52 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import unittest + +from cornac.eval_methods import NextItemEvaluation +from cornac.data import Reader +from cornac.models import SPop +from cornac.metrics import HitRatio, Recall + + +class TestNextItemEvaluation(unittest.TestCase): + def setUp(self): + self.data = Reader().read("./tests/sequence.txt", fmt="USIT", sep=" ") + + def test_from_splits(self): + next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT") + + self.assertTrue(next_item_eval.train_set != None) + self.assertTrue(next_item_eval.test_set != None) + self.assertTrue(next_item_eval.val_set == None) + self.assertTrue(next_item_eval.total_sessions == 16) + + def test_evaluate(self): + next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT") + result = next_item_eval.evaluate( + SPop(), [HitRatio(k=2), Recall(k=2)], user_based=False + ) + self.assertEqual(result[0].metric_avg_results.get('HitRatio@2'), 0) + self.assertEqual(result[0].metric_avg_results.get('Recall@2'), 0) + next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT") + result = next_item_eval.evaluate( + SPop(), [HitRatio(k=5), Recall(k=5)], user_based=True + ) + self.assertEqual(result[0].metric_avg_results.get('HitRatio@5'), 2/3) + self.assertEqual(result[0].metric_avg_results.get('Recall@5'), 2/3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/cornac/models/test_recommender.py b/tests/cornac/models/test_recommender.py index 22dbf91ac..d29b05a3d 100644 --- a/tests/cornac/models/test_recommender.py +++ b/tests/cornac/models/test_recommender.py @@ -15,8 +15,8 @@ import unittest -from cornac.data import BasketDataset, Dataset, Reader -from cornac.models import MF, GPTop, NextBasketRecommender +from cornac.data import BasketDataset, Dataset, SequentialDataset, Reader +from cornac.models import MF, GPTop, SPop, NextBasketRecommender, NextItemRecommender class TestRecommender(unittest.TestCase): @@ -69,5 +69,24 @@ def test_fit(self): model.rank(0, history_baskets=[[]]) +class TestNextItemRecommender(unittest.TestCase): + def setUp(self): + self.data = Reader().read("./tests/sequence.txt", fmt="USIT", sep=" ") + + def test_init(self): + model = NextItemRecommender("test") + self.assertTrue(model.name == "test") + + def test_fit(self): + dataset = SequentialDataset.from_usit(self.data) + model = NextItemRecommender("") + model.fit(dataset) + model = SPop() + model.fit(dataset) + model.score(0, []) + result = model.rank(0, history_items=[]) + self.assertTrue((result[0] == [3, 2, 4, 1, 0, 5, 8, 7, 6]).all()) + + if __name__ == "__main__": unittest.main() diff --git a/tests/sequence.txt b/tests/sequence.txt new file mode 100644 index 000000000..93bb8f916 --- /dev/null +++ b/tests/sequence.txt @@ -0,0 +1,61 @@ +1 1 1 882606571 +1 1 2 882606572 +1 1 3 882606573 +1 2 4 882606574 +1 2 5 882606575 +1 2 6 882606576 +1 2 7 882606577 +1 3 8 882606578 +1 3 9 882606579 +2 4 1 882606580 +2 4 2 882606581 +2 4 4 882606582 +2 4 5 882606583 +2 5 2 882606584 +2 5 3 882606585 +2 5 5 882606586 +2 6 1 882606587 +2 6 3 882606588 +2 6 4 882606589 +3 7 4 882606590 +3 7 5 882606591 +3 7 6 882606592 +3 7 1 882606593 +3 7 2 882606594 +3 7 3 882606595 +3 8 4 882606596 +3 8 5 882606597 +3 8 6 882606598 +3 8 4 882606599 +4 9 3 882606600 +4 9 1 882606601 +4 9 2 882606602 +4 10 3 882606603 +4 10 4 882606604 +4 11 2 882606605 +4 11 4 882606606 +4 11 5 882606607 +4 11 3 882606608 +4 11 1 882606609 +5 12 4 882606610 +5 12 5 882606611 +5 12 6 882606612 +5 12 3 882606613 +5 12 3 882606614 +5 13 5 882606615 +5 13 4 882606616 +5 13 2 882606617 +5 13 1 882606618 +5 13 3 882606619 +5 13 4 882606620 +1 14 5 882606621 +1 14 4 882606622 +1 14 3 882606623 +1 14 1 882606624 +2 15 2 882606625 +2 15 3 882606626 +3 16 4 882606627 +3 16 5 882606628 +3 16 6 882606629 +3 16 1 882606630 +3 16 2 882606631