diff --git a/ogb/lsc/wikikg90mv2.py b/ogb/lsc/wikikg90mv2.py index 3b26924c..d560299c 100644 --- a/ogb/lsc/wikikg90mv2.py +++ b/ogb/lsc/wikikg90mv2.py @@ -236,7 +236,7 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str): t_pred_top10 = input_dict['h,r->t']['t_pred_top10'] for i in range(len(t_pred_top10)): - assert len(pd.unique(t_pred_top10[i])) == len(t_pred_top10[i]), 'Found duplicated tail prediction for some triplets!' + assert len(pd.unique(t_pred_top10[i][t_pred_top10[i] >= 0])) == len(t_pred_top10[i][t_pred_top10[i] >= 0]), 'Found duplicated tail prediction for some triplets!' if mode == 'test-dev': assert t_pred_top10.shape == (15000, 10)