Skip to content

Commit

Permalink
added test for link prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Dec 15, 2023
1 parent 7e141c8 commit 8bcdd70
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
25 changes: 16 additions & 9 deletions tests/test_general/test_link_prediction.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import pathlib
import torch_geometric.transforms as T
from stable_gnn.model_link_predict import ModelLinkPrediction
from torch_geometric.datasets import Planetoid
from sklearn.ensemble import GradientBoostingClassifier
import pytest

import torch
@pytest.mark.parametrize("conv", ["SAGE", "GAT", "GCN"])
@pytest.mark.parametrize("loss_name", ["APP", "LINE", "HOPE_AA", "VERSE_Adj"])
def test_linkpredict(loss_name: str, conv: str) -> None:
root = '../tmp/'
name = 'Cora'
dataset = Planetoid(root=root + str(name), name=name, transform=T.NormalizeFeatures())

from stable_gnn.graph import Graph
from stable_gnn.pipelines.graph_classification_pipeline import TrainModelGC
from tests.data_generators import generate_gc_graphs
model = ModelLinkPrediction(number_of_trials=50, loss_name=loss_name, emb_conv_name=conv)

root = str(pathlib.Path(__file__).parent.resolve().joinpath("data_validation/")) + "/"
generate_gc_graphs(root, 30)
train_edges, train_negative, test_edges, test_negative = model.train_test_edges(dataset)

def test_linkpredict():
pass
cl_before = GradientBoostingClassifier(n_estimators=100, learning_rate=0.2, max_depth=5, random_state=0)
cl_after = model.train_cl(train_edges, train_negative)
assert (model.test(cl_before, test_edges, test_negative)) < (model.test(cl_after, test_edges, test_negative))
13 changes: 0 additions & 13 deletions tutorials/link_prediction.py

This file was deleted.

0 comments on commit 8bcdd70

Please sign in to comment.