Skip to content

Commit

Permalink
account for dataset structure if required model is triplet networks
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed May 1, 2024
1 parent 33f28f5 commit 6b70d40
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def __init__(self, model, dataset, n_splits=5, batch_size=32, learning_rates=Non
{'encoders': False, 'supervisors': False}
]

if model.__class__.__name__ == 'MultiTripletNetwork':
# modify dataset structure to accommodate TripletNetworks
self.dataset = TripletMultiOmicDataset(dataset, model.main_var)

def apply_freeze_config(self, config):
# Freeze or unfreeze encoders
for encoder in self.model.encoders:
Expand Down Expand Up @@ -265,7 +269,7 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch, batch_idx):
# Call the model's validation step without logging
val_loss = self.model.validation_step(batch, batch_idx, log=False) # Assuming you can disable logging
val_loss = self.model.validation_step(batch, batch_idx, log=False)
self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
return val_loss

Expand Down

0 comments on commit 6b70d40

Please sign in to comment.