Skip to content

Commit

Permalink
Merge pull request #5 from lamalab-org/selfies_smiles_example
Browse files Browse the repository at this point in the history
SELFIES - SMILES multimodal example
  • Loading branch information
AdrianM0 authored Apr 11, 2024
2 parents b46ae05 + 546001f commit 25b3bf5
Show file tree
Hide file tree
Showing 25 changed files with 10,460 additions and 281 deletions.
6 changes: 0 additions & 6 deletions configs/data/mnist.yaml

This file was deleted.

10 changes: 10 additions & 0 deletions configs/data/molbind.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_target_: molbind.data.dataloaders.load_combined_loader
central_modality: "smiles"
modalities:
- "selfies"
train_frac : 0.8
val_frac : 0.2
seed: 42
fraction_data: 1.0
dataset_path: "subset.csv"
batch_size: 64
13 changes: 2 additions & 11 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@

wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""
project: "molbind"
entity: "adrianmirza"
25 changes: 0 additions & 25 deletions configs/model/mnist.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/paths/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ log_dir: ${paths.root_dir}/logs/
output_dir: ${hydra:runtime.output_dir}

# path to working directory
work_dir: ${hydra:runtime.cwd}
work_dir: ${hydra:runtime.cwd}
19 changes: 10 additions & 9 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: mnist
- model: mnist
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default
- data: molbind
- model: molbind
- logger: wandb
# - callbacks: default
# - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
# - trainer: default
# - paths: default
# - extras: default
# - hydra: default

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
Expand Down Expand Up @@ -46,4 +47,4 @@ test: True
ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null
seed: 42
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - conda allows for installing packages without requiring certain compilers or
# libraries to be available in the system, since it installs precompiled binaries

name: myenv
name: molbind

channels:
- pytorch
Expand Down
10,000 changes: 10,000 additions & 0 deletions experiments/subset.csv

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions experiments/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from molbind.models.lightning_module import train_molbind
from omegaconf import DictConfig


if __name__ == "__main__":
config = {
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"},
"model": {
"projection_heads": {
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
},
"encoders": {
"smiles": {"pretrained": True, "freeze_encoder": False},
"selfies": {"pretrained": True, "freeze_encoder": False},
},
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4},
},
"loss": {"temperature": 0.1},
"data": {
"central_modality": "smiles",
"modalities": ["selfies"],
"dataset_path": "subset.csv",
"train_frac": 0.8,
"valid_frac": 0.2,
"seed": 42,
"fraction_data": 1,
"batch_size": 64,
},
}

config = DictConfig(config)
train_molbind(config)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Operating System :: MacOS",
]
requires-python = ">=3.9, <3.13"
dependencies = ["numpy"]
dependencies = ["numpy", "transformers", "pandas", "polars"]

[project.optional-dependencies]
dev = ["codecov-cli>=0.4.1", "pytest>=7.4.0", "pytest-cov>=3.0.0", "ruff>=0.0.285"]
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
torchmetrics>=0.11.4
info-nce-pytorch

# --------- hydra --------- #
hydra-core==1.3.2
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0

# --------- loggers --------- #
# wandb
wandb
# neptune-client
# mlflow
# comet-ml
Expand All @@ -21,4 +22,4 @@ rootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)
# sh # for running bash commands in some tests (linux/macos only)
21 changes: 0 additions & 21 deletions setup.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/molbind/data/components/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from transformers import AutoTokenizer

SMILES_TOKENIZER = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
SELFIES_TOKENIZER = AutoTokenizer.from_pretrained("HUBioDataLab/SELFormer")
148 changes: 99 additions & 49 deletions src/molbind/data/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,113 @@
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER
import networkx as nx
from networkx import Graph
from typing import List, Dict
from typing import Tuple
from torch import Tensor


MODALITY_DATA_TYPES = {
"smiles" : str,
"selfies" : str,
"graph" : Graph,
"nmr" : str,
"ir" : str
"smiles": str,
"selfies": str,
"graph": Graph,
"nmr": str,
"ir": str,
}

STRING_TOKENIZERS = {
"smiles": SMILES_TOKENIZER,
"selfies": SELFIES_TOKENIZER,
"iupac_name": "iupac_name_tokenizer",
}


class StringDataset(Dataset):
def __init__(
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256
):
"""_summary_
class StringDataLoader(DataLoader):
def __init__(self, dataset, batch_size, shuffle, num_workers, modality="smiles"):
super(StringDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
Args:
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1)
modality (str): name of data modality as found in MODALITY_DATA_TYPES
context_length (int, optional): _description_. Defaults to 256.
"""
assert len(dataset) == 2
assert len(dataset[0]) == len(dataset[1])
self.modality = modality
self.tokenized_smiles = STRING_TOKENIZERS["smiles"](
dataset[0],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)
self.tokenized_string = STRING_TOKENIZERS[modality](
dataset[1],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)

def __len__(self):
return len(self.dataset)

def __iter__(self):
for batch in super(StringDataLoader, self).__iter__():

if self.modality == "smiles":
tokenized_batch = SMILES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt")
elif self.modality == "selfies":
tokenized_batch = SELFIES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt")
yield tokenized_batch["input_ids"], tokenized_batch["attention_mask"]


def load_combined_loader(data_modalities : Dict, batch_size : int, shuffle : bool, num_workers : int) -> CombinedLoader:
return len(self.tokenized_smiles.input_ids)

def __getitem__(self, idx):
return {
"smiles": (
self.tokenized_smiles.input_ids[idx],
self.tokenized_smiles.attention_mask[idx],
),
self.modality: (
self.tokenized_string.input_ids[idx],
self.tokenized_string.attention_mask[idx],
),
}


class GraphDataset(Dataset):
pass


def load_combined_loader(
data_modalities: dict,
batch_size: int,
shuffle: bool,
num_workers: int,
drop_last: bool = True,
) -> CombinedLoader:
"""Combine multiple dataloaders for different modalities into a single dataloader.
Args:
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality)
batch_size (int): batch size for the dataloader
shuffle (bool): shuffle the dataset
num_workers (int): number of workers for the dataloader
drop_last (bool, optional): whether to drop the last batch; defaults to True.
Returns:
CombinedLoader: a combined dataloader for all the modalities
"""
loaders = {}

for modality in data_modalities.keys():
# import pdb; pdb.set_trace()

for modality in [*data_modalities]:
if MODALITY_DATA_TYPES[modality] == str:
loaders[modality] = StringDataLoader(data_modalities[modality], batch_size, shuffle, num_workers, modality)
dataset_instance = StringDataset(data_modalities[modality], modality)
loaders[modality] = DataLoader(
dataset_instance,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
elif MODALITY_DATA_TYPES[modality] == Graph:
loaders[modality] = DataLoader(data_modalities[modality], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return CombinedLoader(loaders, "min_size")


smiles = ["CCO", "CCN", "CCO", "CCN"]
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O]", "[C][C][N]"]
dummy_graphs = ["dummy_graph", "dummy_graph", "dummy_graph", "dummy_graph"]

combined_loader = load_combined_loader(
data_modalities = {
"smiles" : smiles,
"selfies" : selfies,
"graph" : dummy_graphs
},
batch_size=2,
shuffle=True,
num_workers=1)

for batch, batch_idx, dataloader_idx in combined_loader:
print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
graph_dataset_instance = GraphDataset(data_modalities[modality])
loaders[modality] = DataLoader(
graph_dataset_instance,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
return CombinedLoader(loaders, mode="sequential")
Loading

0 comments on commit 25b3bf5

Please sign in to comment.