From 2d1be3640e0f1f7cca3b61ddda501788d63f19b2 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 17 Mar 2024 13:07:03 +0100 Subject: [PATCH 1/6] add a first draft implementation of cross-modality prediction class --- flexynesis/models/crossmodal_pred.py | 465 +++++++++++++++++++++++++++ 1 file changed, 465 insertions(+) create mode 100644 flexynesis/models/crossmodal_pred.py diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py new file mode 100644 index 0000000..fdea447 --- /dev/null +++ b/flexynesis/models/crossmodal_pred.py @@ -0,0 +1,465 @@ +import torch +import itertools +from torch import nn +from torch.nn import functional as F +from torch.utils.data import Dataset, DataLoader, random_split + +import pandas as pd +import numpy as np + +import lightning as pl +from scipy import stats + +from captum.attr import IntegratedGradients + +from ..modules import * + + +class CrossModalPred(pl.LightningModule): + """ + A Cross-Modality Prediction Architecture that encodes user-specified input data modalities and + tries to reconstruct user-specificed output data modalities. In the case where input/output data modalities + are the same, this behaves like an auto-encoder. + The network also can be connected to one or more MLPs for outcome variable prediction. + + dataset: dictionary of data matrices + input_layers: which data modalities from `dataset` to encode (use a subset of keys from `dataset`) + output_layers: which data modalities are aimed to be reconsructed via decoders (use a subset of keys from `dataset`). + + """ + def __init__(self, config, dataset, target_variables = None, batch_variables = None, + surv_event_var = None, surv_time_var = None, + input_layers = None, output_layers = None, + val_size = 0.2, use_loss_weighting = True, + device_type = None): + super(CrossModalPred, self).__init__() + self.config = config + self.dataset = dataset + self.target_variables = target_variables + self.surv_event_var = surv_event_var + self.surv_time_var = surv_time_var + # both surv event and time variables are assumed to be numerical variables + # we create only one survival variable for the pair (surv_time_var and surv_event_var) + if self.surv_event_var is not None and self.surv_time_var is not None: + self.target_variables = self.target_variables + [self.surv_event_var] + self.batch_variables = batch_variables + self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + + self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) + self.output_layers = output_layers if output_layers else list(dataset.dat.keys()) + + self.val_size = val_size + + self.dat_train, self.dat_val = self.prepare_data() + self.feature_importances = {} + + self.device_type = device_type + + self.use_loss_weighting = use_loss_weighting + + if self.use_loss_weighting: + # Initialize log variance parameters for uncertainty weighting + self.log_vars = nn.ParameterDict() + for loss_type in itertools.chain(self.variables, ['mmd_loss']): + self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) + + # create a list of Encoder instances for separately encoding each input omics layer + input_dims = [len(dataset.features[self.input_layers[i]]) for i in range(len(self.input_layers))] + self.encoders = nn.ModuleList([Encoder(input_dims[i], [config['hidden_dim']], config['latent_dim']) for i in range(len(self.input_layers))]) + + # Fully connected layers for concatenated means and log_vars + self.FC_mean = nn.Linear(len(self.input_layers) * config['latent_dim'], config['latent_dim']) + self.FC_log_var = nn.Linear(len(self.input_layers) * config['latent_dim'], config['latent_dim']) + + # list of decoders to decode the latent layer into the target/output layers + output_dims = [len(dataset.features[self.output_layers[i]]) for i in range(len(self.output_layers))] + self.decoders = nn.ModuleList([Decoder(config['latent_dim'], [config['hidden_dim']], output_dims[i]) for i in range(len(self.output_layers))]) + + # define supervisor heads + # using ModuleDict to store multiple MLPs + self.MLPs = nn.ModuleDict() + for var in self.variables: + if self.dataset.variable_types[var] == 'numerical': + num_class = 1 + else: + num_class = len(np.unique(self.dataset.ann[var])) + self.MLPs[var] = MLP(input_dim = config['latent_dim'], + hidden_dim = config['supervisor_hidden_dim'], + output_dim = num_class) + + def multi_encoder(self, x_list): + """ + Encode each input matrix separately using the corresponding Encoder. + + Args: + x_list (list of torch.Tensor): List of input matrices for each omics layer. + + Returns: + tuple: Tuple containing: + - mean (torch.Tensor): Concatenated mean values from each encoder. + - log_var (torch.Tensor): Concatenated log variance values from each encoder. + """ + means, log_vars = [], [] + # Process each input matrix with its corresponding Encoder + for i, x in enumerate(x_list): + mean, log_var = self.encoders[i](x) + means.append(mean) + log_vars.append(log_var) + + # Concatenate means and log_vars + # Push concatenated means and log_vars through the fully connected layers + mean = self.FC_mean(torch.cat(means, dim=1)) + log_var = self.FC_log_var(torch.cat(log_vars, dim=1)) + return mean, log_var + + def forward(self, x_list): + """ + Forward pass through the model. + + Args: + x_list (list of torch.Tensor): List of input matrices for each omics layer. + + Returns: + tuple: Tuple containing: + - x_hat_list (list of torch.Tensor): List of reconstructed matrices for each omics layer. + - z (torch.Tensor): Latent representation. + - mean (torch.Tensor): Concatenated mean values from each encoder. + - log_var (torch.Tensor): Concatenated log variance values from each encoder. + - y_pred (torch.Tensor): Predicted output. + """ + mean, log_var = self.multi_encoder(x_list) + + # generate latent layer + z = self.reparameterization(mean, log_var) + + # Decode each latent variable with its corresponding Decoder + x_hat_list = [self.decoders[i](z) for i in range(len(x_list))] + + #run the supervisor heads using the latent layer as input + outputs = {} + for var, mlp in self.MLPs.items(): + outputs[var] = mlp(z) + + return x_hat_list, z, mean, log_var, outputs + + def reparameterization(self, mean, var): + """ + Reparameterize the mean and variance values. + + Args: + mean (torch.Tensor): Mean values from the encoders. + var (torch.Tensor): Variance values from the encoders. + + Returns: + torch.Tensor: Latent representation. + """ + epsilon = torch.randn_like(var) + z = mean + var*epsilon + return z + + def configure_optimizers(self): + """ + Configure the optimizer for the model. + + Returns: + torch.optim.Adam: Adam optimizer with learning rate 1e-3. + """ + optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + return optimizer + + def compute_loss(self, var, y, y_hat): + if self.dataset.variable_types[var] == 'numerical': + # Ignore instances with missing labels for numerical variables + valid_indices = ~torch.isnan(y) + if valid_indices.sum() > 0: # only calculate loss if there are valid targets + y_hat = y_hat[valid_indices] + y = y[valid_indices] + loss = F.mse_loss(torch.flatten(y_hat), y.float()) + else: + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + else: + # Ignore instances with missing labels for categorical variables + # Assuming that missing values were encoded as -1 + valid_indices = (y != -1) & (~torch.isnan(y)) + if valid_indices.sum() > 0: # only calculate loss if there are valid targets + y_hat = y_hat[valid_indices] + y = y[valid_indices] + loss = F.cross_entropy(y_hat, y.long()) + else: + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + return loss + + def compute_total_loss(self, losses): + if self.use_loss_weighting and len(losses) > 1: + # Compute weighted loss for each loss + # Weighted loss = precision * loss + log-variance + total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + else: + # Compute unweighted total loss + total_loss = sum(losses.values()) + return total_loss + + + def training_step(self, train_batch, batch_idx): + dat, y_dict = train_batch + + # get input omics modalities and encode them + x_list_input = [dat[x] for x in self.input_layers] + x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) + + # compute mmd loss for the latent space + reconsruction loss for each target/output layer + x_list_output = [dat[x] for x in self.output_layers] + mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] + mmd_loss = torch.mean(torch.stack(mmd_loss_list)) + + # compute loss values for the supervisor heads + losses = {'mmd_loss': mmd_loss} + + for var in self.variables: + if var == self.surv_event_var: + durations = y_dict[self.surv_time_var] + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] #output of MLP + loss = cox_ph_loss(risk_scores, durations, events) + else: + y_hat = outputs[var] + y = y_dict[var] + loss = self.compute_loss(var, y, y_hat) + losses[var] = loss + + total_loss = self.compute_total_loss(losses) + # add total loss for logging + losses['train_loss'] = total_loss + self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) + return total_loss + + def validation_step(self, val_batch, batch_idx): + dat, y_dict = val_batch + + # get input omics modalities and encode them + x_list_input = [dat[x] for x in self.input_layers] + x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) + + # compute mmd loss for the latent space + reconsruction loss for each target/output layer + x_list_output = [dat[x] for x in self.output_layers] + mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] + mmd_loss = torch.mean(torch.stack(mmd_loss_list)) + + # compute loss values for the supervisor heads + losses = {'mmd_loss': mmd_loss} + for var in self.variables: + if var == self.surv_event_var: + durations = y_dict[self.surv_time_var] + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] #output of MLP + loss = cox_ph_loss(risk_scores, durations, events) + else: + y_hat = outputs[var] + y = y_dict[var] + loss = self.compute_loss(var, y, y_hat) + losses[var] = loss + + total_loss = sum(losses.values()) + losses['val_loss'] = total_loss + self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) + return total_loss + + def prepare_data(self): + lt = int(len(self.dataset)*(1-self.val_size)) + lv = len(self.dataset)-lt + dat_train, dat_val = random_split(self.dataset, [lt, lv], + generator=torch.Generator().manual_seed(42)) + return dat_train, dat_val + + def train_dataloader(self): + return DataLoader(self.dat_train, batch_size=int(self.config['batch_size']), num_workers=0, pin_memory=True, shuffle=True, drop_last=True) + + def val_dataloader(self): + return DataLoader(self.dat_val, batch_size=int(self.config['batch_size']), num_workers=0, pin_memory=True, shuffle=False) + + def transform(self, dataset): + """ + Transform the input dataset to latent representation. + + Args: + dataset (MultiOmicDataset): MultiOmicDataset containing input matrices for each omics layer. + + Returns: + pd.DataFrame: Transformed dataset as a pandas DataFrame. + """ + self.eval() + x_list = [dataset.dat[x] for x in self.input_layers] + M = self.forward(x_list)[1].detach().numpy() + z = pd.DataFrame(M) + z.columns = [''.join(['E', str(x)]) for x in z.columns] + z.index = dataset.samples + return z + + def predict(self, dataset): + """ + Evaluate the model on a dataset. + + Args: + dataset (CustomDataset): Custom dataset containing input matrices for each omics layer. + + Returns: + predicted values. + """ + self.eval() + + x_list = [dataset.dat[x] for x in self.input_layers] + X_hat, z, mean, log_var, outputs = self.forward(x_list) + + predictions = {} + for var in self.variables: + y_pred = outputs[var].detach().numpy() + if self.dataset.variable_types[var] == 'categorical': + predictions[var] = np.argmax(y_pred, axis=1) + else: + predictions[var] = y_pred + + return predictions + + + def decode(self, dataset): + """ + Extract the decoded values of the target/output layers + """ + + self.eval() + x_list = [dataset.dat[x] for x in self.input_layers] + X_hat, z, mean, log_var, outputs = self.forward(x_list) + return X_hat + + + def compute_kernel(self, x, y): + """ + Compute the Gaussian kernel matrix between two sets of vectors. + + Args: + x (torch.Tensor): A tensor of shape (x_size, dim) representing the first set of vectors. + y (torch.Tensor): A tensor of shape (y_size, dim) representing the second set of vectors. + + Returns: + torch.Tensor: The Gaussian kernel matrix of shape (x_size, y_size) computed between x and y. + """ + x_size = x.size(0) + y_size = y.size(0) + dim = x.size(1) + x = x.unsqueeze(1) # (x_size, 1, dim) + y = y.unsqueeze(0) # (1, y_size, dim) + tiled_x = x.expand(x_size, y_size, dim) + tiled_y = y.expand(x_size, y_size, dim) + kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) + return torch.exp(-kernel_input) # (x_size, y_size) + + def compute_mmd(self, x, y): + """ + Compute the maximum mean discrepancy (MMD) between two sets of vectors. + + Args: + x (torch.Tensor): A tensor of shape (x_size, dim) representing the first set of vectors. + y (torch.Tensor): A tensor of shape (y_size, dim) representing the second set of vectors. + + Returns: + torch.Tensor: A scalar tensor representing the MMD between x and y. + """ + x_kernel = self.compute_kernel(x, x) + y_kernel = self.compute_kernel(y, y) + xy_kernel = self.compute_kernel(x, y) + mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() + return mmd + + def MMD_loss(self, latent_dim, z, xhat, x): + """ + Compute the loss function based on maximum mean discrepancy (MMD) and negative log likelihood (NLL). + + Args: + latent_dim (int): The dimensionality of the latent space. + z (torch.Tensor): A tensor of shape (batch_size, latent_dim) representing the latent codes. + xhat (torch.Tensor): A tensor of shape (batch_size, dim) representing the reconstructed data. + x (torch.Tensor): A tensor of shape (batch_size, dim) representing the original data. + + Returns: + torch.Tensor: A scalar tensor representing the MMD loss. + """ + true_samples = torch.randn(200, latent_dim, device = self.device) + mmd = self.compute_mmd(true_samples, z) # compute maximum mean discrepancy (MMD) + nll = (xhat - x).pow(2).mean() #negative log likelihood + return mmd+nll + + # Adaptor forward function for captum integrated gradients. + def forward_target(self, *args): + input_data = list(args[:-2]) # one or more tensors (one per omics layer) + target_var = args[-2] # target variable of interest + steps = args[-1] # number of steps for IntegratedGradients().attribute + outputs_list = [] + for i in range(steps): + # get list of tensors for each step into a list of tensors + x_step = [input_data[j][i] for j in range(len(input_data))] + x_hat_list, z, mean, log_var, outputs = self.forward(x_step) + outputs_list.append(outputs[target_var]) + return torch.cat(outputs_list, dim = 0) + + def compute_feature_importance(self, target_var, steps = 5): + """ + Compute the feature importance. + + Args: + input_data (torch.Tensor): The input data to compute the feature importance for. + target_var (str): The target variable to compute the feature importance for. + Returns: + attributions (list of torch.Tensor): The feature importances for each class. + """ + device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu') + self.to(device) + + print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) + x_list = [self.dataset.dat[x].to(device) for x in self.dataset.dat.keys()] + + # Initialize the Integrated Gradients method + ig = IntegratedGradients(self.forward_target) + + input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) + + # Define a baseline (you might need to adjust this depending on your actual data) + baseline = tuple([torch.zeros_like(data) for data in input_data]) + + # Get the number of classes for the target variable + if self.dataset.variable_types[target_var] == 'numerical': + num_class = 1 + else: + num_class = len(np.unique(self.dataset.ann[target_var])) + + # Compute the feature importance for each class + attributions = [] + if num_class > 1: + for target_class in range(num_class): + attributions.append(ig.attribute(input_data, baseline, additional_forward_args=(target_var, steps), target=target_class, n_steps=steps)) + else: + attributions.append(ig.attribute(input_data, baseline, additional_forward_args=(target_var, steps), n_steps=steps)) + + # summarize feature importances + # Compute absolute attributions + # Move the processed tensors to CPU for further operations that are not supported on GPU + abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in attributions] + # average over samples + imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] + + # move the model also back to cpu (if not already on cpu) + self.to('cpu') + + # combine into a single data frame + df_list = [] + layers = self.input_layers + for i in range(num_class): + for j in range(len(layers)): + features = self.dataset.features[layers[j]] + importances = imp[i][j][0].detach().numpy() + df_list.append(pd.DataFrame({'target_variable': target_var, 'target_class': i, 'layer': layers[j], 'name': features, 'importance': importances})) + df_imp = pd.concat(df_list, ignore_index = True) + + # save scores in model + self.feature_importances[target_var] = df_imp + + From 7e47be69582b731db178fca620b77a78c2ba5eda Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 17 Mar 2024 13:07:21 +0100 Subject: [PATCH 2/6] export the class --- flexynesis/models/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flexynesis/models/__init__.py b/flexynesis/models/__init__.py index 52d79bb..1c1e288 100644 --- a/flexynesis/models/__init__.py +++ b/flexynesis/models/__init__.py @@ -2,5 +2,6 @@ from .direct_pred_gcnn import DirectPredGCNN from .supervised_vae import supervised_vae from .triplet_encoder import MultiTripletNetwork +from .crossmodal_pred import CrossModalPred -__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"] +__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"] From ae379aeae2f8409d4bf9238857ee3432799446ad Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 17 Mar 2024 13:44:53 +0100 Subject: [PATCH 3/6] define default hpo space for crossmodalpred --- flexynesis/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/flexynesis/config.py b/flexynesis/config.py index 0a0baf0..bddad50 100644 --- a/flexynesis/config.py +++ b/flexynesis/config.py @@ -19,6 +19,14 @@ Integer(32, 128, name='batch_size'), Categorical(epochs, name='epochs') ], + 'CrossModalPred': [ + Integer(16, 128, name='latent_dim'), + Integer(64, 512, name='hidden_dim'), + Integer(8, 32, name='supervisor_hidden_dim'), + Real(0.0001, 0.01, prior='log-uniform', name='lr'), + Integer(32, 128, name='batch_size'), + Categorical(epochs, name='epochs') + ], 'MultiTripletNetwork': [ Integer(16, 128, name='latent_dim'), Integer(64, 512, name='hidden_dim'), From c03aac4e6629f6730777b1c8669e98361ce0b7ba Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 17 Mar 2024 13:45:30 +0100 Subject: [PATCH 4/6] define feature importance for input_layers only --- flexynesis/models/crossmodal_pred.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index fdea447..5bb123c 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -415,7 +415,7 @@ def compute_feature_importance(self, target_var, steps = 5): self.to(device) print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) - x_list = [self.dataset.dat[x].to(device) for x in self.dataset.dat.keys()] + x_list = [self.dataset.dat[x].to(device) for x in self.input_layers] # Initialize the Integrated Gradients method ig = IntegratedGradients(self.forward_target) @@ -451,7 +451,7 @@ def compute_feature_importance(self, target_var, steps = 5): # combine into a single data frame df_list = [] - layers = self.input_layers + layers = self.input_layers for i in range(num_class): for j in range(len(layers)): features = self.dataset.features[layers[j]] From 883bd1a7dbafb13f506a9a2f35a777225847f21b Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 17 Mar 2024 13:45:54 +0100 Subject: [PATCH 5/6] adapt main CLI to use CrossModalPred --- flexynesis/__main__.py | 47 ++++++++++++++++++++++++++++++++++++------ flexynesis/main.py | 8 ++++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index b6be416..358c485 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -16,7 +16,7 @@ def main(): parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True) parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, - choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True) + choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"], required = True) parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str, choices=["GC", "GCN", "GAT", "SAGE"]) parser.add_argument("--target_variables", @@ -36,6 +36,16 @@ def main(): parser.add_argument("--features_min", help="Minimum number of features to retain after feature selection", type=int, default = 500) parser.add_argument("--features_top_percentile", help="Top percentile features to retain after feature selection", type=float, default = 20) parser.add_argument("--data_types", help="(Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'", type=str, required = True) + parser.add_argument("--input_layers", + help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers" + "Comma-separated if multiple", + type=str, default = None + ) + parser.add_argument("--output_layers", + help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers" + "Comma-separated if multiple", + type=str, default = None + ) parser.add_argument("--outdir", help="Path to the output folder to save the model outputs", type=str, default = os.getcwd()) parser.add_argument("--prefix", help="Job prefix to use for output files", type=str, default = 'job') parser.add_argument("--log_transform", help="whether to apply log-transformation to input data matrices", type=str, choices=['True', 'False'], default = 'False') @@ -71,9 +81,14 @@ def main(): "or --batch_variables."])) # 3. Check for compatibility of fusion_type with DirectPredGCNN - if args.fusion_type == "early" and args.model_class == "DirectPredGCNN": - parser.error("The 'DirectPredGCNN' model cannot be used with early fusion type. " - "Use --fusion_type intermediate instead.") + if args.fusion_type == "early": + if args.model_class == "DirectPredGCNN": + parser.error("The 'DirectPredGCNN' model cannot be used with early fusion type. " + "Use --fusion_type intermediate instead.") + if args.model_class == 'CrossModalPred': + parser.error("The 'CrossModalPred' model cannot be used with early fusion type. " + "Use --fusion_type intermediate instead.") + # 4. Check for device availability if --accelerator is set. if args.use_gpu: @@ -108,6 +123,23 @@ def main(): else: gnn_conv_type = None + # 6. Check CrossModalPred arguments + input_layers = args.input_layers + output_layers = args.output_layers + datatypes = args.data_types.strip().split(',') + if args.model_class == 'CrossModalPred': + # check if input output layers are matching the requested data types + if args.input_layers: + input_layers = input_layers.strip().split(',') + # Check if input_layers are a subset of datatypes + if not all(layer in datatypes for layer in input_layers): + raise ValueError(f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes}).") + # check if output_layers are a subset of datatypes + if args.output_layers: + output_layers = output_layers.strip().split(',') + if not all(layer in datatypes for layer in output_layers): + raise ValueError(f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes}).") + # Validate paths if not os.path.exists(args.data_path): raise FileNotFoundError(f"Input --data_path doesn't exist at:", {args.data_path}) @@ -120,6 +152,7 @@ class AvailableModels(NamedTuple): supervised_vae: tuple[supervised_vae, str] = supervised_vae, "supervised_vae" MultiTripletNetwork: tuple[MultiTripletNetwork, str] = MultiTripletNetwork, "MultiTripletNetwork" DirectPredGCNN: tuple[DirectPredGCNN, str] = DirectPredGCNN, "DirectPredGCNN" + CrossModalPred: tuple[CrossModalPred, str] = CrossModalPred, "CrossModalPred" available_models = AvailableModels() model_class = getattr(available_models, args.model_class, None) @@ -140,7 +173,7 @@ class AvailableModels(NamedTuple): concatenate = True data_importer = flexynesis.DataImporter(path = args.data_path, - data_types = args.data_types.strip().split(','), + data_types = datatypes, concatenate = concatenate, log_transform = args.log_transform == 'True', correlation_threshold = args.correlation_threshold, @@ -173,7 +206,9 @@ class AvailableModels(NamedTuple): use_loss_weighting = args.use_loss_weighting == 'True', early_stop_patience = int(args.early_stop_patience), device_type = device_type, - gnn_conv_type = gnn_conv_type) + gnn_conv_type = gnn_conv_type, + input_layers = input_layers, + output_layers = output_layers) # do a hyperparameter search training multiple models and get the best_configuration model, best_params = tuner.perform_tuning() diff --git a/flexynesis/main.py b/flexynesis/main.py index 4fbd77f..9f993f1 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -50,7 +50,8 @@ def __init__(self, dataset, model_class, config_name, target_variables, batch_variables = None, surv_event_var = None, surv_time_var = None, n_iter = 10, config_path = None, plot_losses = False, val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1, - device_type = None, gnn_conv_type = None): + device_type = None, gnn_conv_type = None, + input_layers = None, output_layers = None): self.dataset = dataset self.model_class = model_class self.target_variables = target_variables @@ -72,6 +73,8 @@ def __init__(self, dataset, model_class, config_name, target_variables, self.early_stop_patience = early_stop_patience self.use_loss_weighting = use_loss_weighting self.gnn_conv_type = gnn_conv_type + self.input_layers = input_layers + self.output_layers = output_layers # If config_path is provided, use it if config_path: @@ -95,6 +98,9 @@ def objective(self, params, current_step, total_steps): "use_loss_weighting": self.use_loss_weighting, "device_type": self.device_type} if self.model_class.__name__ == 'DirectPredGCNN': model_args["gnn_conv_type"] = self.gnn_conv_type + if self.model_class.__name__ == 'CrossModalPred': + model_args["input_layers"] = self.input_layers + model_args["output_layers"] = self.output_layers model = self.model_class(**model_args) print(params) From 379276e44646b2f45216ed5123ba4528623617a3 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 17 Mar 2024 13:47:57 +0100 Subject: [PATCH 6/6] define GHA for CrossModalPred --- .github/workflows/benchmarks.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index ca8ebc7..7a4a956 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -89,6 +89,12 @@ jobs: conda activate my_env flexynesis --data_path dataset1 --model_class supervised_vae --target_variables Erlotinib,Crizotinib --fusion_type early --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types gex,cnv --outdir . --prefix erlotinib_svae --early_stop_patience 3 --use_loss_weighting True --evaluate_baseline_performance False + - name: Run CrossModalPred + shell: bash -l {0} + run: | + conda activate my_env + flexynesis --data_path dataset1 --model_class CrossModalPred --target_variables Erlotinib --fusion_type intermediate --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types gex,cnv --input_layers gex --output_layers cnv --outdir . --prefix erlotinib_crossmodal --early_stop_patience 3 --use_loss_weighting True --evaluate_baseline_performance False + - name: Run MultiTripletNetwork shell: bash -l {0} run: |