Skip to content

Commit

Permalink
Adapte MultiTripletNetwork to device management
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Mar 2, 2024
1 parent 9fedee1 commit e8a5c92
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions flexynesis/models/triplet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class MultiTripletNetwork(pl.LightningModule):
"""
"""
def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True):
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True,
device_type = None):
"""
Initialize the MultiTripletNetwork with the given parameters.
Expand All @@ -85,7 +86,8 @@ def __init__(self, config, dataset, target_variables, batch_variables = None,
self.ann = self.dataset.ann
self.variable_types = self.dataset.variable_types
self.feature_importances = {}

self.device_type = device_type

layers = list(dataset.dat.keys())
input_sizes = [len(dataset.features[layers[i]]) for i in range(len(layers))]
hidden_sizes = [config['hidden_dim'] for x in range(len(layers))]
Expand Down Expand Up @@ -348,6 +350,11 @@ def compute_feature_importance(self, target_var, steps = 5):
attributions (list of torch.Tensor): The feature importances for each class.
"""

device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu')
self.to(device)

print("[INFO] Computing feature importance for variable:",target_var,"on device:",device)

# self.dataset is a TripletMultiomicDataset, which has a different
# structure than the MultiomicDataset. We use data loader to
# read the triplets and get anchor/positive/negative tensors
Expand All @@ -356,6 +363,12 @@ def compute_feature_importance(self, target_var, steps = 5):
it = iter(dl)
anchor, positive, negative, y_dict = next(it)

# Move tensors to the specified device
anchor = {k: v.to(device) for k, v in anchor.items()}
positive = {k: v.to(device) for k, v in positive.items()}
negative = {k: v.to(device) for k, v in negative.items()}
y_dict = {k: v.to(device) for k, v in y_dict.items()}

# Initialize the Integrated Gradients method
ig = IntegratedGradients(self.forward_target)

Expand Down Expand Up @@ -391,10 +404,14 @@ def compute_feature_importance(self, target_var, steps = 5):

# summarize feature importances
# Compute absolute attributions
abs_attr = [[torch.abs(a) for a in attr_class] for attr_class in attributions]
# Move the processed tensors to CPU for further operations that are not supported on GPU
abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in attributions]
# average over samples
imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr]

# move the model also back to cpu (if not already on cpu)
self.to('cpu')

# combine into a single data frame
df_list = []
layers = list(self.dataset.dataset.dat.keys()) # accessing multiomicdataset within tripletmultiomic dataset here
Expand Down

0 comments on commit e8a5c92

Please sign in to comment.