Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup first model version for the SMILES-SELFIES modality pair #5

Merged
merged 21 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions configs/data/mnist.yaml

This file was deleted.

10 changes: 10 additions & 0 deletions configs/data/molbind.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_target_: molbind.data.dataloaders.load_combined_loader
central_modality: "smiles"
modalities:
- "selfies"
train_frac : 0.8
val_frac : 0.2
seed: 42
fraction_data: 1.0
dataset_path: "subset.csv"
batch_size: 64
13 changes: 2 additions & 11 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@

wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""
project: "molbind"
entity: "adrianmirza"
25 changes: 0 additions & 25 deletions configs/model/mnist.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/paths/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ log_dir: ${paths.root_dir}/logs/
output_dir: ${hydra:runtime.output_dir}

# path to working directory
work_dir: ${hydra:runtime.cwd}
work_dir: ${hydra:runtime.cwd}
19 changes: 10 additions & 9 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: mnist
- model: mnist
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default
- data: molbind
- model: molbind
- logger: wandb
# - callbacks: default
# - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
# - trainer: default
# - paths: default
# - extras: default
# - hydra: default

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
Expand Down Expand Up @@ -46,4 +47,4 @@ test: True
ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null
seed: 42
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - conda allows for installing packages without requiring certain compilers or
# libraries to be available in the system, since it installs precompiled binaries

name: myenv
name: molbind

channels:
- pytorch
Expand Down
10,000 changes: 10,000 additions & 0 deletions experiments/subset.csv

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions experiments/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from molbind.models.lightning_module import train_molbind
from omegaconf import DictConfig


if __name__ == "__main__":
config = {
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"},
"model": {
"projection_heads": {
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
},
"encoders": {
"smiles": {"pretrained": True, "freeze_encoder": False},
"selfies": {"pretrained": True, "freeze_encoder": False},
},
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4},
},
"loss": {"temperature": 0.1},
"data": {
"central_modality": "smiles",
"modalities": ["selfies"],
"dataset_path": "subset.csv",
"train_frac": 0.8,
"valid_frac": 0.2,
"seed": 42,
"fraction_data": 1,
"batch_size": 64,
},
}

config = DictConfig(config)
train_molbind(config)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Operating System :: MacOS",
]
requires-python = ">=3.9, <3.13"
dependencies = ["numpy"]
dependencies = ["numpy", "transformers", "pandas", "polars"]

[project.optional-dependencies]
dev = ["codecov-cli>=0.4.1", "pytest>=7.4.0", "pytest-cov>=3.0.0", "ruff>=0.0.285"]
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
torchmetrics>=0.11.4
info-nce-pytorch

# --------- hydra --------- #
hydra-core==1.3.2
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0
Comment on lines 9 to 11
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is hydra pinned and the torch stuff only has a lower bound 🤔 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the default stuff

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean it was pinned this way by default? But what is the rational behind doing it this way?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and do we actually now use both requirements.txt and the toml file? This might easily become messy to have it in two places


# --------- loggers --------- #
# wandb
wandb
# neptune-client
# mlflow
# comet-ml
Expand All @@ -21,4 +22,4 @@ rootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)
# sh # for running bash commands in some tests (linux/macos only)
21 changes: 0 additions & 21 deletions setup.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/molbind/data/components/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from transformers import AutoTokenizer

SMILES_TOKENIZER = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
SELFIES_TOKENIZER = AutoTokenizer.from_pretrained("HUBioDataLab/SELFormer")
148 changes: 99 additions & 49 deletions src/molbind/data/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,113 @@
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER
import networkx as nx
from networkx import Graph
from typing import List, Dict
from typing import Tuple
from torch import Tensor


MODALITY_DATA_TYPES = {
"smiles" : str,
"selfies" : str,
"graph" : Graph,
"nmr" : str,
"ir" : str
"smiles": str,
"selfies": str,
"graph": Graph,
"nmr": str,
"ir": str,
}

