Skip to content

Commit

Permalink
feat: code is running on hydra 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianM0 committed Apr 11, 2024
1 parent a533a09 commit 95f5b99
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
3 changes: 1 addition & 2 deletions experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from omegaconf import DictConfig
import torch
import rootutils
from hydra.utils import instantiate

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

Expand Down Expand Up @@ -87,7 +86,7 @@ def train_molbind(config: DictConfig):
)


@hydra.main(config_path="../configs", config_name="train.yaml")
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(config: DictConfig):
train_molbind(config)

Expand Down
30 changes: 15 additions & 15 deletions src/molbind/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER
from networkx import Graph
from typing import Tuple
from typing import Tuple, Optional
from torch import Tensor


Expand All @@ -25,9 +25,9 @@ class StringDataset(Dataset):
def __init__(
self,
dataset: Tuple[Tensor, Tensor],
central_modality: str,
modality: str,
context_length=256,
central_modality: str = "smiles",
context_length: Optional[int] = 256,
):
"""Dataset for string modalities.
Expand All @@ -38,13 +38,13 @@ def __init__(
"""
assert len(dataset) == 2
assert len(dataset[0]) == len(dataset[1])
assert (
MODALITY_DATA_TYPES[modality] == str
), "This dataset supports string modalities only."

self.modality = modality
self.central_modality = central_modality

assert MODALITY_DATA_TYPES[modality] == str
assert MODALITY_DATA_TYPES[central_modality] == str

self.tokenized_central_modality = STRING_TOKENIZERS[central_modality](
dataset[0],
padding="max_length",
Expand All @@ -62,13 +62,13 @@ def __init__(
)

def __len__(self):
return len(self.tokenized_smiles.input_ids)
return len(self.tokenized_central_modality.input_ids)

def __getitem__(self, idx):
return {
"smiles": (
self.tokenized_smiles.input_ids[idx],
self.tokenized_smiles.attention_mask[idx],
self.central_modality: (
self.tokenized_central_modality.input_ids[idx],
self.tokenized_central_modality.attention_mask[idx],
),
self.modality: (
self.tokenized_string.input_ids[idx],
Expand All @@ -82,22 +82,22 @@ class GraphDataset(Dataset):


def load_combined_loader(
central_modality: str,
data_modalities: dict,
batch_size: int,
shuffle: bool,
num_workers: int,
central_modality: str = "smiles",
drop_last: bool = True,
) -> CombinedLoader:
"""Combine multiple dataloaders for different modalities into a single dataloader.
Args:
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality)
data_modalities (dict): data inputs for each modality as pairs of (central_modality, modality)
batch_size (int): batch size for the dataloader
shuffle (bool): shuffle the dataset
num_workers (int): number of workers for the dataloader
drop_last (bool, optional): whether to drop the last batch; defaults to True.
central_modality (str, optional): central modality to use for the dataset; defaults to "smiles".
Returns:
CombinedLoader: a combined dataloader for all the modalities
"""
Expand All @@ -106,9 +106,9 @@ def load_combined_loader(
for modality in [*data_modalities]:
if MODALITY_DATA_TYPES[modality] == str:
dataset_instance = StringDataset(
dataset=data_modalities,
central_modality=central_modality,
dataset=data_modalities[modality],
modality=modality,
central_modality=central_modality,
context_length=256,
)
loaders[modality] = DataLoader(
Expand Down

0 comments on commit 95f5b99

Please sign in to comment.