diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 33c188e..3632068 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -61,7 +61,8 @@ class MultiTripletNetwork(pl.LightningModule): """ """ def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True): + surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True, + device_type = None): """ Initialize the MultiTripletNetwork with the given parameters. @@ -85,7 +86,8 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, self.ann = self.dataset.ann self.variable_types = self.dataset.variable_types self.feature_importances = {} - + self.device_type = device_type + layers = list(dataset.dat.keys()) input_sizes = [len(dataset.features[layers[i]]) for i in range(len(layers))] hidden_sizes = [config['hidden_dim'] for x in range(len(layers))] @@ -348,6 +350,11 @@ def compute_feature_importance(self, target_var, steps = 5): attributions (list of torch.Tensor): The feature importances for each class. """ + device = torch.device("cuda" if self.device_type == 'gpu' and torch.cuda.is_available() else 'cpu') + self.to(device) + + print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) + # self.dataset is a TripletMultiomicDataset, which has a different # structure than the MultiomicDataset. We use data loader to # read the triplets and get anchor/positive/negative tensors @@ -356,6 +363,12 @@ def compute_feature_importance(self, target_var, steps = 5): it = iter(dl) anchor, positive, negative, y_dict = next(it) + # Move tensors to the specified device + anchor = {k: v.to(device) for k, v in anchor.items()} + positive = {k: v.to(device) for k, v in positive.items()} + negative = {k: v.to(device) for k, v in negative.items()} + y_dict = {k: v.to(device) for k, v in y_dict.items()} + # Initialize the Integrated Gradients method ig = IntegratedGradients(self.forward_target) @@ -391,10 +404,14 @@ def compute_feature_importance(self, target_var, steps = 5): # summarize feature importances # Compute absolute attributions - abs_attr = [[torch.abs(a) for a in attr_class] for attr_class in attributions] + # Move the processed tensors to CPU for further operations that are not supported on GPU + abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in attributions] # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] + # move the model also back to cpu (if not already on cpu) + self.to('cpu') + # combine into a single data frame df_list = [] layers = list(self.dataset.dataset.dat.keys()) # accessing multiomicdataset within tripletmultiomic dataset here