Skip to content

Commit

Permalink
now embeddings can be constructed for geom gcn and for general purpos…
Browse files Browse the repository at this point in the history
…e + saving of embeddings added
  • Loading branch information
anpolol committed Dec 7, 2023
1 parent bebc389 commit 1e5efda
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 149 deletions.
11 changes: 6 additions & 5 deletions stable_gnn/embedding/embedding_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class EmbeddingFactory:
"""Producing unsupervised embeddings for a given dataset"""

@staticmethod
def _build_embeddings(loss: Dict[str, Any], data: Graph, conv: str, device: device,number_of_trials: int) -> NDArray:
optuna_training = OptunaTrainEmbeddings(data=data, conv=conv, device=device, loss_function=loss)
def _build_embeddings(loss: Dict[str, Any], data: Graph, conv: str, device: device,number_of_trials: int, tune_out: bool=False) -> NDArray:
optuna_training = OptunaTrainEmbeddings(data=data, conv=conv, device=device, loss_function=loss,tune_out=tune_out)
best_values = optuna_training.run(number_of_trials=number_of_trials)

loss_trgt = dict()
Expand All @@ -28,7 +28,7 @@ def _build_embeddings(loss: Dict[str, Any], data: Graph, conv: str, device: devi
if "lmbda" in loss_trgt:
loss_trgt["lmbda"] = best_values["lmbda"]

model_training = ModelTrainEmbeddings(data=data, conv=conv, device=device, loss_function=loss_trgt)
model_training = ModelTrainEmbeddings(data=data, conv=conv, device=device, loss_function=loss_trgt,tune_out=tune_out)
out = model_training.run(best_values)
torch.cuda.empty_cache()
return out.detach().cpu().numpy()
Expand Down Expand Up @@ -77,16 +77,17 @@ def _get_emb_settings(loss_name: str) -> Dict[str, Any]:
else:
raise NameError

def build_embeddings(self, loss_name: str, conv: str, data: Graph, device: device, number_of_trials: int) -> NDArray:
def build_embeddings(self, loss_name: str, conv: str, data: Graph, device: device, number_of_trials: int, tune_out: bool=False) -> NDArray:
"""Build embeddings based on passed dataset and settings
:param loss_name: (str): Name of loss function for embedding learning in GeomGCN layer
:param conv: (str) Name of convolution used in unsupervied embeddings
:param data: (Graph): Input Graph
:param device: (device): Device 'cuda' or 'cpu'
:param number_of_trials (int): Number of trials for optuna tuning embeddings
:param tune_out (bool): Flag if you want tune out layer of embeddings
:returns: (NDArray) embeddings NumPy array of (N_nodes) x (N_emb_dim)
"""
loss_params = self._get_emb_settings(loss_name)
emb = self._build_embeddings(loss=loss_params, data=data[0], conv=conv, device=device, number_of_trials=number_of_trials)
emb = self._build_embeddings(loss=loss_params, data=data[0], conv=conv, device=device, number_of_trials=number_of_trials, tune_out=tune_out)
return emb
14 changes: 11 additions & 3 deletions stable_gnn/embedding/model_train_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ class ModelTrainEmbeddings:
:param loss_function: (dict): Dict of parameters of unsupervised loss function
:param conv: (str): Name of convolution (default:'GCN')
:param device: (device): Either 'cuda' or 'cpu' (default:'cuda')
:param tune_out: (bool): Flag if you want tuning out layer or if it 2 for GeomGCN
"""

def __init__(self, data: Graph, loss_function: Dict, device: device, conv: str = "GCN") -> None:
def __init__(self, data: Graph, loss_function: Dict, device: device, conv: str = "GCN", tune_out: bool=False) -> None:
self.conv = conv
self.device = device
self.x = data.x
self.y = data.y.squeeze()
self.data = data.to(device)
self.train_mask = torch.Tensor([True] * data.num_nodes)
self.loss = loss_function
self.tune_out = tune_out
super(ModelTrainEmbeddings, self).__init__()

def _sampling(self, sampler: BaseSampler, epoch: int, nodes: Tensor) -> None:
Expand Down Expand Up @@ -76,7 +78,10 @@ def run(self, params: Dict) -> Tensor:
:return: (Tensor): The output embeddings
"""
hidden_layer = params["hidden_layer"]
out_layer = params["out_layer"]
if self.tune_out:
out_layer = params["out_layer"]
else:
out_layer=2
dropout = params["dropout"]
size = params["size of network, number of convs"]
learning_rate = params["lr"]
Expand Down Expand Up @@ -121,7 +126,10 @@ def _objective(self, trial: Trial) -> Tensor:
dropout = trial.suggest_float("dropout", 0.0, 0.5, step=0.1)
size = trial.suggest_categorical("size of network, number of convs", [1, 2, 3])
learning_rate = trial.suggest_float("lr", 5e-3, 1e-2)
out_layer = trial.suggest_categorical("out_layer", [32, 64, 128])
if self.tune_out:
out_layer = trial.suggest_categorical("out_layer", [32, 64, 128])
else:
out_layer = 2

loss_to_train = {}
for name in self.loss:
Expand Down
1 change: 0 additions & 1 deletion stable_gnn/embedding/sampling/abstract_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _sample_negative(self, batch: Tensor, num_negative_samples: int) -> Tensor:
:param num_negative_samples: (int): number of negative samples for each edge
:return: (Tensor): Negative samples
"""
print("self device", self.device, batch)
a, _ = subgraph(batch, self.data.edge_index.to(self.device))
adj = self._adj_list(a)
g = dict()
Expand Down
1 change: 0 additions & 1 deletion stable_gnn/geom_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def _virtual_vertex(self, edge_index: Tensor, x: Union[Tensor, OptPairTensor]) -
x = (x, x)
graph_size = max(edge_index[0].max(), edge_index[1].max()) + 1
deg = degree(edge_index[0], graph_size)
print("degree", deg)
(
edge_index_s_ur,
edge_index_s_ul,
Expand Down
2 changes: 1 addition & 1 deletion stable_gnn/model_link_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self.neg_samples_train = self._neg_samples(train_edges, self.data)

self.embeddings = EmbeddingFactory().build_embeddings(
loss_name=loss_name, conv=emb_conv_name, data=dataset, device=device, number_of_trials=number_of_trials
loss_name=loss_name, conv=emb_conv_name, data=dataset, device=device, number_of_trials=number_of_trials, tune_out=True
)

def _train_test_edges(self, data: Graph) -> (List[int], List[int]):
Expand Down
15 changes: 11 additions & 4 deletions stable_gnn/model_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from stable_gnn.embedding import EmbeddingFactory
from stable_gnn.geom_gcn import GeomGCN
from stable_gnn.graph import Graph

import os

class ModelNodeClassification(torch.nn.Module):
"""
Expand Down Expand Up @@ -52,9 +52,15 @@ def __init__(
if self.ssl_flag:
self.deg = degree(self.data[0].edge_index[0], self.data[0].num_nodes)

embeddings = EmbeddingFactory().build_embeddings(
loss_name=loss_name, conv=emb_conv_name, data=dataset, device=device
)
path = '../tutorials/embeddings_'+str(loss_name)+'_'+str(emb_conv_name)+'.npy'
if os.path.exists(path):
embeddings = np.load(path)
else:

embeddings = EmbeddingFactory().build_embeddings(
loss_name=loss_name, conv=emb_conv_name, data=dataset, device=device, number_of_trials=50
)
np.save(path,embeddings)

if self.num_layers == 1:
self.convs.append(
Expand Down Expand Up @@ -109,6 +115,7 @@ def inference(self, data: Graph) -> Tuple[Tensor, Tensor]:
x = conv(x, edge_index)
if i != self.num_layers - 1:
x = x.relu()
x = self.linear(x)
x = self.linear_classifier(x)
deg_pred = 0
if self.ssl_flag:
Expand Down
16 changes: 1 addition & 15 deletions tests/test_general/test_link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,4 @@
generate_gc_graphs(root, 30)

def test_linkpredict():
name = "ba_gc"
data = Graph(root=root + name + "/", name=name, adjust_flag=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ssl_flag = False
extrapolate_flag = True

#######
best_values = {"hidden_layer": 64, "size of network, number of convs": 2, "dropout": 0.0, "lr": 0.01, "coef": 10}
model_training = TrainModelGC(data=data, device=device, ssl_flag=ssl_flag, extrapolate_flag=extrapolate_flag)

model, train_acc_mi, train_acc_ma, test_acc_mi, test_acc_ma = model_training.run(best_values)
torch.save(model, "model.pt")
assert train_acc_mi >= test_acc_mi
assert train_acc_ma >= test_acc_ma
pass
Loading

0 comments on commit 1e5efda

Please sign in to comment.