Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup first model version for the SMILES-SELFIES modality pair #5

Merged
merged 21 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions configs/model/mnist.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/paths/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ log_dir: ${paths.root_dir}/logs/
output_dir: ${hydra:runtime.output_dir}

# path to working directory
work_dir: ${hydra:runtime.cwd}
work_dir: ${hydra:runtime.cwd}
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# - conda allows for installing packages without requiring certain compilers or
# libraries to be available in the system, since it installs precompiled binaries

name: myenv
name: molbind

channels:
- pytorch
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Operating System :: MacOS",
]
requires-python = ">=3.9, <3.13"
dependencies = ["numpy"]
dependencies = ["numpy", "transformers", "pandas", "polars"]

[project.optional-dependencies]
dev = ["codecov-cli>=0.4.1", "pytest>=7.4.0", "pytest-cov>=3.0.0", "ruff>=0.0.285"]
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
torchmetrics>=0.11.4
info-nce-pytorch

# --------- hydra --------- #
hydra-core==1.3.2
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0
Comment on lines 9 to 11
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is hydra pinned and the torch stuff only has a lower bound 🤔 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's the default stuff

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean it was pinned this way by default? But what is the rational behind doing it this way?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and do we actually now use both requirements.txt and the toml file? This might easily become messy to have it in two places


# --------- loggers --------- #
# wandb
wandb
# neptune-client
# mlflow
# comet-ml
Expand All @@ -21,4 +22,4 @@ rootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)
# sh # for running bash commands in some tests (linux/macos only)
21 changes: 0 additions & 21 deletions setup.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/molbind/data/components/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from transformers import AutoTokenizer

SMILES_TOKENIZER = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
SELFIES_TOKENIZER = AutoTokenizer.from_pretrained("HUBioDataLab/SELFormer")
168 changes: 121 additions & 47 deletions src/molbind/data/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER
import networkx as nx
Expand All @@ -8,56 +7,131 @@


MODALITY_DATA_TYPES = {
"smiles" : str,
"selfies" : str,
"graph" : Graph,
"nmr" : str,
"ir" : str
"smiles": str,
"selfies": str,
"graph": Graph,
"nmr": str,
"ir": str,
}

STRING_TOKENIZERS = {
"smiles": SMILES_TOKENIZER,
"selfies": SELFIES_TOKENIZER,
"iupac_name": "iupac_name_tokenizer",
}


