Skip to content

Commit

Permalink
update docstrings for direct_pred class
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Jun 27, 2024
1 parent b4f4c01 commit 5ffaec0
Showing 1 changed file with 72 additions and 17 deletions.
89 changes: 72 additions & 17 deletions flexynesis/models/direct_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@
from ..modules import *

class DirectPred(pl.LightningModule):
"""
A fully connected network for multi-omics integration with supervisor heads.
Attributes:
config (dict): Configuration settings for the model, including learning rates and dimensions.
dataset: The MultiOmicDataset object containing the data and metadata.
target_variables (list): A list of target variable names that the model aims to predict.
batch_variables (list, optional): A list of variables used for batch correction. Defaults to None.
surv_event_var (str, optional): The name of the survival event variable. Defaults to None.
surv_time_var (str, optional): The name of the survival time variable. Defaults to None.
use_loss_weighting (bool, optional): Whether to use loss weighting in the model. Defaults to True.
device_type (str, optional): Type of device to run the model ('gpu' or 'cpu'). Defaults to None.
"""

def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, use_loss_weighting = True,
device_type = None):
Expand Down Expand Up @@ -116,6 +130,26 @@ def compute_loss(self, var, y, y_hat):
return loss

def compute_total_loss(self, losses):
"""
Computes the total loss from a dictionary of individual losses. This method can compute
either weighted or unweighted total loss based on the model configuration. If loss weighting
is enabled and there are multiple loss components, it uses uncertainty-based weighting.
See Kendall A. et al, https://arxiv.org/abs/1705.07115.
Args:
losses (dict of torch.Tensor): A dictionary where each key is a variable name and
each value is the loss tensor associated with that variable.
Returns:
torch.Tensor: The total loss computed across all inputs, either weighted or unweighted.
The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple
losses to weight. If so, it computes the weighted sum of losses, where the weight involves
the exponential of the negative log variance (acting as precision) associated with each loss,
added to the log variance itself. This approach helps in balancing the contribution of each
loss component based on its uncertainty. If loss weighting is not used, or there is only one
loss component, it sums up the losses directly.
"""
if self.use_loss_weighting and len(losses) > 1:
# Compute weighted loss for each loss
# Weighted loss = precision * loss + log-variance
Expand All @@ -127,12 +161,15 @@ def compute_total_loss(self, losses):

def training_step(self, train_batch, batch_idx, log = True):
"""
Perform a single training step.
Executes one training step using a single batch from the training dataset.
Args:
train_batch (tuple): A tuple containing the input data and labels for the current batch.
batch_idx (int): The index of the current batch.
train_batch (tuple): The batch to train on, which includes input data and targets.
batch_idx (int): Index of the current batch in the sequence.
log (bool, optional): Whether to log the loss metrics to TensorBoard. Defaults to True.
Returns:
torch.Tensor: The total loss for the current training step.
torch.Tensor: The total loss computed for the batch.
"""

dat, y_dict = train_batch
Expand Down Expand Up @@ -161,14 +198,15 @@ def training_step(self, train_batch, batch_idx, log = True):

def validation_step(self, val_batch, batch_idx, log = True):
"""
Perform a single validation step.
Executes one validation step using a single batch from the validation dataset.
Args:
val_batch (tuple): A tuple containing the input data and labels for the current batch.
batch_idx (int): The index of the current batch.
val_batch (tuple): The batch to validate on, which includes input data and targets.
batch_idx (int): Index of the current batch in the sequence.
log (bool, optional): Whether to log the loss metrics to TensorBoard. Defaults to True.
Returns:
torch.Tensor: The total loss for the current validation step.
torch.Tensor: The total loss computed for the batch.
"""
dat, y_dict = val_batch
layers = dat.keys()
Expand All @@ -194,13 +232,13 @@ def validation_step(self, val_batch, batch_idx, log = True):

def predict(self, dataset):
"""
Evaluate the DirectPred model on a given dataset.
Make predictions on an entire dataset.
Args:
dataset: The dataset to evaluate the model on.
dataset: The MultiOmicDataset object to evaluate the model on.
Returns:
A dictionary where each key is a target variable and the corresponding value is the predicted output for that variable.
dict: Predictions mapped by target variable names.
"""
self.eval()
layers = dataset.dat.keys()
Expand All @@ -218,15 +256,13 @@ def predict(self, dataset):

def transform(self, dataset):
"""
Transform the input data into a lower-dimensional space using the trained encoders.
Transforms the input data into a lower-dimensional representation using trained encoders.
Args:
dataset: The input dataset containing the omics data.
dataset: The dataset containing the input data.
Returns:
pd.DataFrame: A dataframe of embeddings where the row indices are
dataset.samples and the column names are created by appending
the substring "E" to each dimension index.
pd.DataFrame: DataFrame containing the transformed data.
"""
self.eval()
embeddings_list = []
Expand Down Expand Up @@ -255,6 +291,25 @@ def forward_target(self, *args):
return torch.cat(outputs_list, dim = 0)

def compute_feature_importance(self, dataset, target_var, steps=5, batch_size = 64):
"""
Computes the feature importance for each variable in the dataset using the Integrated Gradients method.
This method measures the importance of each feature by attributing the prediction output to each input feature.
Args:
dataset: The dataset object containing the features and data.
target_var (str): The target variable for which feature importance is calculated.
steps (int, optional): The number of steps to use for integrated gradients approximation. Defaults to 5.
batch_size (int, optional): The size of the batch to process the dataset. Defaults to 64.
Returns:
pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities.
Columns include 'target_variable', 'target_class', 'target_class_label', 'layer', 'name',
and 'importance'.
This function adjusts the device setting based on the availability of GPUs and performs the computation using
Integrated Gradients. It processes batches of data, aggregates results across batches, and formats the output
into a readable DataFrame which is then stored in the model's attribute for later use or analysis.
"""
device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu')
self.to(device)

Expand Down Expand Up @@ -327,4 +382,4 @@ def compute_feature_importance(self, dataset, target_var, steps=5, batch_size =
df_imp = pd.concat(df_list, ignore_index=True)
# save the computed scores in the model
self.feature_importances[target_var] = df_imp


0 comments on commit 5ffaec0

Please sign in to comment.