-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setup first model version for the SMILES-SELFIES modality pair #5
Changes from 5 commits
61c923b
babda3a
3a177fb
2f990c7
40decc4
1d400cf
c281245
27c1e79
ffa1f2c
119d247
2663f52
e49d50f
4d6dd3d
2b18858
d5e2846
ac74fce
e920076
575ee50
fd1ed99
1d1cc83
546001f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from transformers import AutoTokenizer | ||
|
||
SMILES_TOKENIZER = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | ||
SELFIES_TOKENIZER = AutoTokenizer.from_pretrained("HUBioDataLab/SELFormer") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
from torch.utils.data import DataLoader | ||
from pytorch_lightning import LightningDataModule | ||
from torch.utils.data import DataLoader, Dataset | ||
from lightning.pytorch.utilities.combined_loader import CombinedLoader | ||
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER | ||
import networkx as nx | ||
|
@@ -8,56 +7,131 @@ | |
|
||
|
||
MODALITY_DATA_TYPES = { | ||
"smiles" : str, | ||
"selfies" : str, | ||
"graph" : Graph, | ||
"nmr" : str, | ||
"ir" : str | ||
"smiles": str, | ||
"selfies": str, | ||
"graph": Graph, | ||
"nmr": str, | ||
"ir": str, | ||
} | ||
|
||
STRING_TOKENIZERS = { | ||
"smiles": SMILES_TOKENIZER, | ||
"selfies": SELFIES_TOKENIZER, | ||
"iupac_name": "iupac_name_tokenizer", | ||
} | ||
|
||
|
||
class StringDataLoader(DataLoader): | ||
def __init__(self, dataset, batch_size, shuffle, num_workers, modality="smiles"): | ||
super(StringDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | ||
class StringDataset(Dataset): | ||
def __init__(self, dataset, modality, context_length=256): | ||
self.dataset = dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we add docstrings? To me it was not clear from the variable name
AdrianM0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.modality = modality | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we still need it? Otherwise, we can perhaps keep the object leaner by avoiding this attribute |
||
self.tokenized_smiles = STRING_TOKENIZERS["smiles"]( | ||
AdrianM0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dataset[0], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
self.tokenized_string = STRING_TOKENIZERS[modality]( | ||
dataset[1], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
|
||
Comment on lines
+38
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if your data is large, you might need to revisit this and replace this with some tokenization on the fly or loading from pre-tokenized datasets. it is okay for now, but I'd keep in mind that this might need to be refactored |
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __iter__(self): | ||
for batch in super(StringDataLoader, self).__iter__(): | ||
|
||
if self.modality == "smiles": | ||
tokenized_batch = SMILES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt") | ||
elif self.modality == "selfies": | ||
tokenized_batch = SELFIES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt") | ||
yield tokenized_batch["input_ids"], tokenized_batch["attention_mask"] | ||
|
||
|
||
def load_combined_loader(data_modalities : Dict, batch_size : int, shuffle : bool, num_workers : int) -> CombinedLoader: | ||
return len(self.tokenized_smiles.input_ids) | ||
|
||
def __getitem__(self, idx): | ||
return { | ||
"smiles": ( | ||
self.tokenized_smiles.input_ids[idx], | ||
self.tokenized_smiles.attention_mask[idx], | ||
), | ||
self.modality: ( | ||
self.tokenized_string.input_ids[idx], | ||
self.tokenized_string.attention_mask[idx], | ||
), | ||
} | ||
|
||
|
||
class GraphDataset(Dataset): | ||
def __init__(self, dataset, context_length=128): | ||
self.dataset = dataset | ||
self.graphs = dataset[1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar comments as above |
||
self.smiles = STRING_TOKENIZERS["smiles"]( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those are assumed to be PyG objects? |
||
dataset[0], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
|
||
def __len__(self): | ||
return len(self.graphs) | ||
|
||
def __getitem__(self, idx): | ||
return { | ||
"smiles": (self.smiles.input_ids[idx], self.smiles.attention_mask[idx]), | ||
"graph": self.graphs[idx], | ||
} | ||
|
||
|
||
def load_combined_loader( | ||
data_modalities: dict, | ||
batch_size: int, | ||
shuffle: bool, | ||
num_workers: int, | ||
drop_last: bool = True, | ||
) -> CombinedLoader: | ||
"""_summary_ | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. summary is missing ;) |
||
Args: | ||
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality) | ||
batch_size (int): batch size for the dataloader | ||
shuffle (bool): shuffle the dataset | ||
num_workers (int): number of workers for the dataloader | ||
drop_last (bool, optional): whether to drop the last batch; defaults to True. | ||
|
||
Returns: | ||
CombinedLoader: a combined dataloader for all the modalities | ||
""" | ||
loaders = {} | ||
|
||
for modality in data_modalities.keys(): | ||
# import pdb; pdb.set_trace() | ||
|
||
for modality in [*data_modalities]: | ||
if MODALITY_DATA_TYPES[modality] == str: | ||
loaders[modality] = StringDataLoader(data_modalities[modality], batch_size, shuffle, num_workers, modality) | ||
dataset_instance = StringDataset(data_modalities[modality], modality) | ||
loaders[modality] = DataLoader( | ||
dataset_instance, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
drop_last=drop_last, | ||
) | ||
elif MODALITY_DATA_TYPES[modality] == Graph: | ||
loaders[modality] = DataLoader(data_modalities[modality], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | ||
return CombinedLoader(loaders, "min_size") | ||
|
||
|
||
smiles = ["CCO", "CCN", "CCO", "CCN"] | ||
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O]", "[C][C][N]"] | ||
dummy_graphs = ["dummy_graph", "dummy_graph", "dummy_graph", "dummy_graph"] | ||
|
||
combined_loader = load_combined_loader( | ||
data_modalities = { | ||
"smiles" : smiles, | ||
"selfies" : selfies, | ||
"graph" : dummy_graphs | ||
}, | ||
batch_size=2, | ||
shuffle=True, | ||
num_workers=1) | ||
|
||
for batch, batch_idx, dataloader_idx in combined_loader: | ||
print(f"{batch}, {batch_idx=}, {dataloader_idx=}") | ||
graph_dataset_instance = GraphDataset(data_modalities[modality]) | ||
loaders[modality] = DataLoader( | ||
graph_dataset_instance, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
drop_last=drop_last, | ||
) | ||
return CombinedLoader(loaders, mode="sequential") | ||
|
||
|
||
if __name__ == "__main__": | ||
smiles = ["CCO", "CCN", "CCON", "CCNO"] | ||
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O][N]", "[C][C][N][O]"] | ||
dummy_graphs = ["CCO_graph", "CCN_graph", "CCON_graph", "CCNO_graph"] | ||
|
||
combined_loader = load_combined_loader( | ||
data_modalities={"selfies": [smiles, selfies], "graph": [smiles, dummy_graphs]}, | ||
batch_size=2, | ||
shuffle=False, | ||
num_workers=1, | ||
) | ||
|
||
for batch, batch_idx, dataloader_idx in combined_loader: | ||
print(f"{batch=}, {batch_idx=}, {dataloader_idx=}") | ||
AdrianM0 marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,16 @@ | ||
from typing import Dict, Literal | ||
from transformers import AutoModelForCausalLM | ||
import torch.nn as nn | ||
|
||
|
||
class BaseModalityEncoder(nn.Module): | ||
def __init__(self, | ||
projection_head_type : Literal["linear", "non-linear"] = "non-linear", | ||
pretrained=True, | ||
**kwargs): | ||
self.pretrained = pretrained | ||
self.encoder = self.build_encoder() | ||
def __init__(self, freeze_encoder: bool = False, pretrained=True, **kwargs): | ||
super(BaseModalityEncoder, self).__init__() | ||
self.pretrained = pretrained | ||
self.freeze_encoder = freeze_encoder | ||
|
||
|
||
def build_encoder(self): | ||
pass | ||
|
||
def forward(self, x): | ||
x = self.encoder(x) | ||
x = self.projection_head(x) | ||
# pooling | ||
return x | ||
return self.encoder(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is hydra pinned and the torch stuff only has a lower bound 🤔 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's the default stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean it was pinned this way by default? But what is the rational behind doing it this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and do we actually now use both requirements.txt and the
toml
file? This might easily become messy to have it in two places