From 95f5b99985621d2c1f9f92d9bf845b8a75acf4c9 Mon Sep 17 00:00:00 2001 From: AdrianM0 Date: Thu, 11 Apr 2024 12:58:38 +0200 Subject: [PATCH] feat: code is running on hydra :tada: --- experiments/train.py | 3 +-- src/molbind/data/dataloaders.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/experiments/train.py b/experiments/train.py index 76795e0e..3c86cb17 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -6,7 +6,6 @@ from omegaconf import DictConfig import torch import rootutils -from hydra.utils import instantiate rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) @@ -87,7 +86,7 @@ def train_molbind(config: DictConfig): ) -@hydra.main(config_path="../configs", config_name="train.yaml") +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") def main(config: DictConfig): train_molbind(config) diff --git a/src/molbind/data/dataloaders.py b/src/molbind/data/dataloaders.py index 21cfa575..5d9d7875 100644 --- a/src/molbind/data/dataloaders.py +++ b/src/molbind/data/dataloaders.py @@ -2,7 +2,7 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER from networkx import Graph -from typing import Tuple +from typing import Tuple, Optional from torch import Tensor @@ -25,9 +25,9 @@ class StringDataset(Dataset): def __init__( self, dataset: Tuple[Tensor, Tensor], - central_modality: str, modality: str, - context_length=256, + central_modality: str = "smiles", + context_length: Optional[int] = 256, ): """Dataset for string modalities. @@ -38,13 +38,13 @@ def __init__( """ assert len(dataset) == 2 assert len(dataset[0]) == len(dataset[1]) - assert ( - MODALITY_DATA_TYPES[modality] == str - ), "This dataset supports string modalities only." self.modality = modality self.central_modality = central_modality + assert MODALITY_DATA_TYPES[modality] == str + assert MODALITY_DATA_TYPES[central_modality] == str + self.tokenized_central_modality = STRING_TOKENIZERS[central_modality]( dataset[0], padding="max_length", @@ -62,13 +62,13 @@ def __init__( ) def __len__(self): - return len(self.tokenized_smiles.input_ids) + return len(self.tokenized_central_modality.input_ids) def __getitem__(self, idx): return { - "smiles": ( - self.tokenized_smiles.input_ids[idx], - self.tokenized_smiles.attention_mask[idx], + self.central_modality: ( + self.tokenized_central_modality.input_ids[idx], + self.tokenized_central_modality.attention_mask[idx], ), self.modality: ( self.tokenized_string.input_ids[idx], @@ -82,22 +82,22 @@ class GraphDataset(Dataset): def load_combined_loader( - central_modality: str, data_modalities: dict, batch_size: int, shuffle: bool, num_workers: int, + central_modality: str = "smiles", drop_last: bool = True, ) -> CombinedLoader: """Combine multiple dataloaders for different modalities into a single dataloader. Args: - data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality) + data_modalities (dict): data inputs for each modality as pairs of (central_modality, modality) batch_size (int): batch size for the dataloader shuffle (bool): shuffle the dataset num_workers (int): number of workers for the dataloader drop_last (bool, optional): whether to drop the last batch; defaults to True. - + central_modality (str, optional): central modality to use for the dataset; defaults to "smiles". Returns: CombinedLoader: a combined dataloader for all the modalities """ @@ -106,9 +106,9 @@ def load_combined_loader( for modality in [*data_modalities]: if MODALITY_DATA_TYPES[modality] == str: dataset_instance = StringDataset( - dataset=data_modalities, - central_modality=central_modality, + dataset=data_modalities[modality], modality=modality, + central_modality=central_modality, context_length=256, ) loaders[modality] = DataLoader(