-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setup first model version for the SMILES-SELFIES modality pair #5
Changes from all commits
61c923b
babda3a
3a177fb
2f990c7
40decc4
1d400cf
c281245
27c1e79
ffa1f2c
119d247
2663f52
e49d50f
4d6dd3d
2b18858
d5e2846
ac74fce
e920076
575ee50
fd1ed99
1d1cc83
546001f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
_target_: molbind.data.dataloaders.load_combined_loader | ||
central_modality: "smiles" | ||
modalities: | ||
- "selfies" | ||
train_frac : 0.8 | ||
val_frac : 0.2 | ||
seed: 42 | ||
fraction_data: 1.0 | ||
dataset_path: "subset.csv" | ||
batch_size: 64 |
This file was deleted.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from molbind.models.lightning_module import train_molbind | ||
from omegaconf import DictConfig | ||
|
||
|
||
if __name__ == "__main__": | ||
config = { | ||
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"}, | ||
"model": { | ||
"projection_heads": { | ||
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, | ||
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, | ||
}, | ||
"encoders": { | ||
"smiles": {"pretrained": True, "freeze_encoder": False}, | ||
"selfies": {"pretrained": True, "freeze_encoder": False}, | ||
}, | ||
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4}, | ||
}, | ||
"loss": {"temperature": 0.1}, | ||
"data": { | ||
"central_modality": "smiles", | ||
"modalities": ["selfies"], | ||
"dataset_path": "subset.csv", | ||
"train_frac": 0.8, | ||
"valid_frac": 0.2, | ||
"seed": 42, | ||
"fraction_data": 1, | ||
"batch_size": 64, | ||
}, | ||
} | ||
|
||
config = DictConfig(config) | ||
train_molbind(config) |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from transformers import AutoTokenizer | ||
|
||
SMILES_TOKENIZER = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | ||
SELFIES_TOKENIZER = AutoTokenizer.from_pretrained("HUBioDataLab/SELFormer") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,113 @@ | ||
from torch.utils.data import DataLoader | ||
from pytorch_lightning import LightningDataModule | ||
from torch.utils.data import DataLoader, Dataset | ||
from lightning.pytorch.utilities.combined_loader import CombinedLoader | ||
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER | ||
import networkx as nx | ||
from networkx import Graph | ||
from typing import List, Dict | ||
from typing import Tuple | ||
from torch import Tensor | ||
|
||
|
||
MODALITY_DATA_TYPES = { | ||
"smiles" : str, | ||
"selfies" : str, | ||
"graph" : Graph, | ||
"nmr" : str, | ||
"ir" : str | ||
"smiles": str, | ||
"selfies": str, | ||
"graph": Graph, | ||
"nmr": str, | ||
"ir": str, | ||
} | ||
|
||
STRING_TOKENIZERS = { | ||
"smiles": SMILES_TOKENIZER, | ||
"selfies": SELFIES_TOKENIZER, | ||
"iupac_name": "iupac_name_tokenizer", | ||
} | ||
|
||
|
||
class StringDataset(Dataset): | ||
def __init__( | ||
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256 | ||
): | ||
"""_summary_ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
class StringDataLoader(DataLoader): | ||
def __init__(self, dataset, batch_size, shuffle, num_workers, modality="smiles"): | ||
super(StringDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | ||
Args: | ||
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1) | ||
modality (str): name of data modality as found in MODALITY_DATA_TYPES | ||
context_length (int, optional): _description_. Defaults to 256. | ||
""" | ||
assert len(dataset) == 2 | ||
assert len(dataset[0]) == len(dataset[1]) | ||
self.modality = modality | ||
Comment on lines
+35
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍🏽 |
||
self.tokenized_smiles = STRING_TOKENIZERS["smiles"]( | ||
AdrianM0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dataset[0], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
self.tokenized_string = STRING_TOKENIZERS[modality]( | ||
dataset[1], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
|
||
Comment on lines
+38
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if your data is large, you might need to revisit this and replace this with some tokenization on the fly or loading from pre-tokenized datasets. it is okay for now, but I'd keep in mind that this might need to be refactored |
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __iter__(self): | ||
for batch in super(StringDataLoader, self).__iter__(): | ||
|
||
if self.modality == "smiles": | ||
tokenized_batch = SMILES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt") | ||
elif self.modality == "selfies": | ||
tokenized_batch = SELFIES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt") | ||
yield tokenized_batch["input_ids"], tokenized_batch["attention_mask"] | ||
|
||
|
||
def load_combined_loader(data_modalities : Dict, batch_size : int, shuffle : bool, num_workers : int) -> CombinedLoader: | ||
return len(self.tokenized_smiles.input_ids) | ||
|
||
def __getitem__(self, idx): | ||
return { | ||
"smiles": ( | ||
self.tokenized_smiles.input_ids[idx], | ||
self.tokenized_smiles.attention_mask[idx], | ||
), | ||
self.modality: ( | ||
self.tokenized_string.input_ids[idx], | ||
self.tokenized_string.attention_mask[idx], | ||
), | ||
} | ||
|
||
|
||
class GraphDataset(Dataset): | ||
pass | ||
|
||
|
||
def load_combined_loader( | ||
data_modalities: dict, | ||
batch_size: int, | ||
shuffle: bool, | ||
num_workers: int, | ||
drop_last: bool = True, | ||
) -> CombinedLoader: | ||
"""Combine multiple dataloaders for different modalities into a single dataloader. | ||
|
||
Args: | ||
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality) | ||
batch_size (int): batch size for the dataloader | ||
shuffle (bool): shuffle the dataset | ||
num_workers (int): number of workers for the dataloader | ||
drop_last (bool, optional): whether to drop the last batch; defaults to True. | ||
|
||
Returns: | ||
CombinedLoader: a combined dataloader for all the modalities | ||
""" | ||
loaders = {} | ||
|
||
for modality in data_modalities.keys(): | ||
# import pdb; pdb.set_trace() | ||
|
||
for modality in [*data_modalities]: | ||
if MODALITY_DATA_TYPES[modality] == str: | ||
loaders[modality] = StringDataLoader(data_modalities[modality], batch_size, shuffle, num_workers, modality) | ||
dataset_instance = StringDataset(data_modalities[modality], modality) | ||
loaders[modality] = DataLoader( | ||
dataset_instance, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
drop_last=drop_last, | ||
) | ||
elif MODALITY_DATA_TYPES[modality] == Graph: | ||
loaders[modality] = DataLoader(data_modalities[modality], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | ||
return CombinedLoader(loaders, "min_size") | ||
|
||
|
||
smiles = ["CCO", "CCN", "CCO", "CCN"] | ||
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O]", "[C][C][N]"] | ||
dummy_graphs = ["dummy_graph", "dummy_graph", "dummy_graph", "dummy_graph"] | ||
|
||
combined_loader = load_combined_loader( | ||
data_modalities = { | ||
"smiles" : smiles, | ||
"selfies" : selfies, | ||
"graph" : dummy_graphs | ||
}, | ||
batch_size=2, | ||
shuffle=True, | ||
num_workers=1) | ||
|
||
for batch, batch_idx, dataloader_idx in combined_loader: | ||
print(f"{batch}, {batch_idx=}, {dataloader_idx=}") | ||
graph_dataset_instance = GraphDataset(data_modalities[modality]) | ||
loaders[modality] = DataLoader( | ||
graph_dataset_instance, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
drop_last=drop_last, | ||
) | ||
return CombinedLoader(loaders, mode="sequential") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is hydra pinned and the torch stuff only has a lower bound 🤔 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's the default stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean it was pinned this way by default? But what is the rational behind doing it this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and do we actually now use both requirements.txt and the
toml
file? This might easily become messy to have it in two places