diff --git a/experiments/train.py b/experiments/train.py index f0250e1e..76795e0e 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -65,6 +65,7 @@ def train_molbind(config: DictConfig): ] combined_loader = load_combined_loader( + central_modality=config.data.central_modality, data_modalities=train_modality_data, batch_size=config.data.batch_size, shuffle=True, @@ -72,6 +73,7 @@ def train_molbind(config: DictConfig): ) valid_dataloader = load_combined_loader( + central_modality=config.data.central_modality, data_modalities=valid_modality_data, batch_size=config.data.batch_size, shuffle=False, @@ -87,9 +89,7 @@ def train_molbind(config: DictConfig): @hydra.main(config_path="../configs", config_name="train.yaml") def main(config: DictConfig): - # train_molbind(config) - import pdb; pdb.set_trace() - config = instantiate(config) + train_molbind(config) if __name__ == "__main__":