diff --git a/README.md b/README.md index ee3a460..17d5cad 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,7 @@ Performance on `v1.0` val (trained on `v1.0` train): | R@1 | R@5 | R@10 | MeanR | MRR | | ------ | ------ | ------ | ------ | ------ | -| 0.4194 | 0.7345 | 0.8387 | 5.9876 | 0.5650 | +| 0.4298 | 0.7464 | 0.8491 | 5.4874 | 0.5757 | Acknowledgements diff --git a/utils/eval_utils.py b/utils/eval_utils.py index 547bf1c..568773d 100644 --- a/utils/eval_utils.py +++ b/utils/eval_utils.py @@ -3,15 +3,9 @@ def get_gt_ranks(ranks, ans_ind): ans_ind = ans_ind.view(-1) - num_opts = 100 - ranks = ranks.view(-1, num_opts) gt_ranks = torch.LongTensor(ans_ind.size(0)) for i in range(ans_ind.size(0)): - gt_binary = torch.zeros(num_opts) - gt_binary[ans_ind[i]] = 1 - sorted_gt = gt_binary.index_select(0, ranks[i].sort()[1].cpu()) - gt_rank = (sorted_gt == 1).nonzero() + 1 - gt_ranks[i] = int(gt_rank) # gt_rank is 1x1 LongTensor + gt_ranks[i] = int(ranks[i, ans_ind[i]]) return gt_ranks