class StringDataLoader(DataLoader):
def __init__(self, dataset, batch_size, shuffle, num_workers, modality="smiles"):
super(StringDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
class StringDataset(Dataset):
def __init__(self, dataset, modality, context_length=256):
self.dataset = dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add docstrings? To me it was not clear from the variable name dataset that this is supposed to be indexable (like a tuple)

AdrianM0 marked this conversation as resolved.
Show resolved Hide resolved
self.modality = modality
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need it? Otherwise, we can perhaps keep the object leaner by avoiding this attribute

self.tokenized_smiles = STRING_TOKENIZERS["smiles"](
AdrianM0 marked this conversation as resolved.
Show resolved Hide resolved
dataset[0],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)
self.tokenized_string = STRING_TOKENIZERS[modality](
dataset[1],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)

Comment on lines +38 to +51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if your data is large, you might need to revisit this and replace this with some tokenization on the fly or loading from pre-tokenized datasets.

it is okay for now, but I'd keep in mind that this might need to be refactored

def __len__(self):
return len(self.dataset)

def __iter__(self):
for batch in super(StringDataLoader, self).__iter__():

if self.modality == "smiles":
tokenized_batch = SMILES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt")
elif self.modality == "selfies":
tokenized_batch = SELFIES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt")
yield tokenized_batch["input_ids"], tokenized_batch["attention_mask"]


def load_combined_loader(data_modalities : Dict, batch_size : int, shuffle : bool, num_workers : int) -> CombinedLoader:
return len(self.tokenized_smiles.input_ids)

def __getitem__(self, idx):
return {
"smiles": (
self.tokenized_smiles.input_ids[idx],
self.tokenized_smiles.attention_mask[idx],
),
self.modality: (
self.tokenized_string.input_ids[idx],
self.tokenized_string.attention_mask[idx],
),
}


class GraphDataset(Dataset):
def __init__(self, dataset, context_length=128):
self.dataset = dataset
self.graphs = dataset[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comments as above

self.smiles = STRING_TOKENIZERS["smiles"](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are assumed to be PyG objects?

dataset[0],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)

def __len__(self):
return len(self.graphs)

def __getitem__(self, idx):
return {
"smiles": (self.smiles.input_ids[idx], self.smiles.attention_mask[idx]),
"graph": self.graphs[idx],
}


def load_combined_loader(
data_modalities: dict,
batch_size: int,
shuffle: bool,
num_workers: int,
drop_last: bool = True,
) -> CombinedLoader:
"""_summary_

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summary is missing ;)

Args:
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality)
batch_size (int): batch size for the dataloader
shuffle (bool): shuffle the dataset
num_workers (int): number of workers for the dataloader
drop_last (bool, optional): whether to drop the last batch; defaults to True.

Returns:
CombinedLoader: a combined dataloader for all the modalities
"""
loaders = {}

for modality in data_modalities.keys():
# import pdb; pdb.set_trace()

for modality in [*data_modalities]:
if MODALITY_DATA_TYPES[modality] == str:
loaders[modality] = StringDataLoader(data_modalities[modality], batch_size, shuffle, num_workers, modality)
dataset_instance = StringDataset(data_modalities[modality], modality)
loaders[modality] = DataLoader(
dataset_instance,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
elif MODALITY_DATA_TYPES[modality] == Graph:
loaders[modality] = DataLoader(data_modalities[modality], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return CombinedLoader(loaders, "min_size")


smiles = ["CCO", "CCN", "CCO", "CCN"]
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O]", "[C][C][N]"]
dummy_graphs = ["dummy_graph", "dummy_graph", "dummy_graph", "dummy_graph"]

combined_loader = load_combined_loader(
data_modalities = {
"smiles" : smiles,
"selfies" : selfies,
"graph" : dummy_graphs
},
batch_size=2,
shuffle=True,
num_workers=1)

for batch, batch_idx, dataloader_idx in combined_loader:
print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
graph_dataset_instance = GraphDataset(data_modalities[modality])
loaders[modality] = DataLoader(
graph_dataset_instance,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)
return CombinedLoader(loaders, mode="sequential")


if __name__ == "__main__":
smiles = ["CCO", "CCN", "CCON", "CCNO"]
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O][N]", "[C][C][N][O]"]
dummy_graphs = ["CCO_graph", "CCN_graph", "CCON_graph", "CCNO_graph"]

combined_loader = load_combined_loader(
data_modalities={"selfies": [smiles, selfies], "graph": [smiles, dummy_graphs]},
batch_size=2,
shuffle=False,
num_workers=1,
)

for batch, batch_idx, dataloader_idx in combined_loader:
print(f"{batch=}, {batch_idx=}, {dataloader_idx=}")
AdrianM0 marked this conversation as resolved.
Show resolved Hide resolved
22 changes: 8 additions & 14 deletions src/molbind/models/components/base_encoder.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
from typing import Dict, Literal
from transformers import AutoModelForCausalLM
import torch.nn as nn


class BaseModalityEncoder(nn.Module):
def __init__(self,
projection_head_type : Literal["linear", "non-linear"] = "non-linear",
pretrained=True,
**kwargs):
self.pretrained = pretrained
self.encoder = self.build_encoder()
def __init__(self, freeze_encoder: bool = False, pretrained=True, **kwargs):
super(BaseModalityEncoder, self).__init__()
self.pretrained = pretrained
self.freeze_encoder = freeze_encoder


def build_encoder(self):
pass

def forward(self, x):
x = self.encoder(x)
x = self.projection_head(x)
# pooling
return x
return self.encoder(x)
Loading
Loading