STRING_TOKENIZERS = {
"smiles": SMILES_TOKENIZER,
"selfies": SELFIES_TOKENIZER,
"iupac_name": "iupac_name_tokenizer",
}


class StringDataset(Dataset):
def __init__(
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256
):
"""_summary_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_summary_ is strange :)

class StringDataLoader(DataLoader):
def __init__(self, dataset, batch_size, shuffle, num_workers, modality="smiles"):
super(StringDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
Args:
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1)
modality (str): name of data modality as found in MODALITY_DATA_TYPES
context_length (int, optional): _description_. Defaults to 256.
"""
assert len(dataset) == 2
assert len(dataset[0]) == len(dataset[1])
self.modality = modality
Comment on lines +35 to +36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏽

self.tokenized_smiles = STRING_TOKENIZERS["smiles"](
AdrianM0 marked this conversation as resolved.
Show resolved Hide resolved
dataset[0],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)
self.tokenized_string = STRING_TOKENIZERS[modality](
dataset[1],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)

Comment on lines +38 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if your data is large, you might need to revisit this and replace this with some tokenization on the fly or loading from pre-tokenized datasets.

it is okay for now, but I'd keep in mind that this might need to be refactored

def __len__(self):
return len(self.dataset)

def __iter__(self):
for batch in super(StringDataLoader, self).__iter__():

if self.modality == "smiles":
tokenized_batch = SMILES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt")
elif self.modality == "selfies":
tokenized_batch = SELFIES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt")
yield tokenized_batch["input_ids"], tokenized_batch["attention_mask"]


def load_combined_loader(data_modalities : Dict, batch_size : int, shuffle : bool, num_workers : int) -> CombinedLoader:
return len(self.tokenized_smiles.input_ids)

def __getitem__(self, idx):
return {
"smiles": (
self.tokenized_smiles.input_ids[idx],
self.tokenized_smiles.attention_mask[idx],
),
self.modality: (
self.tokenized_string.input_ids[idx],
self.tokenized_string.attention_mask[idx],
),
}


class GraphDataset(Dataset):
pass


def load_combined_loader(
data_modalities: dict,
batch_size: int,
shuffle: bool,
num_workers: int,
drop_last: bool = True,
) -> CombinedLoader:
"""Combine multiple dataloaders for different modalities into a single dataloader.

Args:
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality)
batch_size (int): batch size for the dataloader
shuffle (bool): shuffle the dataset
num_workers (int): number of workers for the dataloader
drop_last (bool, optional): whether to drop the last batch; defaults to True.

Returns:
CombinedLoader: a combined dataloader for all the modalities
"""
loaders = {}

for modality in data_modalities.keys():
# import pdb; pdb.set_trace()

for modality in [*data_modalities]:
if MODALITY_DATA_TYPES[modality] == str:
loaders[modality] = StringDataLoader(data_modalities[modality], batch_size, shuffle, num_workers, modality)
dataset_instance = StringDataset(data_modalities[modality], modality)
loaders[modality] = DataLoader(
dataset_instance,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
elif MODALITY_DATA_TYPES[modality] == Graph:
loaders[modality] = DataLoader(data_modalities[modality], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return CombinedLoader(loaders, "min_size")


smiles = ["CCO", "CCN", "CCO", "CCN"]
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O]", "[C][C][N]"]
dummy_graphs = ["dummy_graph", "dummy_graph", "dummy_graph", "dummy_graph"]

combined_loader = load_combined_loader(
data_modalities = {
"smiles" : smiles,
"selfies" : selfies,
"graph" : dummy_graphs
},
batch_size=2,
shuffle=True,
num_workers=1)

for batch, batch_idx, dataloader_idx in combined_loader:
print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
graph_dataset_instance = GraphDataset(data_modalities[modality])
loaders[modality] = DataLoader(
graph_dataset_instance,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
return CombinedLoader(loaders, mode="sequential")
Loading
Loading