Skip to content

Commit

Permalink
Merge pull request #60 from BIMSBbioinfo/gnn_dev
Browse files Browse the repository at this point in the history
add different GNN models
  • Loading branch information
borauyar authored Mar 4, 2024
2 parents faa836d + d39bbb9 commit c0ebf36
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 21 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ __pycache__/
*.so

# Distribution / packaging
core.*
*.pth
9606.protein.aliases.v12.0.txt
9606.protein.links.v12.0.txt
*.tgz
Expand Down
26 changes: 22 additions & 4 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str,
choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str,
choices=["GC", "GCN", "GAT", "GIN", "SAGE", "CHEB"])
parser.add_argument("--target_variables",
help="(Optional if survival variables are not set to None)."
"Which variables in 'clin.csv' to use for predictions, comma-separated if multiple",
Expand Down Expand Up @@ -56,7 +59,7 @@ def main():
# 2. Check for required variables for model classes
if args.model_class != "supervised_vae":
if not any([args.target_variables, args.surv_event_var, args.batch_variables]):
parser.error(''.join(["When selecting a model other than 'supervised_vae',"
parser.error(''.join(["When selecting a model other than 'supervised_vae',",
"you must provide at least one of --target_variables, ",
"survival variables (--surv_event_var and --surv_time_var)",
"or --batch_variables."]))
Expand All @@ -82,7 +85,21 @@ def main():
else:
device_type = 'cpu'
torch.set_num_threads(args.threads)


# 5. check GNN arguments
if args.model_class == 'DirectPredGCNN':
if not args.gnn_conv_type:
warning_message = "\n".join([
"\n\n!!! When running DirectPredGCNN, a convolution type can be set",
"with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.",
"Falling back on the default convolution type: GC !!!\n\n"
])
warnings.warn(warning_message)
time.sleep(3) #wait a bit to capture user's attention to the warning
gnn_conv_type = 'GC'
else:
gnn_conv_type = args.gnn_conv_type

# Validate paths
if not os.path.exists(args.data_path):
raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path})
Expand Down Expand Up @@ -144,7 +161,8 @@ class AvailableModels(NamedTuple):
n_iter=int(args.hpo_iter),
use_loss_weighting = args.use_loss_weighting == 'True',
early_stop_patience = int(args.early_stop_patience),
device_type = device_type)
device_type = device_type,
gnn_conv_type = gnn_conv_type)

# do a hyperparameter search training multiple models and get the best_configuration
model, best_params = tuner.perform_tuning()
Expand Down
8 changes: 6 additions & 2 deletions flexynesis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
Integer(64, 512, name="hidden_dim"),
Real(0.0001, 0.01, prior="log-uniform", name="lr"),
Integer(32, 128, name="batch_size"),
Categorical(epochs, name="epochs")
],
Categorical(epochs, name="epochs"),
# below parameters apply to all convolution types except for 'GC'
Real(0.1, 0.4, prior="log-uniform", name="dropout"),
Integer(1, 3, name="number_layers"),
Categorical(['relu', 'sigmoid', 'tanh'], name="activation")
]
}
6 changes: 4 additions & 2 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
batch_variables = None, surv_event_var = None, surv_time_var = None,
n_iter = 10, config_path = None, plot_losses = False,
val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1,
device_type = None):
device_type = None, gnn_conv_type = None):
self.dataset = dataset
self.model_class = model_class
self.target_variables = target_variables
Expand All @@ -71,6 +71,7 @@ def __init__(self, dataset, model_class, config_name, target_variables,
progress_bar_finished='red'))
self.early_stop_patience = early_stop_patience
self.use_loss_weighting = use_loss_weighting
self.gnn_conv_type = gnn_conv_type

# If config_path is provided, use it
if config_path:
Expand All @@ -93,7 +94,8 @@ def objective(self, params, current_step, total_steps):
surv_time_var = self.surv_time_var,
val_size = self.val_size,
use_loss_weighting = self.use_loss_weighting,
device_type = self.device_type)
device_type = self.device_type,
gnn_conv_type = self.gnn_conv_type)
print(params)

mycallbacks = [self.progress_bar]
Expand Down
42 changes: 29 additions & 13 deletions flexynesis/models/direct_pred_gcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from captum.attr import IntegratedGradients

