diff --git a/.gitignore b/.gitignore index 4ae42c4..41c772c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ __pycache__/ *.so # Distribution / packaging +core.* +*.pth 9606.protein.aliases.v12.0.txt 9606.protein.links.v12.0.txt *.tgz diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 9404a14..c13c8fd 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -14,7 +14,10 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True) - parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True) + parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, + choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True) + parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str, + choices=["GC", "GCN", "GAT", "GIN", "SAGE", "CHEB"]) parser.add_argument("--target_variables", help="(Optional if survival variables are not set to None)." "Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", @@ -56,7 +59,7 @@ def main(): # 2. Check for required variables for model classes if args.model_class != "supervised_vae": if not any([args.target_variables, args.surv_event_var, args.batch_variables]): - parser.error(''.join(["When selecting a model other than 'supervised_vae'," + parser.error(''.join(["When selecting a model other than 'supervised_vae',", "you must provide at least one of --target_variables, ", "survival variables (--surv_event_var and --surv_time_var)", "or --batch_variables."])) @@ -82,7 +85,21 @@ def main(): else: device_type = 'cpu' torch.set_num_threads(args.threads) - + + # 5. check GNN arguments + if args.model_class == 'DirectPredGCNN': + if not args.gnn_conv_type: + warning_message = "\n".join([ + "\n\n!!! When running DirectPredGCNN, a convolution type can be set", + "with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.", + "Falling back on the default convolution type: GC !!!\n\n" + ]) + warnings.warn(warning_message) + time.sleep(3) #wait a bit to capture user's attention to the warning + gnn_conv_type = 'GC' + else: + gnn_conv_type = args.gnn_conv_type + # Validate paths if not os.path.exists(args.data_path): raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path}) @@ -144,7 +161,8 @@ class AvailableModels(NamedTuple): n_iter=int(args.hpo_iter), use_loss_weighting = args.use_loss_weighting == 'True', early_stop_patience = int(args.early_stop_patience), - device_type = device_type) + device_type = device_type, + gnn_conv_type = gnn_conv_type) # do a hyperparameter search training multiple models and get the best_configuration model, best_params = tuner.perform_tuning() diff --git a/flexynesis/config.py b/flexynesis/config.py index c4b789d..26555f5 100644 --- a/flexynesis/config.py +++ b/flexynesis/config.py @@ -39,6 +39,10 @@ Integer(64, 512, name="hidden_dim"), Real(0.0001, 0.01, prior="log-uniform", name="lr"), Integer(32, 128, name="batch_size"), - Categorical(epochs, name="epochs") - ], + Categorical(epochs, name="epochs"), + # below parameters apply to all convolution types except for 'GC' + Real(0.1, 0.4, prior="log-uniform", name="dropout"), + Integer(1, 3, name="number_layers"), + Categorical(['relu', 'sigmoid', 'tanh'], name="activation") + ] } diff --git a/flexynesis/main.py b/flexynesis/main.py index 1217753..ac4f746 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -50,7 +50,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, batch_variables = None, surv_event_var = None, surv_time_var = None, n_iter = 10, config_path = None, plot_losses = False, val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1, - device_type = None): + device_type = None, gnn_conv_type = None): self.dataset = dataset self.model_class = model_class self.target_variables = target_variables @@ -71,6 +71,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, progress_bar_finished='red')) self.early_stop_patience = early_stop_patience self.use_loss_weighting = use_loss_weighting + self.gnn_conv_type = gnn_conv_type # If config_path is provided, use it if config_path: @@ -93,7 +94,8 @@ def objective(self, params, current_step, total_steps): surv_time_var = self.surv_time_var, val_size = self.val_size, use_loss_weighting = self.use_loss_weighting, - device_type = self.device_type) + device_type = self.device_type, + gnn_conv_type = self.gnn_conv_type) print(params) mycallbacks = [self.progress_bar] diff --git a/flexynesis/models/direct_pred_gcnn.py b/flexynesis/models/direct_pred_gcnn.py index 6839879..afe7259 100644 --- a/flexynesis/models/direct_pred_gcnn.py +++ b/flexynesis/models/direct_pred_gcnn.py @@ -13,7 +13,7 @@ from captum.attr import IntegratedGradients -from ..modules import GCNN, MLP, cox_ph_loss +from ..modules import GCNN, MLP, cox_ph_loss, GraphNNs class DirectPredGCNN(pl.LightningModule): @@ -27,7 +27,8 @@ def __init__( surv_time_var=None, val_size=0.2, use_loss_weighting=True, - device_type = None + device_type = None, + gnn_conv_type = None ): super().__init__() self.config = config @@ -47,7 +48,8 @@ def __init__( self.use_loss_weighting = use_loss_weighting self.device_type = device_type - + self.gnn_conv_type = gnn_conv_type + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() @@ -59,16 +61,30 @@ def __init__( # NOTE: For now we use matrices, so number of node input features is 1. input_dims = [1 for _ in range(len(layers))] - self.encoders = nn.ModuleList( - [ - GCNN( - input_dim=input_dims[i], - hidden_dim=int(self.config["hidden_dim"]), # int because of pyg - output_dim=self.config["latent_dim"], - ) - for i in range(len(layers)) - ] - ) + if self.gnn_conv_type == 'GCNN': + self.encoders = nn.ModuleList( + [ + GCNN( + input_dim=input_dims[i], + hidden_dim=int(self.config["hidden_dim"]), # int because of pyg + output_dim=self.config["latent_dim"], + ) + for i in range(len(layers)) + ]) + else: + self.encoders = nn.ModuleList( + [ + GraphNNs( + input_dim=input_dims[i], + hidden_dim=int(self.config["hidden_dim"]), # int because of pyg + output_dim=self.config["latent_dim"], + act = self.config['activation'], + number_layers = self.config['number_layers'], + dropout = self.config['dropout'], + conv = self.gnn_conv_type + ) + for i in range(len(layers)) + ]) # Init output layers self.MLPs = nn.ModuleDict() diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 92726e1..f69972f 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -3,6 +3,8 @@ import torch from torch import nn import torch_geometric.nn as gnn +from torch_geometric.nn import GCNConv, GATConv, GINConv, PNAConv, SAGEConv, ChebConv, GraphConv, \ + global_mean_pool as gmeanp, global_max_pool as gmaxp, global_add_pool as gap __all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "GCNN", "cox_ph_loss"] @@ -225,6 +227,146 @@ def forward(self, x): x = x.squeeze(-1) return x +class GraphNNs(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, conv='CHEB', + dropout=0.1, number_layers=1, deg=None, act = None): + super().__init__() + """ + Initializes a Graph Neural Network model with customizable convolution types, activation functions, + and the ability to specify the number of layers. + This model can use various graph convolution types (e.g., GCN, GAT, GIN, SAGE, CHEB) as specified by the user, + each potentially followed by batch normalization and a specified activation function. Dropout is applied for regularization. + The model concludes with a fully connected layer to produce the output. + + Args: + input_dim (int): The dimensionality of input features. + hidden_dim (int): The dimensionality of features after the first fully connected layer. + output_dim (int): The dimensionality of the output features. + conv (str): The type of graph convolution to use. Defaults to 'CHEB'. + dropout (float): Dropout rate for regularization. Defaults to 0.5. + number_layers (int): The number of convolutional layers. Defaults to 5. + deg: (torch.Tensor): A tensor of the degrees of the nodes in the + input graph. Default value is None.(used in specific convolution types like PNA). + act (str): The activation function to use. Options include 'relu', 'sigmoid', etc. + + + Methods: + + - `reset_parameters()`: A method that initializes the parameters of + the model. + + Inputs: + + - `batch` (torch_geometric.data.Batch): A PyTorch Geometric batch + object that represents the input graph. The batch object contains + the following attributes: + - `x` (torch.Tensor): A tensor of node features. + + - `edge_index` (torch.LongTensor): A tensor of shape `(2, num_edges)` that + represents the indices of the edges in the graph. + + - `batch` (torch.LongTensor): A tensor of shape `(num_nodes,)` that + indicates the membership of each node in a particular graph in the + batch. + + - `edge_attr` (torch.Tensor): A tensor of shape `(num_edges, num_edge_features)` + that represents the edge features. If there are no edge features, then this + tensor is not used. + + - `inputs` (torch.Tensor): A tensor of shape `(num_graphs, num_input_features)` + that represents the Molecular Mechanic features for each graph in the batch. + If `inputs` is `"False"`, then this tensor is not used. + + - `output` (torch.Tensor): A tensor of shape `(num_graphs, output_dim)` + that represents the target labels for each graph in the batch. + + Outputs: + + - `x` (torch.Tensor): A tensor of shape `(batch_size, output_dim)` + that represents the predicted class probabilities for each graph in + the batch. + """ + act_options = { + 'relu': (torch.nn.ReLU()), + 'sigmoid': (torch.nn.Sigmoid()), + 'tanh': (torch.nn.Tanh()), + 'softmax': (torch.nn.Softmax(dim=None)), + 'leakyrelu': (torch.nn.LeakyReLU(negative_slope=0.01, inplace=False)), + 'elu': (torch.nn.ELU(alpha=1.0, inplace=False)), + 'gelu': (torch.nn.GELU()) + } + # check if the activation function string is valid + if act not in act_options: + raise ValueError("Invalid activation function string. Choose from 'relu', 'sigmoid', 'tanh', 'softmax', 'leakyrelu', 'elu', or 'gelu'.") + + # instantiate the activation function + self.activation = act_options[act] + + self.dropout = torch.nn.Dropout(dropout) + + self.convs = torch.nn.ModuleList() + self.bns = torch.nn.ModuleList() + + conv_options = { + 'GCN': (GCNConv(input_dim, input_dim)), + 'GAT': (GATConv(input_dim, input_dim)), + 'GIN': (GINConv(nn.Sequential(nn.Linear(input_dim, input_dim), + torch.nn.BatchNorm1d(input_dim), + nn.ReLU(), nn.Linear(input_dim, input_dim)))), + 'SAGE': (SAGEConv(input_dim, input_dim)), + 'CHEB': (ChebConv(input_dim, input_dim, K=2)), + 'GC': (GraphConv(input_dim, input_dim)) + } + if conv not in conv_options: + raise ValueError('Unknown convolution type. Choose one of: ',conv_options.keys()) + + self.conv = conv_options[conv] + + for i in range(number_layers): + self.convs.append(self.conv) + self.bns.append(nn.BatchNorm1d(input_dim)) + + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.bn_ff1 = nn.BatchNorm1d(hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + + + def reset_parameters(self): + for layer in self.children(): + if hasattr(layer, 'reset_parameters'): + layer.reset_parameters() + + def forward(self, x, edge_index, batch): + + for conv, batch_norm in zip(self.convs, self.bns): + if conv == "GCN" or conv == "CHEB" or conv == "GAT": + x = self.dropout(self.activation(batch_norm(conv(x, edge_index)))) + else: + x = self.dropout(self.activation(batch_norm(conv(x, edge_index)))) + x = gap(x, batch) + + x = self.activation(self.bn_ff1(self.fc1(x))) + x = self.dropout(x) + out = self.fc2(x) + + return out + + def extract_embeddings(self, x, edge_index, batch): + for conv, batch_norm in zip(self.convs, self.bns): + if conv == "GCN" or conv == "CHEB" or conv == "GAT": + x = self.dropout(self.activation(batch_norm(conv(x, edge_index)))) + else: + x = self.dropout(self.activation(batch_norm(conv(x, edge_index)))) + return x + + def get_attention_scores(self, x, edge_index, batch): + attention_scores_list = [] + for i, (conv, batch_norm) in enumerate(zip(self.convs, self.bns)): + + x, attention_scores = conv(x, edge_index, return_attention_weights=True) # adapt as per your GAT layer's API + attention_scores_list.append((i, attention_scores)) + return attention_scores_list + class GCNN(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): @@ -261,6 +403,7 @@ def forward(self, x, edge_index, batch): Returns: Tensor: The output tensor after processing through the GCNN, with shape [num_nodes, output_dim]. """ + #print('batch:', batch.x) x = self.layer_1(x, edge_index) x = self.relu_1(x) x = self.layer_2(x, edge_index)