diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index ac79dcf..6aa05c6 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -7,7 +7,7 @@ import flexynesis from flexynesis.models import * from lightning.pytorch.callbacks import EarlyStopping - +from .data import STRING, MultiOmicDatasetNW def main(): """ @@ -57,7 +57,7 @@ def main(): parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True) parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, - choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"], required = True) + choices=["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNNEarly"], required = True) parser.add_argument("--gnn_conv_type", help="If model_class is set to DirectPredGCNN, choose which graph convolution type to use", type=str, choices=["GC", "GCN", "GAT", "SAGE"]) parser.add_argument("--target_variables", @@ -109,11 +109,6 @@ def main(): parser.add_argument("--string_organism", help="STRING DB organism id.", type=int, default=9606) parser.add_argument("--string_node_name", help="Type of node name.", type=str, choices=["gene_name", "gene_id"], default="gene_name") - - warnings.filterwarnings("ignore", ".*does not have many workers.*") - warnings.filterwarnings("ignore", "has been removed as a dependency of the") - warnings.filterwarnings("ignore", "The `srun` command is available on your system but is not used") - args = parser.parse_args() # do some sanity checks on input arguments @@ -157,10 +152,10 @@ def main(): torch.set_num_threads(args.threads) # 5. check GNN arguments - if args.model_class == 'DirectPredGCNN': + if args.model_class == 'DirectPredGCNN' or args.model_class == 'GNNEarly': if not args.gnn_conv_type: warning_message = "\n".join([ - "\n\n!!! When running DirectPredGCNN, a convolution type can be set", + "\n\n!!! When running DirectPredGCNN or GNNEarly, a convolution type can be set", "with the --gnn_conv_type flag. See `flexynesis -h` for full set of options.", "Falling back on the default convolution type: GC !!!\n\n" ]) @@ -202,7 +197,8 @@ class AvailableModels(NamedTuple): MultiTripletNetwork: tuple[MultiTripletNetwork, str] = MultiTripletNetwork, "MultiTripletNetwork" DirectPredGCNN: tuple[DirectPredGCNN, str] = DirectPredGCNN, "DirectPredGCNN" CrossModalPred: tuple[CrossModalPred, str] = CrossModalPred, "CrossModalPred" - + GNNEarly: tuple[GNNEarly, str] = GNNEarly, "GNNEarly" + available_models = AvailableModels() model_class = getattr(available_models, args.model_class, None) if model_class is None: @@ -237,6 +233,16 @@ class AvailableModels(NamedTuple): downsample = args.subsample) train_dataset, test_dataset = data_importer.import_data(force = True) + if args.model_class == 'GNNEarly': + # overlay datasets with network info + # this is a temporary solution + print("[INFO] Overlaying the dataset with network data from STRINGDB") + obj = STRING('STRING', "9606", "gene_name") + train_dataset = MultiOmicDatasetNW(train_dataset, obj.graph_df) + train_dataset.print_stats() + test_dataset = MultiOmicDatasetNW(test_dataset, obj.graph_df) + + # print feature logs to file (we use these tables to track which features are dropped/selected and why) feature_logs = data_importer.feature_logs for key in feature_logs.keys(): @@ -347,14 +353,20 @@ class AvailableModels(NamedTuple): print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0]) var = model.target_variables[0] metrics = pd.DataFrame() + + # in the case when GNNEarly was used, the we use the initial multiomicdataset for train/test + # because GNNEarly requires a modified dataset structure to fit the networks (temporary solution) + train = train_dataset.multiomic_dataset if args.model_class == 'GNNEarly' else train_dataset + test = test_dataset.multiomic_dataset if args.model_class == 'GNNEarly' else test_dataset + if var != model.surv_event_var: - metrics = flexynesis.evaluate_baseline_performance(train_dataset, test_dataset, + metrics = flexynesis.evaluate_baseline_performance(train, test, variable_name = var, n_folds=5, n_jobs = int(args.threads)) if model.surv_event_var and model.surv_time_var: print("[INFO] Computing off-the-shelf method performance on survival variable:",model.surv_time_var) - metrics_baseline_survival = flexynesis.evaluate_baseline_survival_performance(train_dataset, test_dataset, + metrics_baseline_survival = flexynesis.evaluate_baseline_survival_performance(train, test, model.surv_time_var, model.surv_event_var, n_folds = 5, diff --git a/flexynesis/config.py b/flexynesis/config.py index ec39ad5..1fef1c7 100644 --- a/flexynesis/config.py +++ b/flexynesis/config.py @@ -39,5 +39,13 @@ Categorical(epochs, name="epochs"), Integer(8, 32, name='supervisor_hidden_dim'), Categorical(['relu'], name="activation") + ], + 'GNNEarly': [ + Integer(16, 128, name='latent_dim'), + Real(0.2, 1, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim + Real(0.0001, 0.01, prior='log-uniform', name='lr'), + Integer(8, 32, name='supervisor_hidden_dim'), + Categorical(epochs, name='epochs'), + Categorical(['relu'], name="activation") ] } diff --git a/flexynesis/data.py b/flexynesis/data.py index 3033a7b..5a74c6e 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -200,7 +200,7 @@ def import_data(self, force=False): initial_edge_list = graph_df.to_numpy().tolist() # Read STRING by default. elif isinstance(self.graph, STRING): - graph_df = self.graph.df + graph_df = self.graph.graph_df available_features = np.unique(graph_df[["protein1", "protein2"]].to_numpy()).tolist() initial_edge_list = stringdb_links_to_list(graph_df) else: @@ -679,6 +679,7 @@ def get_dataset_stats(self): stats = {': '.join(['feature_count in', x]): self.dat[x].shape[1] for x in self.dat.keys()} stats['sample_count'] = len(self.samples) return(stats) + # given a MultiOmicDataset object, convert to Triplets (anchor,positive,negative) class TripletMultiOmicDataset(Dataset): @@ -714,7 +715,100 @@ def get_label_indices(self, labels): for label in labels_set} return labels_set, label_to_indices +class MultiOmicDatasetNW(Dataset): + def __init__(self, multiomic_dataset, interaction_df): + self.multiomic_dataset = multiomic_dataset + self.interaction_df = interaction_df + + # Precompute common features and edge index + self.common_features = self.find_common_features() + self.gene_to_index = {gene: idx for idx, gene in enumerate(self.common_features)} + self.edge_index = self.create_edge_index() + self.samples = self.multiomic_dataset.samples + self.variable_types = self.multiomic_dataset.variable_types + self.label_mappings = self.multiomic_dataset.label_mappings + self.ann = self.multiomic_dataset.ann + + # Precompute all node features for all samples + self.node_features_tensor = self.precompute_node_features() + + # Store labels for all samples + self.labels = {target_name: labels for target_name, labels in self.multiomic_dataset.ann.items()} + + # Store sample identifiers + self.samples = self.multiomic_dataset.samples + + def find_common_features(self): + common_features = set.intersection(*(set(features) for features in self.multiomic_dataset.features.values())) + interaction_genes = set(self.interaction_df['protein1']).union(set(self.interaction_df['protein2'])) + return list(common_features.intersection(interaction_genes)) + + def create_edge_index(self): + filtered_df = self.interaction_df[ + (self.interaction_df['protein1'].isin(self.common_features)) & + (self.interaction_df['protein2'].isin(self.common_features)) + ] + edge_list = [(self.gene_to_index[row['protein1']], self.gene_to_index[row['protein2']]) for index, row in filtered_df.iterrows()] + return torch.tensor(edge_list, dtype=torch.long).t() + + def precompute_node_features(self): + # Find indices of common features in each data matrix + feature_indices = {data_type: [self.multiomic_dataset.features[data_type].get_loc(gene) + for gene in self.common_features] + for data_type in self.multiomic_dataset.dat} + # Create a tensor to store all features [num_samples, num_nodes, num_data_types] + num_samples = len(self.samples) + num_nodes = len(self.common_features) + num_data_types = len(self.multiomic_dataset.dat) + all_features = torch.empty((num_samples, num_nodes, num_data_types), dtype=torch.float) + + # Extract features for each data type and place them in the tensor + for i, data_type in enumerate(self.multiomic_dataset.dat): + # Get the data matrix + data_matrix = self.multiomic_dataset.dat[data_type] + # Use advanced indexing to extract features for all samples at once + indices = feature_indices[data_type] + if indices: # Ensure there are common features in this data type + all_features[:, :, i] = data_matrix[:, indices] + return all_features + + def __getitem__(self, idx): + node_features_tensor = self.node_features_tensor[idx] + y_dict = {target_name: self.labels[target_name][idx] for target_name in self.labels} + return node_features_tensor, y_dict, self.samples[idx] + def __len__(self): + return len(self.samples) + + def print_stats(self): + """ + Prints various statistics about the graph. + """ + num_nodes = len(self.common_features) + num_edges = self.edge_index.size(1) + num_node_features = self.node_features_tensor.size(2) + + # Calculate degree for each node + degrees = torch.zeros(num_nodes, dtype=torch.long) + degrees.index_add_(0, self.edge_index[0], torch.ones_like(self.edge_index[0])) + degrees.index_add_(0, self.edge_index[1], torch.ones_like(self.edge_index[1])) # For undirected graphs + + num_singletons = torch.sum(degrees == 0).item() + non_singletons = degrees[degrees > 0] + + mean_edges_per_node = non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 + median_edges_per_node = non_singletons.median().item() if len(non_singletons) > 0 else 0 + max_edges = degrees.max().item() + + print("Dataset Statistics:") + print(f"Number of nodes: {num_nodes}") + print(f"Total number of edges: {num_edges}") + print(f"Number of node features per node: {num_node_features}") + print(f"Number of singletons (nodes with no edges): {num_singletons}") + print(f"Mean number of edges per node (excluding singletons): {mean_edges_per_node:.2f}") + print(f"Median number of edges per node (excluding singletons): {median_edges_per_node}") + print(f"Max number of edges per node: {max_edges}") + class MultiOmicPYGDataset(PYGDataset): required = ["variable_types", "features", "samples", "label_mappings", "feature_ann"] @@ -811,7 +905,11 @@ def __init__(self, root: str, organism: int = 9606, node_name: str = "gene_name" self.organism = organism self.node_name = node_name super().__init__(os.path.join(root, self.base_folder)) - self.df = read_user_graph(self.processed_paths[0], sep=",", header=0, index_col=0) + + if not os.path.exists(self.processed_paths[0]): + self.download() + self.process() + self.graph_df = pd.read_csv(self.processed_paths[0], sep=",", header=0, index_col=0) def len(self) -> int: return 0 @@ -839,8 +937,8 @@ def download(self) -> None: def process(self) -> None: graph_df = read_stringdb_graph(self.node_name, self.raw_paths[0], self.raw_paths[1]) - # Drop nans and save to disk. graph_df.dropna().to_csv(self.processed_paths[0]) + self.graph_df = graph_df # Storing the DataFrame in the instance def read_user_graph(fpath, sep=" ", header=None, **pd_read_csv_kw): @@ -859,6 +957,22 @@ def read_stringdb_links(fname): df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map(lambda a: a.split(".")[-1]) return df +def read_stringdb_links_test(fname): + df = pd.read_csv(fname, header=0, sep=" ") + df = df[df.combined_score > 800] + df = df[df.combined_score > df.combined_score.quantile(0.9)] + df_expanded = pd.concat([ + df.rename(columns={'protein1': 'protein', 'protein2': 'partner'}), + df.rename(columns={'protein2': 'protein', 'protein1': 'partner'}) + ]) + # Sort the expanded DataFrame by 'combined_score' in descending order + df_expanded_sorted = df_expanded.sort_values(by='combined_score', ascending=False) + # Reduce to unique interactions to avoid counting duplicates + df_expanded_unique = df_expanded_sorted.drop_duplicates(subset=['protein', 'partner']) + top_interactions = df_expanded_unique.groupby('protein').head(5) + df = top_interactions.rename(columns={'protein': 'protein1', 'partner': 'protein2'}) + df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map(lambda a: a.split(".")[-1]) + return df def read_stringdb_aliases(fname: str, node_name: str) -> dict[str, str]: if node_name == "gene_id": diff --git a/flexynesis/main.py b/flexynesis/main.py index 55955d3..3976dc2 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -19,7 +19,9 @@ import os, yaml from skopt.space import Integer, Categorical, Real +from .data import STRING +torch.set_float32_matmul_precision("medium") class HyperparameterTuning: """ @@ -104,7 +106,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, self.input_layers = input_layers self.output_layers = output_layers - self.DataLoader = DataLoader # use torch data loader by default + self.DataLoader = torch.utils.data.DataLoader # use torch data loader by default if self.model_class.__name__ == 'MultiTripletNetwork': self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0]) @@ -127,7 +129,7 @@ def __init__(self, dataset, model_class, config_name, target_variables, else: raise ValueError(f"'{self.config_name}' not found in the default config.") - def get_batch_space(self, min_size = 16, max_size = 256): + def get_batch_space(self, min_size = 32, max_size = 256): m = int(np.log2(len(self.dataset) * 0.8)) st = int(np.log2(min_size)) end = int(np.log2(max_size)) @@ -147,8 +149,9 @@ def setup_trainer(self, params, current_step, total_steps, full_train = False): if self.early_stop_patience > 0 and full_train == False: early_stop_callback = self.init_early_stopping() mycallbacks.append(early_stop_callback) - + trainer = pl.Trainer( + precision = '16-mixed', # mixed precision training max_epochs=int(params['epochs']), log_every_n_steps=5, callbacks=mycallbacks, @@ -173,7 +176,7 @@ def objective(self, params, current_step, total_steps, full_train = False): "device_type": self.device_type, } - if self.model_class.__name__ == 'DirectPredGCNN': + if self.model_class.__name__ == 'DirectPredGCNN' or self.model_class.__name__ == 'GNNEarly': model_args['gnn_conv_type'] = self.gnn_conv_type if self.model_class.__name__ == 'CrossModalPred': model_args['input_layers'] = self.input_layers @@ -208,11 +211,14 @@ def objective(self, params, current_step, total_steps, full_train = False): print(f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}") train_subset = torch.utils.data.Subset(self.loader_dataset, train_index) val_subset = torch.utils.data.Subset(self.loader_dataset, val_index) - train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=True, drop_last=True) - val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']), pin_memory=True, shuffle=False) + train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']), + pin_memory=True, shuffle=True, drop_last=True, num_workers = 4, prefetch_factor = None, persistent_workers = True) + val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']), + pin_memory=True, shuffle=False, num_workers = 4, prefetch_factor = None, persistent_workers = True) model = self.model_class(**model_args) trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps) + print(f"[INFO] hpo config:{params}") trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) if early_stop_callback.stopped_epoch: epochs.append(early_stop_callback.stopped_epoch) diff --git a/flexynesis/models/__init__.py b/flexynesis/models/__init__.py index 1c1e288..05a3819 100644 --- a/flexynesis/models/__init__.py +++ b/flexynesis/models/__init__.py @@ -3,5 +3,5 @@ from .supervised_vae import supervised_vae from .triplet_encoder import MultiTripletNetwork from .crossmodal_pred import CrossModalPred - -__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred"] +from .gnn_early import GNNEarly +__all__ = ["DirectPred", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNNEarly"] diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py new file mode 100644 index 0000000..4b345f7 --- /dev/null +++ b/flexynesis/models/gnn_early.py @@ -0,0 +1,344 @@ +import numpy as np +import pandas as pd + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.data import random_split + +import lightning as pl + +from torch.utils.data import DataLoader + +from captum.attr import IntegratedGradients + +from ..modules import MLP, cox_ph_loss, GNNs + + +class GNNEarly(pl.LightningModule): + def __init__( + self, + config, + dataset, # MultiOmicGeometricDataset object + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type = None, + gnn_conv_type = None + ): + super().__init__() + self.config = config + self.target_variables = target_variables + self.surv_event_var = surv_event_var + self.surv_time_var = surv_time_var + # both surv event and time variables are assumed to be numerical variables + # we create only one survival variable for the pair (surv_time_var and surv_event_var) + if self.surv_event_var is not None and self.surv_time_var is not None: + self.target_variables = self.target_variables + [self.surv_event_var] + self.batch_variables = batch_variables + self.variables = self.target_variables + self.batch_variables if self.batch_variables else self.target_variables + self.variable_types = dataset.multiomic_dataset.variable_types + self.ann = dataset.multiomic_dataset.ann + self.edge_index = dataset.edge_index + + self.feature_importances = {} + self.use_loss_weighting = use_loss_weighting + + self.device_type = device_type + self.gnn_conv_type = gnn_conv_type + + device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu') + self.edge_index = self.edge_index.to(device) # edge index is re-used across samples, so we keep it in device + + if self.use_loss_weighting: + # Initialize log variance parameters for uncertainty weighting + self.log_vars = nn.ParameterDict() + for var in self.variables: + self.log_vars[var] = nn.Parameter(torch.zeros(1)) + + node_features = dataset[0][0].shape[1] # number of node features + node_count = dataset[0][0].shape[0] #number of nodes + self.encoders = GNNs( + input_dim=node_features, + hidden_dim=int(self.config["hidden_dim_factor"] * node_count), + output_dim=self.config["latent_dim"], + act = self.config['activation'], + conv = self.gnn_conv_type + ) + + # Init output layers + self.MLPs = nn.ModuleDict() + for var in self.variables: + if self.variable_types[var] == "numerical": + num_class = 1 + else: + num_class = len(np.unique(self.ann[var])) + self.MLPs[var] = MLP( + input_dim=self.config["latent_dim"], + hidden_dim=self.config["supervisor_hidden_dim"], + output_dim=num_class + ) + + def forward(self, x, edge_index): + embeddings = self.encoders(x, edge_index) + outputs = {} + for var, mlp in self.MLPs.items(): + outputs[var] = mlp(embeddings) + return outputs + + + def training_step(self, batch): + x, y_dict, samples = batch + outputs = self.forward(x, self.edge_index) + + losses = {} + for var in self.variables: + if var == self.surv_event_var: + durations = y_dict[self.surv_time_var] + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] + loss = cox_ph_loss(risk_scores, durations, events) + else: + y_hat = outputs[var] + y = y_dict[var] + loss = self.compute_loss(var, y, y_hat) + losses[var] = loss + + total_loss = self.compute_total_loss(losses) + losses["train_loss"] = total_loss + self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch)) + return total_loss + + def validation_step(self, batch): + x, y_dict, samples = batch + outputs = self.forward(x, self.edge_index) + losses = {} + for var in self.variables: + if var == self.surv_event_var: + durations = y_dict[self.surv_time_var] + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] + loss = cox_ph_loss(risk_scores, durations, events) + else: + y_hat = outputs[var] + y = y_dict[var] + loss = self.compute_loss(var, y, y_hat) + losses[var] = loss + + total_loss = sum(losses.values()) + losses["val_loss"] = total_loss + self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch)) + return total_loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) + return optimizer + + def compute_loss(self, var, y, y_hat): + if self.variable_types[var] == "numerical": + # Ignore instances with missing labels for numerical variables + valid_indices = ~torch.isnan(y) + if valid_indices.sum() > 0: # only calculate loss if there are valid targets + y_hat = y_hat[valid_indices] + y = y[valid_indices] + loss = F.mse_loss(torch.flatten(y_hat), y.float()) + else: + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + else: + # Ignore instances with missing labels for categorical variables + # Assuming that missing values were encoded as -1 + valid_indices = (y != -1) & (~torch.isnan(y)) + if valid_indices.sum() > 0: # only calculate loss if there are valid targets + y_hat = y_hat[valid_indices] + y = y[valid_indices] + loss = F.cross_entropy(y_hat, y.long()) + else: + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + return loss + + def compute_total_loss(self, losses): + if self.use_loss_weighting and len(losses) > 1: + # Compute weighted loss for each loss + # Weighted loss = precision * loss + log-variance + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items() + ) + else: + # Compute unweighted total loss + total_loss = sum(losses.values()) + return total_loss + + def predict(self, dataset): + self.eval() # Set the model to evaluation mode + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(device) # Move the model to the appropriate device + + # Create a DataLoader with a practical batch size + dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + edge_index = dataset.edge_index.to(device) # Move edge_index to GPU + + predictions = {var: [] for var in self.variables} # Initialize prediction storage + + # Process each batch + for x, y_dict,samples in dataloader: + x = x.to(device) # Move data to GPU + + outputs = self.forward(x, edge_index) + + # Collect predictions for each variable + for var in self.variables: + y_pred = outputs[var].detach().cpu().numpy() # Move outputs back to CPU and convert to numpy + if self.variable_types[var] == "categorical": + predictions[var].extend(np.argmax(y_pred, axis=1)) + else: + predictions[var].extend(y_pred) + + # Convert lists to arrays if necessary, depending on the downstream use-case + predictions = {var: np.array(predictions[var]) for var in predictions} + + return predictions + + def transform(self, dataset): + self.eval() # Set the model to evaluation mode + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to(device) # Move the model to the appropriate device + edge_index = dataset.edge_index.to(device) # Move edge_index to GPU + + dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + all_embeddings = [] # List to store embeddings from all batches + sample_ids = [] # List to store indices for all samples processed + + # Process each batch + for x, y_dict, samples in dataloader: + x = x.to(device) # Move data to GPU + + embeddings = self.encoders(x, edge_index).detach().cpu().numpy() # Compute embeddings and move to CPU + all_embeddings.append(embeddings) + sample_ids.extend(samples) + + # Concatenate all embeddings into a single numpy array + all_embeddings = np.vstack(all_embeddings) + + # Converting tensor to numpy array and then to DataFrame + embeddings_df = pd.DataFrame( + all_embeddings, + index=sample_ids, # Use the correct indices as row names + columns=[f"E{dim}" for dim in range(all_embeddings.shape[1])], + ) + return embeddings_df + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): + optimizer.zero_grad(set_to_none=True) + + # Adaptor forward function for captum integrated gradients. + def forward_target(self, *args): + input_data = list(args[:-2]) # expect a single tensor (early integration) + target_var = args[-2] # target variable of interest + steps = args[-1] # number of steps for IntegratedGradients().attribute + outputs_list = [] + for i in range(steps): + x_step = input_data[0][i] + #edges_step = edge_index[i] # although, identical, they get copied. + out = self.forward(x_step, self.dataset_edge_index) + outputs_list.append(out[target_var]) + return torch.cat(outputs_list, dim = 0) + + + def compute_feature_importance(self, dataset, target_var, steps=5, batch_size = 16): + """ + Computes the feature importance for each variable in the dataset using the Integrated Gradients method. + This method measures the importance of each feature by attributing the prediction output to each input feature. + + Args: + dataset: The dataset object containing the features and data (MultiOmicDatasetNW object). + target_var (str): The target variable for which feature importance is calculated. + steps (int, optional): The number of steps to use for integrated gradients approximation. Defaults to 5. + batch_size (int, optional): The size of the batch to process the dataset. Defaults to 64. + + Returns: + pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. + Columns include 'target_variable', 'target_class', 'target_class_label', 'layer', 'name', + and 'importance'. + + This function adjusts the device setting based on the availability of GPUs and performs the computation using + Integrated Gradients. It processes batches of data, aggregates results across batches, and formats the output + into a readable DataFrame which is then stored in the model's attribute for later use or analysis. + """ + def bytes_to_gb(bytes): + return bytes / 1024 ** 2 + print("Memory before moving model to device: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu') + self.to(device) + print("Memory before edges: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + self.dataset_edge_index = dataset.edge_index.to(device) + print("Memory after edges: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + + + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) + ig = IntegratedGradients(self.forward_target) + + if dataset.variable_types[target_var] == 'numerical': + num_class = 1 + else: + num_class = len(np.unique(dataset.ann[target_var])) + + print("Memory before batch processing: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + + aggregated_attributions = [[] for _ in range(num_class)] + for batch in dataloader: + x, y_dict, samples = batch + + input_data = x.unsqueeze(0).requires_grad_().to(device) + baseline = torch.zeros_like(input_data) + + if num_class == 1: + # returns a tuple of tensors (one per data modality) + attributions = ig.attribute( input_data, baseline, + additional_forward_args=(target_var, steps), + n_steps=steps) + aggregated_attributions[0].append(attributions) + else: + for target_class in range(num_class): + # returns a tuple of tensors (one per data modality) + attributions = ig.attribute( input_data, baseline, + additional_forward_args=(target_var, steps), + target=target_class, n_steps=steps) + aggregated_attributions[target_class].append(attributions) + # For each target class concatenate node attributions accross batches + processed_attributions = [] + # Process each class + for class_idx in range(len(aggregated_attributions)): + class_attr = aggregated_attributions[class_idx] + # Concatenate tensors along the batch dimension + attr_concat = torch.cat([batch_attr for batch_attr in class_attr], dim=1) + processed_attributions.append(attr_concat) + + # compute absolute importance and move to cpu + abs_attr = [torch.abs(attr_class).cpu() for attr_class in processed_attributions] + # average over samples + imp = [a.mean(dim=1) for a in abs_attr] + + # move the model also back to cpu (if not already on cpu) + self.to('cpu') + print("Memory after batch processing: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + + + df_list = [] + layers = list(dataset.multiomic_dataset.dat.keys()) + for i in range(num_class): + features = dataset.common_features + target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' + for l in range(len(layers)): + # extracting node feature attributes coming from different omic layers + importances = imp[i].squeeze().detach().numpy()[:,l] + df_list.append(pd.DataFrame({'target_variable': target_var, + 'target_class': i, + 'target_class_label': target_class_label, + 'layer': layers[l], + 'name': features, + 'importance': importances})) + df_imp = pd.concat(df_list, ignore_index=True) + # save the computed scores in the model + self.feature_importances[target_var] = df_imp \ No newline at end of file diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 905b554..1ea8984 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -171,18 +171,19 @@ def __init__(self, input_dim, hidden_dim, output_dim, raise ValueError('Unknown convolution type. Choose one of: ',conv_options.keys()) self.conv = conv_options[conv] - self.layer_1 = self.conv(input_dim, output_dim) + self.layer_1 = self.conv(input_dim, hidden_dim) self.act_1 = self.activation - #self.layer_2 = self.conv(hidden_dim, output_dim) - #self.act_2 = self.activation - self.aggregation = aggr.SumAggregation() + self.layer_2 = self.conv(hidden_dim, output_dim) + self.act_2 = self.activation - def forward(self, x, edge_index, batch): + def forward(self, x, edge_index): x = self.layer_1(x, edge_index) x = self.act_1(x) - #x = self.layer_2(x, edge_index) - #x = self.act_2(x) - x = self.aggregation(x, batch) + x = self.layer_2(x, edge_index) + x = self.act_2(x) + # mean pooling to get sample embeddings (we use a single graph across samples) + # Perform mean pooling across the node dimension + x = x.mean(dim=1) return x def cox_ph_loss(outputs, durations, events):