from ..modules import GCNN, MLP, cox_ph_loss
from ..modules import GCNN, MLP, cox_ph_loss, GraphNNs


class DirectPredGCNN(pl.LightningModule):
Expand All @@ -27,7 +27,8 @@ def __init__(
surv_time_var=None,
val_size=0.2,
use_loss_weighting=True,
device_type = None
device_type = None,
gnn_conv_type = None
):
super().__init__()
self.config = config
Expand All @@ -47,7 +48,8 @@ def __init__(
self.use_loss_weighting = use_loss_weighting

self.device_type = device_type

self.gnn_conv_type = gnn_conv_type

if self.use_loss_weighting:
# Initialize log variance parameters for uncertainty weighting
self.log_vars = nn.ParameterDict()
Expand All @@ -59,16 +61,30 @@ def __init__(
# NOTE: For now we use matrices, so number of node input features is 1.
input_dims = [1 for _ in range(len(layers))]

self.encoders = nn.ModuleList(
[
GCNN(
input_dim=input_dims[i],
hidden_dim=int(self.config["hidden_dim"]), # int because of pyg
output_dim=self.config["latent_dim"],
)
for i in range(len(layers))
]
)
if self.gnn_conv_type == 'GCNN':
self.encoders = nn.ModuleList(
[
GCNN(
input_dim=input_dims[i],
hidden_dim=int(self.config["hidden_dim"]), # int because of pyg
output_dim=self.config["latent_dim"],
)
for i in range(len(layers))
])
else:
self.encoders = nn.ModuleList(
[
GraphNNs(
input_dim=input_dims[i],
hidden_dim=int(self.config["hidden_dim"]), # int because of pyg
output_dim=self.config["latent_dim"],
act = self.config['activation'],
number_layers = self.config['number_layers'],
dropout = self.config['dropout'],
conv = self.gnn_conv_type
)
for i in range(len(layers))
])

# Init output layers
self.MLPs = nn.ModuleDict()
Expand Down
143 changes: 143 additions & 0 deletions flexynesis/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import nn
import torch_geometric.nn as gnn
from torch_geometric.nn import GCNConv, GATConv, GINConv, PNAConv, SAGEConv, ChebConv, GraphConv, \
global_mean_pool as gmeanp, global_max_pool as gmaxp, global_add_pool as gap


__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "GCNN", "cox_ph_loss"]
Expand Down Expand Up @@ -225,6 +227,146 @@ def forward(self, x):
x = x.squeeze(-1)
return x

class GraphNNs(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, conv='CHEB',
dropout=0.1, number_layers=1, deg=None, act = None):
super().__init__()
"""
Initializes a Graph Neural Network model with customizable convolution types, activation functions,
and the ability to specify the number of layers.
This model can use various graph convolution types (e.g., GCN, GAT, GIN, SAGE, CHEB) as specified by the user,
each potentially followed by batch normalization and a specified activation function. Dropout is applied for regularization.
The model concludes with a fully connected layer to produce the output.
Args:
input_dim (int): The dimensionality of input features.
hidden_dim (int): The dimensionality of features after the first fully connected layer.
output_dim (int): The dimensionality of the output features.
conv (str): The type of graph convolution to use. Defaults to 'CHEB'.
dropout (float): Dropout rate for regularization. Defaults to 0.5.
number_layers (int): The number of convolutional layers. Defaults to 5.
deg: (torch.Tensor): A tensor of the degrees of the nodes in the
input graph. Default value is None.(used in specific convolution types like PNA).
act (str): The activation function to use. Options include 'relu', 'sigmoid', etc.
Methods:
- `reset_parameters()`: A method that initializes the parameters of
the model.
Inputs:
- `batch` (torch_geometric.data.Batch): A PyTorch Geometric batch
object that represents the input graph. The batch object contains
the following attributes:
- `x` (torch.Tensor): A tensor of node features.
- `edge_index` (torch.LongTensor): A tensor of shape `(2, num_edges)` that
represents the indices of the edges in the graph.
- `batch` (torch.LongTensor): A tensor of shape `(num_nodes,)` that
indicates the membership of each node in a particular graph in the
batch.
- `edge_attr` (torch.Tensor): A tensor of shape `(num_edges, num_edge_features)`
that represents the edge features. If there are no edge features, then this
tensor is not used.
- `inputs` (torch.Tensor): A tensor of shape `(num_graphs, num_input_features)`
that represents the Molecular Mechanic features for each graph in the batch.
If `inputs` is `"False"`, then this tensor is not used.
- `output` (torch.Tensor): A tensor of shape `(num_graphs, output_dim)`
that represents the target labels for each graph in the batch.
Outputs:
- `x` (torch.Tensor): A tensor of shape `(batch_size, output_dim)`
that represents the predicted class probabilities for each graph in
the batch.
"""
act_options = {
'relu': (torch.nn.ReLU()),
'sigmoid': (torch.nn.Sigmoid()),
'tanh': (torch.nn.Tanh()),
'softmax': (torch.nn.Softmax(dim=None)),
'leakyrelu': (torch.nn.LeakyReLU(negative_slope=0.01, inplace=False)),
'elu': (torch.nn.ELU(alpha=1.0, inplace=False)),
'gelu': (torch.nn.GELU())
}
# check if the activation function string is valid
if act not in act_options:
raise ValueError("Invalid activation function string. Choose from 'relu', 'sigmoid', 'tanh', 'softmax', 'leakyrelu', 'elu', or 'gelu'.")

# instantiate the activation function
self.activation = act_options[act]

self.dropout = torch.nn.Dropout(dropout)

self.convs = torch.nn.ModuleList()
self.bns = torch.nn.ModuleList()

conv_options = {
'GCN': (GCNConv(input_dim, input_dim)),
'GAT': (GATConv(input_dim, input_dim)),
'GIN': (GINConv(nn.Sequential(nn.Linear(input_dim, input_dim),
torch.nn.BatchNorm1d(input_dim),
nn.ReLU(), nn.Linear(input_dim, input_dim)))),
'SAGE': (SAGEConv(input_dim, input_dim)),
'CHEB': (ChebConv(input_dim, input_dim, K=2)),
'GC': (GraphConv(input_dim, input_dim))
}
if conv not in conv_options:
raise ValueError('Unknown convolution type. Choose one of: ',conv_options.keys())

self.conv = conv_options[conv]

for i in range(number_layers):
self.convs.append(self.conv)
self.bns.append(nn.BatchNorm1d(input_dim))

self.fc1 = nn.Linear(input_dim, hidden_dim)
self.bn_ff1 = nn.BatchNorm1d(hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)


def reset_parameters(self):
for layer in self.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()

def forward(self, x, edge_index, batch):

for conv, batch_norm in zip(self.convs, self.bns):
if conv == "GCN" or conv == "CHEB" or conv == "GAT":
x = self.dropout(self.activation(batch_norm(conv(x, edge_index))))
else:
x = self.dropout(self.activation(batch_norm(conv(x, edge_index))))
x = gap(x, batch)

x = self.activation(self.bn_ff1(self.fc1(x)))
x = self.dropout(x)
out = self.fc2(x)

return out

def extract_embeddings(self, x, edge_index, batch):
for conv, batch_norm in zip(self.convs, self.bns):
if conv == "GCN" or conv == "CHEB" or conv == "GAT":
x = self.dropout(self.activation(batch_norm(conv(x, edge_index))))
else:
x = self.dropout(self.activation(batch_norm(conv(x, edge_index))))
return x

def get_attention_scores(self, x, edge_index, batch):
attention_scores_list = []
for i, (conv, batch_norm) in enumerate(zip(self.convs, self.bns)):

x, attention_scores = conv(x, edge_index, return_attention_weights=True) # adapt as per your GAT layer's API
attention_scores_list.append((i, attention_scores))
return attention_scores_list


class GCNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
Expand Down Expand Up @@ -261,6 +403,7 @@ def forward(self, x, edge_index, batch):
Returns:
Tensor: The output tensor after processing through the GCNN, with shape [num_nodes, output_dim].
"""
#print('batch:', batch.x)
x = self.layer_1(x, edge_index)
x = self.relu_1(x)
x = self.layer_2(x, edge_index)
Expand Down

0 comments on commit c0ebf36

Please sign in to comment.