Skip to content

Commit

Permalink
only pass input layers to forward; print decoded layers to disk
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Apr 3, 2024
1 parent 221961b commit 2496e11
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
10 changes: 10 additions & 0 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ class AvailableModels(NamedTuple):
embeddings_train_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.filtered.csv'])), header=True)
embeddings_test_filtered.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.filtered.csv'])), header=True)

# for architectures with decoders; print decoded output layers
if args.model_class == 'CrossModalPred':
output_layers_train = model.decode(train_dataset)
output_layers_test = model.decode(test_dataset)
for layer in output_layers_train.keys():
output_layers_train[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'train_decoded', layer, 'csv'])), header=True)
for layer in output_layers_test.keys():
output_layers_test[layer].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'test_decoded', layer, 'csv'])), header=True)


# evaluate off-the-shelf methods on the main target variable
if args.evaluate_baseline_performance == 'True':
print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0])
Expand Down
23 changes: 10 additions & 13 deletions flexynesis/models/crossmodal_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def multi_encoder(self, x_list):
log_var = self.FC_log_var(torch.cat(log_vars, dim=1))
return mean, log_var

def forward(self, x_list_input, x_list_output):
def forward(self, x_list_input):
"""
Forward pass through the model.
Expand All @@ -142,7 +142,7 @@ def forward(self, x_list_input, x_list_output):
z = self.reparameterization(mean, log_var)

# decode the latent space to target output layer(s)
x_hat_list = [self.decoders[i](z) for i in range(len(x_list_output))]
x_hat_list = [self.decoders[i](z) for i in range(len(self.output_layers))]

#run the supervisor heads using the latent layer as input
outputs = {}
Expand Down Expand Up @@ -214,11 +214,10 @@ def training_step(self, train_batch, batch_idx):

# get input omics modalities and encode them; decode them to output layers
x_list_input = [dat[x] for x in self.input_layers]
x_list_output = [dat[x] for x in self.output_layers]

x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output)
x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input)

# compute mmd loss for the latent space + reconsruction loss for each target/output layer
x_list_output = [dat[x] for x in self.output_layers]
mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))]
mmd_loss = torch.mean(torch.stack(mmd_loss_list))

Expand Down Expand Up @@ -248,9 +247,7 @@ def validation_step(self, val_batch, batch_idx):

# get input omics modalities and encode them
x_list_input = [dat[x] for x in self.input_layers]
x_list_output = [dat[x] for x in self.output_layers]

x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output)
x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input)

# compute mmd loss for the latent space + reconsruction loss for each target/output layer
x_list_output = [dat[x] for x in self.output_layers]
Expand Down Expand Up @@ -304,8 +301,8 @@ def transform(self, dataset):
pd.DataFrame: Transformed dataset as a pandas DataFrame.
"""
self.eval()
x_list = [dataset.dat[x] for x in self.input_layers]
M = self.forward(x_list)[1].detach().numpy()
x_list_input = [dataset.dat[x] for x in self.input_layers]
M = self.forward(x_list_input)[1].detach().numpy()
z = pd.DataFrame(M)
z.columns = [''.join(['E', str(x)]) for x in z.columns]
z.index = dataset.samples
Expand All @@ -323,8 +320,8 @@ def predict(self, dataset):
"""
self.eval()

x_list = [dataset.dat[x] for x in self.input_layers]
X_hat, z, mean, log_var, outputs = self.forward(x_list)
x_list_input = [dataset.dat[x] for x in self.input_layers]
X_hat, z, mean, log_var, outputs = self.forward(x_list_input)

predictions = {}
for var in self.variables:
Expand All @@ -344,7 +341,7 @@ def decode(self, dataset):
self.eval()
x_list_input = [dataset.dat[x] for x in self.input_layers]
x_list_output = [dataset.dat[x] for x in self.output_layers]
x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input, x_list_output)
x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input)
X = {}
for i in range(len(self.output_layers)):
x = pd.DataFrame(x_hat_list[i].detach().numpy()).transpose()
Expand Down

0 comments on commit 2496e11

Please sign in to comment.