Skip to content

Commit

Permalink
Add option to evaluate every next item in next-item evaluation (#580)
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang authored Jan 12, 2024
1 parent f2d44ce commit aed8ece
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 42 deletions.
202 changes: 160 additions & 42 deletions cornac/eval_methods/next_item_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@
# limitations under the License.
# ============================================================================

import time
import warnings
from collections import OrderedDict, defaultdict

import numpy as np
from tqdm.auto import tqdm

from ..data import SequentialDataset
from ..experiment.result import Result
from ..models import NextItemRecommender
from . import BaseMethod

EVALUATION_MODES = frozenset([
"last",
"next",
])

def ranking_eval(
model,
Expand All @@ -30,6 +37,7 @@ def ranking_eval(
test_set,
user_based=False,
exclude_unknowns=True,
mode="last",
verbose=False,
):
"""Evaluate model on provided ranking metrics.
Expand Down Expand Up @@ -68,62 +76,68 @@ def ranking_eval(
return [], []

avg_results = []
session_results = [{} for _ in enumerate(metrics)]
session_results = [defaultdict(list) for _ in enumerate(metrics)]
user_results = [defaultdict(list) for _ in enumerate(metrics)]

user_sessions = defaultdict(list)
session_ids = []
for [sid], [mapped_ids], [session_items] in tqdm(
test_set.si_iter(batch_size=1, shuffle=False),
total=len(test_set.sessions),
desc="Ranking",
disable=not verbose,
miniters=100,
):
test_pos_items = session_items[-1:] # last item in the session
if len(test_pos_items) == 0:
if len(session_items) < 2: # exclude all session with size smaller than 2
continue
user_idx = test_set.uir_tuple[0][mapped_ids[0]]
if user_based:
user_sessions[user_idx].append(sid)
# binary mask for ground-truth positive items
u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int")
u_gt_pos_mask[test_pos_items] = 1

# binary mask for ground-truth negative items, removing all positive items
u_gt_neg_mask = np.ones(test_set.num_items, dtype="int")
u_gt_neg_mask[test_pos_items] = 0

# filter items being considered for evaluation
if exclude_unknowns:
u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items]
u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items]

u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_items=session_items[:-1],
history_mapped_ids=mapped_ids[:-1],
sessions=test_set.sessions,
session_indices=test_set.session_indices,
extra_data=test_set.extra_data,
)

for i, mt in enumerate(metrics):
mt_score = mt.compute(
gt_pos=u_gt_pos_items,
gt_neg=u_gt_neg_items,
pd_rank=item_rank,
pd_scores=item_scores,
item_indices=item_indices,
session_ids.append(sid)

start_pos = 1 if mode == "next" else len(session_items) - 1
for test_pos in range(start_pos, len(session_items), 1):
test_pos_items = session_items[test_pos]

# binary mask for ground-truth positive items
u_gt_pos_mask = np.zeros(test_set.num_items, dtype="int")
u_gt_pos_mask[test_pos_items] = 1

# binary mask for ground-truth negative items, removing all positive items
u_gt_neg_mask = np.ones(test_set.num_items, dtype="int")
u_gt_neg_mask[test_pos_items] = 0

# filter items being considered for evaluation
if exclude_unknowns:
u_gt_pos_mask = u_gt_pos_mask[: train_set.num_items]
u_gt_neg_mask = u_gt_neg_mask[: train_set.num_items]

u_gt_pos_items = np.nonzero(u_gt_pos_mask)[0]
u_gt_neg_items = np.nonzero(u_gt_neg_mask)[0]
item_indices = np.nonzero(u_gt_pos_mask + u_gt_neg_mask)[0]

item_rank, item_scores = model.rank(
user_idx,
item_indices,
history_items=session_items[:test_pos],
history_mapped_ids=mapped_ids[:test_pos],
sessions=test_set.sessions,
session_indices=test_set.session_indices,
extra_data=test_set.extra_data,
)
if user_based:
user_results[i][user_idx].append(mt_score)
else:
session_results[i][sid] = mt_score

for i, mt in enumerate(metrics):
mt_score = mt.compute(
gt_pos=u_gt_pos_items,
gt_neg=u_gt_neg_items,
pd_rank=item_rank,
pd_scores=item_scores,
item_indices=item_indices,
)
if user_based:
user_results[i][user_idx].append(mt_score)
else:
session_results[i][sid].append(mt_score)

# avg results of ranking metrics
for i, mt in enumerate(metrics):
Expand All @@ -132,7 +146,8 @@ def ranking_eval(
user_avg_results = [np.mean(user_results[i][user_idx]) for user_idx in user_ids]
avg_results.append(np.mean(user_avg_results))
else:
avg_results.append(sum(session_results[i].values()) / len(session_results[i]))
session_result = [score for sid in session_ids for score in session_results[i][sid]]
avg_results.append(np.mean(session_result))
return avg_results, user_results


Expand Down Expand Up @@ -163,6 +178,11 @@ class NextItemEvaluation(BaseMethod):
seed: int, optional, default: None
Random seed for reproducibility.
mode: str, optional, default: 'last'
Evaluation mode is either 'next' or 'last'.
If 'last', only evaluate the last item.
If 'next', evaluate every next item in the sequence,
exclude_unknowns: bool, optional, default: True
If `True`, unknown items will be ignored during model evaluation.
Expand All @@ -178,6 +198,7 @@ def __init__(
val_size=0.0,
fmt="SIT",
seed=None,
mode="last",
exclude_unknowns=True,
verbose=False,
**kwargs,
Expand All @@ -191,8 +212,14 @@ def __init__(
seed=seed,
exclude_unknowns=exclude_unknowns,
verbose=verbose,
mode=mode,
**kwargs,
)

if mode not in EVALUATION_MODES:
raise ValueError(f"{mode} is not supported. ({EVALUATION_MODES})")

self.mode = mode
self.global_sid_map = kwargs.get("global_sid_map", OrderedDict())

def _build_datasets(self, train_data, test_data, val_data=None):
Expand Down Expand Up @@ -263,6 +290,7 @@ def eval(
ranking_metrics,
user_based=False,
verbose=False,
mode="last",
**kwargs,
):
metric_avg_results = OrderedDict()
Expand All @@ -275,6 +303,7 @@ def eval(
test_set=test_set,
user_based=user_based,
exclude_unknowns=exclude_unknowns,
mode=mode,
verbose=verbose,
)

Expand All @@ -284,6 +313,95 @@ def eval(

return Result(model.name, metric_avg_results, metric_user_results)

def evaluate(self, model, metrics, user_based, show_validation=True):
"""Evaluate given models according to given metrics. Supposed to be called by Experiment.
Parameters
----------
model: :obj:`cornac.models.NextItemRecommender`
NextItemRecommender model to be evaluated.
metrics: :obj:`iterable`
List of metrics.
user_based: bool, required
Evaluation strategy for the rating metrics. Whether results
are averaging based on number of users or number of ratings.
show_validation: bool, optional, default: True
Whether to show the results on validation set (if exists).
Returns
-------
res: :obj:`cornac.experiment.Result`
"""
if not isinstance(model, NextItemRecommender):
raise ValueError("model must be a NextItemRecommender but '%s' is provided" % type(model))

if self.train_set is None:
raise ValueError("train_set is required but None!")
if self.test_set is None:
raise ValueError("test_set is required but None!")

self._reset()

###########
# FITTING #
###########
if self.verbose:
print("\n[{}] Training started!".format(model.name))

start = time.time()
model.fit(self.train_set, self.val_set)
train_time = time.time() - start

##############
# EVALUATION #
##############
if self.verbose:
print("\n[{}] Evaluation started!".format(model.name))

rating_metrics, ranking_metrics = self.organize_metrics(metrics)
if len(rating_metrics) > 0:
warnings.warn("NextItemEvaluation only supports ranking metrics. The given rating metrics {} will be ignored!".format([mt.name for mt in rating_metrics]))

start = time.time()
model.transform(self.test_set)
test_result = self.eval(
model=model,
train_set=self.train_set,
test_set=self.test_set,
val_set=self.val_set,
exclude_unknowns=self.exclude_unknowns,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
test_time = time.time() - start
test_result.metric_avg_results["Train (s)"] = train_time
test_result.metric_avg_results["Test (s)"] = test_time

val_result = None
if show_validation and self.val_set is not None:
start = time.time()
model.transform(self.val_set)
val_result = self.eval(
model=model,
train_set=self.train_set,
test_set=self.val_set,
val_set=None,
exclude_unknowns=self.exclude_unknowns,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
val_time = time.time() - start
val_result.metric_avg_results["Time (s)"] = val_time

return test_result, val_result

@classmethod
def from_splits(
cls,
Expand Down
15 changes: 15 additions & 0 deletions tests/cornac/eval_methods/test_next_item_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,28 @@ def test_evaluate(self):
)
self.assertEqual(result[0].metric_avg_results.get('HitRatio@2'), 0)
self.assertEqual(result[0].metric_avg_results.get('Recall@2'), 0)

next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT")
result = next_item_eval.evaluate(
SPop(), [HitRatio(k=5), Recall(k=5)], user_based=True
)
self.assertEqual(result[0].metric_avg_results.get('HitRatio@5'), 2/3)
self.assertEqual(result[0].metric_avg_results.get('Recall@5'), 2/3)

next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT", mode="next")
result = next_item_eval.evaluate(
SPop(), [HitRatio(k=2), Recall(k=2)], user_based=False
)

self.assertEqual(result[0].metric_avg_results.get('HitRatio@2'), 1/8)
self.assertEqual(result[0].metric_avg_results.get('Recall@2'), 1/8)

next_item_eval = NextItemEvaluation.from_splits(train_data=self.data[:50], test_data=self.data[50:], fmt="USIT", mode="next")
result = next_item_eval.evaluate(
SPop(), [HitRatio(k=5), Recall(k=5)], user_based=True
)
self.assertEqual(result[0].metric_avg_results.get('HitRatio@5'), 3/4)
self.assertEqual(result[0].metric_avg_results.get('Recall@5'), 3/4)

if __name__ == "__main__":
unittest.main()

0 comments on commit aed8ece

Please sign in to comment.