From 2496e116748172def442ddbc6c72976efc05303f Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Wed, 3 Apr 2024 12:34:49 +0200 Subject: [PATCH] only pass input layers to forward; print decoded layers to disk --- flexynesis/__main__.py | 10 ++++++++++ flexynesis/models/crossmodal_pred.py | 23 ++++++++++------------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index b49f7fc..37338ef 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -257,6 +257,16 @@ class AvailableModels(NamedTuple): embeddings_train_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.filtered.csv'])), header=True) embeddings_test_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.filtered.csv'])), header=True) + # for architectures with decoders; print decoded output layers + if args.model_class == 'CrossModalPred': + output_layers_train = model.decode(train_dataset) + output_layers_test = model.decode(test_dataset) + for layer in output_layers_train.keys(): + output_layers_train[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'train_decoded', layer, 'csv'])), header=True) + for layer in output_layers_test.keys(): + output_layers_test[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'test_decoded', layer, 'csv'])), header=True) + + # evaluate off-the-shelf methods on the main target variable if args.evaluate_baseline_performance == 'True': print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0]) diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 9e8691c..0554b8f 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -121,7 +121,7 @@ def multi_encoder(self, x_list): log_var = self.FC_log_var(torch.cat(log_vars, dim=1)) return mean, log_var - def forward(self, x_list_input, x_list_output): + def forward(self, x_list_input): """ Forward pass through the model. @@ -142,7 +142,7 @@ def forward(self, x_list_input, x_list_output): z = self.reparameterization(mean, log_var) # decode the latent space to target output layer(s) - x_hat_list = [self.decoders[i](z) for i in range(len(x_list_output))] + x_hat_list = [self.decoders[i](z) for i in range(len(self.output_layers))] #run the supervisor heads using the latent layer as input outputs = {} @@ -214,11 +214,10 @@ def training_step(self, train_batch, batch_idx): # get input omics modalities and encode them; decode them to output layers x_list_input = [dat[x] for x in self.input_layers] - x_list_output = [dat[x] for x in self.output_layers] - - x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output) + x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) # compute mmd loss for the latent space + reconsruction loss for each target/output layer + x_list_output = [dat[x] for x in self.output_layers] mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) @@ -248,9 +247,7 @@ def validation_step(self, val_batch, batch_idx): # get input omics modalities and encode them x_list_input = [dat[x] for x in self.input_layers] - x_list_output = [dat[x] for x in self.output_layers] - - x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output) + x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) # compute mmd loss for the latent space + reconsruction loss for each target/output layer x_list_output = [dat[x] for x in self.output_layers] @@ -304,8 +301,8 @@ def transform(self, dataset): pd.DataFrame: Transformed dataset as a pandas DataFrame. """ self.eval() - x_list = [dataset.dat[x] for x in self.input_layers] - M = self.forward(x_list)[1].detach().numpy() + x_list_input = [dataset.dat[x] for x in self.input_layers] + M = self.forward(x_list_input)[1].detach().numpy() z = pd.DataFrame(M) z.columns = [''.join(['E', str(x)]) for x in z.columns] z.index = dataset.samples @@ -323,8 +320,8 @@ def predict(self, dataset): """ self.eval() - x_list = [dataset.dat[x] for x in self.input_layers] - X_hat, z, mean, log_var, outputs = self.forward(x_list) + x_list_input = [dataset.dat[x] for x in self.input_layers] + X_hat, z, mean, log_var, outputs = self.forward(x_list_input) predictions = {} for var in self.variables: @@ -344,7 +341,7 @@ def decode(self, dataset): self.eval() x_list_input = [dataset.dat[x] for x in self.input_layers] x_list_output = [dataset.dat[x] for x in self.output_layers] - x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output) + x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) X = {} for i in range(len(self.output_layers)): x = pd.DataFrame(x_hat_list[i].detach().numpy()).transpose()