From 8bcdd70609028623acf6f2909401a940ae5a5232 Mon Sep 17 00:00:00 2001 From: anpolol Date: Fri, 15 Dec 2023 19:04:37 +0300 Subject: [PATCH] added test for link prediction --- tests/test_general/test_link_prediction.py | 25 ++++++++++++++-------- tutorials/link_prediction.py | 13 ----------- 2 files changed, 16 insertions(+), 22 deletions(-) delete mode 100644 tutorials/link_prediction.py diff --git a/tests/test_general/test_link_prediction.py b/tests/test_general/test_link_prediction.py index eed870a..202013d 100644 --- a/tests/test_general/test_link_prediction.py +++ b/tests/test_general/test_link_prediction.py @@ -1,13 +1,20 @@ -import pathlib +import torch_geometric.transforms as T +from stable_gnn.model_link_predict import ModelLinkPrediction +from torch_geometric.datasets import Planetoid +from sklearn.ensemble import GradientBoostingClassifier +import pytest -import torch +@pytest.mark.parametrize("conv", ["SAGE", "GAT", "GCN"]) +@pytest.mark.parametrize("loss_name", ["APP", "LINE", "HOPE_AA", "VERSE_Adj"]) +def test_linkpredict(loss_name: str, conv: str) -> None: + root = '../tmp/' + name = 'Cora' + dataset = Planetoid(root=root + str(name), name=name, transform=T.NormalizeFeatures()) -from stable_gnn.graph import Graph -from stable_gnn.pipelines.graph_classification_pipeline import TrainModelGC -from tests.data_generators import generate_gc_graphs + model = ModelLinkPrediction(number_of_trials=50, loss_name=loss_name, emb_conv_name=conv) -root = str(pathlib.Path(__file__).parent.resolve().joinpath("data_validation/")) + "/" -generate_gc_graphs(root, 30) + train_edges, train_negative, test_edges, test_negative = model.train_test_edges(dataset) -def test_linkpredict(): - pass \ No newline at end of file + cl_before = GradientBoostingClassifier(n_estimators=100, learning_rate=0.2, max_depth=5, random_state=0) + cl_after = model.train_cl(train_edges, train_negative) + assert (model.test(cl_before, test_edges, test_negative)) < (model.test(cl_after, test_edges, test_negative)) \ No newline at end of file diff --git a/tutorials/link_prediction.py b/tutorials/link_prediction.py deleted file mode 100644 index fffc03a..0000000 --- a/tutorials/link_prediction.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch_geometric.transforms as T -from torch_geometric.datasets import Planetoid -from stable_gnn.model_link_predict import ModelLinkPrediction -from stable_gnn.graph import Graph -import pathlib - - -if __name__ == "__main__": - - data = Planetoid(root="/tmp/" + str("name"), name="Cora", transform=T.NormalizeFeatures()) - model = ModelLinkPrediction(data, number_of_trials=1) - clf = model.train_cl() - print("f1 measure", (model.test())) \ No newline at end of file