From 73782c926be66fd110a38d5a3297394e8a1718f5 Mon Sep 17 00:00:00 2001 From: sungho-ham <19978686+sungho-ham@users.noreply.github.com> Date: Wed, 7 Aug 2024 05:10:34 +0900 Subject: [PATCH] Fix recall at k when batch size = 1 (#779) Co-authored-by: sungho-ham --- transformers4rec/torch/ranking_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/ranking_metric.py b/transformers4rec/torch/ranking_metric.py index 5495b98a4..a281dd83b 100644 --- a/transformers4rec/torch/ranking_metric.py +++ b/transformers4rec/torch/ranking_metric.py @@ -131,7 +131,7 @@ def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor) # Compute recalls at K num_relevant = torch.sum(labels, dim=-1) - rel_indices = (num_relevant != 0).nonzero().squeeze() + rel_indices = (num_relevant != 0).nonzero().squeeze(dim=1) rel_count = num_relevant[rel_indices] if rel_indices.shape[0] > 0: