Add a class MyGraphLayer()
in my_graph_layer.py
file in the layers/
directory. A standard code is
import torch
import torch.nn as nn
class MyGraphLayer(nn.Module):
def __init__(self, in_dim, out_dim, dropout):
super().__init__()
# write your code here
def forward(self, x_with_all_info):
# write your code here
# which operates on the dense
# input tensor x_with_all_info
return x_out
Directory layers/ contains all layer classes for all graph networks and standard layers like MLP for readout layers.
As instance, the RingGNN Layer class RingGNNEquivLayer() is defined in the layers/ring_gnn_equiv_layer.py file.
Add a class MyGraphNetwork()
in my_gcn_net.py
file in the net/
directory. The loss()
function of the network is also defined in class MyGraphNetwork().
import torch
import torch.nn as nn
from layers.my_graph_layer import MyGraphLayer
class MyGraphNetwork(nn.Module):
def __init__(self, in_dim, out_dim, dropout):
super().__init__()
# write your code here
self.layer = MyGraphLayer()
def forward(self, x_with_all_info):
# write your code here
# which operates on the dense
# input tensor x_with_all_info
return x_out
def loss(self, pred, label):
# write your loss function here
return loss
Add a name MyGNN
for the proposed new graph network class in load_gnn.py
file in the net/
directory.
from nets.my_gcn_net import MyGraphNetwork
def MyGNN(net_params):
return MyGraphNetwork(net_params)
def gnn_model(MODEL_NAME, net_params):
models = {
'MyGNN': MyGNN
}
return models[MODEL_NAME](net_params)
For the ZINC example, RingGNNNet() in nets/molecules_graph_regression/ring_gnn_net.py is given the GNN name RingGNN in nets/molecules_graph_regression/load_net.py.
Add a file train_data_my_new_task.py
in the train/
directory.
def train_epoch_dense(model, optimizer, device, data_loader, nb_epochs, batch_size):
model.train()
# write your code here
# Note, we use gradient accumulation wrt to
# the batch_size during training, since the
# ususal batching approach for MP-GCNs operating
# on sparse tensors do not apply for WL-GNNs
return train_loss, train_acc
def evaluate_network_dense(model, device, data_loader):
model.eval()
# write your code here
return test_loss, test_acc
For ZINC, the loops are defined in file train/train_molecules_graph_regression.py.
Add a new notebook file main_my_new_task.ipynb
or python main_my_new_task.py
for the new task.
from nets.load_net import gnn_model
from data.data import LoadData
from train.train_data_my_new_task import train_epoch_dense as train_epoch, evaluate_network_dense as evaluate_network
DATASET_NAME = 'MY_DATASET'
dataset = LoadData(DATASET_NAME)
MODEL_NAME = 'MyGNN'
model = gnn_model(MODEL_NAME, net_params)
optimizer = optim.Adam(model.parameters())
train_loader = DataLoader(dataset.train, shuffle=True, collate_fn=dataset.collate_dense_gnn)
epoch_train_loss, epoch_train_acc = train_epoch(model, optimizer, device, train_loader, epoch, batch_size)
Python file main_my_new_task.py
can be generated by saving the notebook main_my_new_task.ipynb
as a regular python file. (We actually developed a new graph network within the notebook and then converted the .ipynb to .py, but it can be done directly in .py)
As for ZINC, the main file is main_molecules_graph_regression.ipynb
or main_molecules_graph_regression.py
.
Code can be executed in the notebook main_my_new_task.ipynb
or in terminal with command
bash main_my_new_task.py --dataset DATASET_NAME --gpu_id 0 --config 'configs/my_new_task_MyGNN_DATASET_NAME.json'
The training and network parameters for the dataset and the network is stored in a json file in the configs/
directory.
{
"gpu": {
"use": true,
"id": 0
},
"model": MyGNN,
"dataset": DATASET_NAME,
"out_dir": "out/my_new_task/",
"params": {
"seed": 41,
"epochs": 1000,
"batch_size": 128,
"init_lr": 0.001
},
"net_params": {
"L": 4,
"hidden_dim": 70,
"out_dim": 70,
"residual": true
}
}
For ZINC, the config is molecules_graph_regression_RingGNN_ZINC_100k.json and the code is run with
python main_molecules_graph_regression.py --dataset ZINC --gpu_id 0 --config 'configs/molecules_graph_regression_RingGNN_ZINC_100k.json'