From c8da898a0f86351aaf8721cf1e90d358d04703d3 Mon Sep 17 00:00:00 2001 From: Weihua Hu Date: Tue, 6 Apr 2021 16:43:37 -0700 Subject: [PATCH] add as_tuple in torch.nonzero --- examples/graphproppred/code2/utils.py | 2 +- ogb/linkproppred/evaluate.py | 2 +- ogb/lsc/wikikg90m.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/graphproppred/code2/utils.py b/examples/graphproppred/code2/utils.py index ce58b9ec..0a25abe4 100644 --- a/examples/graphproppred/code2/utils.py +++ b/examples/graphproppred/code2/utils.py @@ -170,7 +170,7 @@ def decode_arr_to_seq(arr, idx2vocab): ''' - eos_idx_list = torch.nonzero(arr == len(idx2vocab) - 1) # find the position of __EOS__ (the last vocab in idx2vocab) + eos_idx_list = torch.nonzero(arr == len(idx2vocab) - 1, as_tuple=False) # find the position of __EOS__ (the last vocab in idx2vocab) if len(eos_idx_list) > 0: clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__ else: diff --git a/ogb/linkproppred/evaluate.py b/ogb/linkproppred/evaluate.py index 6ab4f5b4..67de3df4 100644 --- a/ogb/linkproppred/evaluate.py +++ b/ogb/linkproppred/evaluate.py @@ -232,7 +232,7 @@ def _eval_mrr(self, y_pred_pos, y_pred_neg, type_info): if type_info == 'torch': y_pred = torch.cat([y_pred_pos.view(-1,1), y_pred_neg], dim = 1) argsort = torch.argsort(y_pred, dim = 1, descending = True) - ranking_list = torch.nonzero(argsort == 0) + ranking_list = torch.nonzero(argsort == 0, as_tuple=False) ranking_list = ranking_list[:, 1] + 1 hits1_list = (ranking_list <= 1).to(torch.float) hits3_list = (ranking_list <= 3).to(torch.float) diff --git a/ogb/lsc/wikikg90m.py b/ogb/lsc/wikikg90m.py index 701acb2d..abd616b8 100644 --- a/ogb/lsc/wikikg90m.py +++ b/ogb/lsc/wikikg90m.py @@ -201,7 +201,7 @@ def _calculate_mrr(self, correct_index, pred_top10): - pred_top10: shape (num_eval_triplets, 10) ''' # extract indices where correct_index is within top10 - tmp = torch.nonzero(correct_index.view(-1,1) == pred_top10) + tmp = torch.nonzero(correct_index.view(-1,1) == pred_top10, as_tuple=False) # reciprocal rank # if rank is larger than 10, then set the reciprocal rank to 0.