generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from lamalab-org/selfies_smiles_example
SELFIES - SMILES multimodal example
- Loading branch information
Showing
25 changed files
with
10,460 additions
and
281 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
_target_: molbind.data.dataloaders.load_combined_loader | ||
central_modality: "smiles" | ||
modalities: | ||
- "selfies" | ||
train_frac : 0.8 | ||
val_frac : 0.2 | ||
seed: 42 | ||
fraction_data: 1.0 | ||
dataset_path: "subset.csv" | ||
batch_size: 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from molbind.models.lightning_module import train_molbind | ||
from omegaconf import DictConfig | ||
|
||
|
||
if __name__ == "__main__": | ||
config = { | ||
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"}, | ||
"model": { | ||
"projection_heads": { | ||
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, | ||
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, | ||
}, | ||
"encoders": { | ||
"smiles": {"pretrained": True, "freeze_encoder": False}, | ||
"selfies": {"pretrained": True, "freeze_encoder": False}, | ||
}, | ||
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4}, | ||
}, | ||
"loss": {"temperature": 0.1}, | ||
"data": { | ||
"central_modality": "smiles", | ||
"modalities": ["selfies"], | ||
"dataset_path": "subset.csv", | ||
"train_frac": 0.8, | ||
"valid_frac": 0.2, | ||
"seed": 42, | ||
"fraction_data": 1, | ||
"batch_size": 64, | ||
}, | ||
} | ||
|
||
config = DictConfig(config) | ||
train_molbind(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from transformers import AutoTokenizer | ||
|
||
SMILES_TOKENIZER = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | ||
SELFIES_TOKENIZER = AutoTokenizer.from_pretrained("HUBioDataLab/SELFormer") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,113 @@ | ||
from torch.utils.data import DataLoader | ||
from pytorch_lightning import LightningDataModule | ||
from torch.utils.data import DataLoader, Dataset | ||
from lightning.pytorch.utilities.combined_loader import CombinedLoader | ||
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER | ||
import networkx as nx | ||
from networkx import Graph | ||
from typing import List, Dict | ||
from typing import Tuple | ||
from torch import Tensor | ||
|
||
|
||
MODALITY_DATA_TYPES = { | ||
"smiles" : str, | ||
"selfies" : str, | ||
"graph" : Graph, | ||
"nmr" : str, | ||
"ir" : str | ||
"smiles": str, | ||
"selfies": str, | ||
"graph": Graph, | ||
"nmr": str, | ||
"ir": str, | ||
} | ||
|
||
STRING_TOKENIZERS = { | ||
"smiles": SMILES_TOKENIZER, | ||
"selfies": SELFIES_TOKENIZER, | ||
"iupac_name": "iupac_name_tokenizer", | ||
} | ||
|
||
|
||
class StringDataset(Dataset): | ||
def __init__( | ||
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256 | ||
): | ||
"""_summary_ | ||
class StringDataLoader(DataLoader): | ||
def __init__(self, dataset, batch_size, shuffle, num_workers, modality="smiles"): | ||
super(StringDataLoader, self).__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | ||
Args: | ||
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1) | ||
modality (str): name of data modality as found in MODALITY_DATA_TYPES | ||
context_length (int, optional): _description_. Defaults to 256. | ||
""" | ||
assert len(dataset) == 2 | ||
assert len(dataset[0]) == len(dataset[1]) | ||
self.modality = modality | ||
self.tokenized_smiles = STRING_TOKENIZERS["smiles"]( | ||
dataset[0], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
self.tokenized_string = STRING_TOKENIZERS[modality]( | ||
dataset[1], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __iter__(self): | ||
for batch in super(StringDataLoader, self).__iter__(): | ||
|
||
if self.modality == "smiles": | ||
tokenized_batch = SMILES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt") | ||
elif self.modality == "selfies": | ||
tokenized_batch = SELFIES_TOKENIZER(batch, padding="max_length", truncation=True, return_tensors="pt") | ||
yield tokenized_batch["input_ids"], tokenized_batch["attention_mask"] | ||
|
||
|
||
def load_combined_loader(data_modalities : Dict, batch_size : int, shuffle : bool, num_workers : int) -> CombinedLoader: | ||
return len(self.tokenized_smiles.input_ids) | ||
|
||
def __getitem__(self, idx): | ||
return { | ||
"smiles": ( | ||
self.tokenized_smiles.input_ids[idx], | ||
self.tokenized_smiles.attention_mask[idx], | ||
), | ||
self.modality: ( | ||
self.tokenized_string.input_ids[idx], | ||
self.tokenized_string.attention_mask[idx], | ||
), | ||
} | ||
|
||
|
||
class GraphDataset(Dataset): | ||
pass | ||
|
||
|
||
def load_combined_loader( | ||
data_modalities: dict, | ||
batch_size: int, | ||
shuffle: bool, | ||
num_workers: int, | ||
drop_last: bool = True, | ||
) -> CombinedLoader: | ||
"""Combine multiple dataloaders for different modalities into a single dataloader. | ||
Args: | ||
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality) | ||
batch_size (int): batch size for the dataloader | ||
shuffle (bool): shuffle the dataset | ||
num_workers (int): number of workers for the dataloader | ||
drop_last (bool, optional): whether to drop the last batch; defaults to True. | ||
Returns: | ||
CombinedLoader: a combined dataloader for all the modalities | ||
""" | ||
loaders = {} | ||
|
||
for modality in data_modalities.keys(): | ||
# import pdb; pdb.set_trace() | ||
|
||
for modality in [*data_modalities]: | ||
if MODALITY_DATA_TYPES[modality] == str: | ||
loaders[modality] = StringDataLoader(data_modalities[modality], batch_size, shuffle, num_workers, modality) | ||
dataset_instance = StringDataset(data_modalities[modality], modality) | ||
loaders[modality] = DataLoader( | ||
dataset_instance, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
drop_last=drop_last, | ||
) | ||
elif MODALITY_DATA_TYPES[modality] == Graph: | ||
loaders[modality] = DataLoader(data_modalities[modality], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) | ||
return CombinedLoader(loaders, "min_size") | ||
|
||
|
||
smiles = ["CCO", "CCN", "CCO", "CCN"] | ||
selfies = ["[C][C][O]", "[C][C][N]", "[C][C][O]", "[C][C][N]"] | ||
dummy_graphs = ["dummy_graph", "dummy_graph", "dummy_graph", "dummy_graph"] | ||
|
||
combined_loader = load_combined_loader( | ||
data_modalities = { | ||
"smiles" : smiles, | ||
"selfies" : selfies, | ||
"graph" : dummy_graphs | ||
}, | ||
batch_size=2, | ||
shuffle=True, | ||
num_workers=1) | ||
|
||
for batch, batch_idx, dataloader_idx in combined_loader: | ||
print(f"{batch}, {batch_idx=}, {dataloader_idx=}") | ||
graph_dataset_instance = GraphDataset(data_modalities[modality]) | ||
loaders[modality] = DataLoader( | ||
graph_dataset_instance, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
drop_last=drop_last, | ||
) | ||
return CombinedLoader(loaders, mode="sequential") |
Oops, something went wrong.