Skip to content

Commit

Permalink
add as_tuple in torch.nonzero
Browse files Browse the repository at this point in the history
  • Loading branch information
weihua916 committed Apr 6, 2021
1 parent a593a61 commit c8da898
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/graphproppred/code2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def decode_arr_to_seq(arr, idx2vocab):
'''


eos_idx_list = torch.nonzero(arr == len(idx2vocab) - 1) # find the position of __EOS__ (the last vocab in idx2vocab)
eos_idx_list = torch.nonzero(arr == len(idx2vocab) - 1, as_tuple=False) # find the position of __EOS__ (the last vocab in idx2vocab)
if len(eos_idx_list) > 0:
clippted_arr = arr[: torch.min(eos_idx_list)] # find the smallest __EOS__
else:
Expand Down
2 changes: 1 addition & 1 deletion ogb/linkproppred/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _eval_mrr(self, y_pred_pos, y_pred_neg, type_info):
if type_info == 'torch':
y_pred = torch.cat([y_pred_pos.view(-1,1), y_pred_neg], dim = 1)
argsort = torch.argsort(y_pred, dim = 1, descending = True)
ranking_list = torch.nonzero(argsort == 0)
ranking_list = torch.nonzero(argsort == 0, as_tuple=False)
ranking_list = ranking_list[:, 1] + 1
hits1_list = (ranking_list <= 1).to(torch.float)
hits3_list = (ranking_list <= 3).to(torch.float)
Expand Down
2 changes: 1 addition & 1 deletion ogb/lsc/wikikg90m.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _calculate_mrr(self, correct_index, pred_top10):
- pred_top10: shape (num_eval_triplets, 10)
'''
# extract indices where correct_index is within top10
tmp = torch.nonzero(correct_index.view(-1,1) == pred_top10)
tmp = torch.nonzero(correct_index.view(-1,1) == pred_top10, as_tuple=False)

# reciprocal rank
# if rank is larger than 10, then set the reciprocal rank to 0.
Expand Down

0 comments on commit c8da898

Please sign in to comment.