-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Run model from hydra
setup
#16
Changes from all commits
48f28b1
5e0c33d
843e122
abf9b42
de65592
0f6fb79
f339ee2
4e391d4
11e7cb9
4a11e4a
b5d5943
a533a09
95f5b99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,9 @@ | ||
_target_: molbind.data.dataloaders.load_combined_loader | ||
central_modality: "smiles" | ||
modalities: | ||
- "selfies" | ||
train_frac : 0.8 | ||
val_frac : 0.2 | ||
valid_frac : 0.2 | ||
seed: 42 | ||
fraction_data: 1.0 | ||
dataset_path: "subset.csv" | ||
dataset_path: "${paths.data_dir}/subset.csv" | ||
batch_size: 64 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,4 @@ | ||
# https://wandb.ai | ||
|
||
wandb: | ||
_target_: lightning.pytorch.loggers.wandb.WandbLogger | ||
offline: False | ||
project: "molbind" | ||
entity: "adrianmirza" | ||
entity: "wandb_username" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# model architecture | ||
encoders: | ||
smiles: | ||
pretrained: True | ||
freeze_encoder: False | ||
|
||
selfies: | ||
pretrained: True | ||
freeze_encoder: False | ||
|
||
projection_heads: | ||
smiles: | ||
dims: [768, 256, 128] | ||
activation: LeakyReLU | ||
batch_norm: False | ||
selfies: | ||
dims: [768, 256, 128] | ||
activation: LeakyReLU | ||
batch_norm: False | ||
|
||
optimizer: | ||
lr: 0.0001 | ||
weight_decay: 0.0001 | ||
|
||
loss: | ||
temperature: 0.1 | ||
|
||
# compile model for faster training with pytorch 2.0 | ||
compile: false |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,95 @@ | ||
from molbind.models.lightning_module import train_molbind | ||
import hydra | ||
import pytorch_lightning as L | ||
import polars as pl | ||
from molbind.data.dataloaders import load_combined_loader | ||
from molbind.models.lightning_module import MolBindModule | ||
from omegaconf import DictConfig | ||
import torch | ||
import rootutils | ||
|
||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | ||
|
||
if __name__ == "__main__": | ||
config = { | ||
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"}, | ||
"model": { | ||
"projection_heads": { | ||
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, | ||
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, | ||
}, | ||
"encoders": { | ||
"smiles": {"pretrained": True, "freeze_encoder": False}, | ||
"selfies": {"pretrained": True, "freeze_encoder": False}, | ||
}, | ||
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4}, | ||
}, | ||
"loss": {"temperature": 0.1}, | ||
"data": { | ||
"central_modality": "smiles", | ||
"modalities": ["selfies"], | ||
"dataset_path": "subset.csv", | ||
"train_frac": 0.8, | ||
"valid_frac": 0.2, | ||
"seed": 42, | ||
"fraction_data": 1, | ||
"batch_size": 64, | ||
}, | ||
} | ||
|
||
config = DictConfig(config) | ||
|
||
def train_molbind(config: DictConfig): | ||
wandb_logger = L.loggers.WandbLogger(**config.logger.wandb) | ||
|
||
device_count = torch.cuda.device_count() | ||
trainer = L.Trainer( | ||
max_epochs=100, | ||
accelerator="cuda", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you maybe wand to also make |
||
log_every_n_steps=10, | ||
logger=wandb_logger, | ||
devices=device_count if device_count > 1 else "auto", | ||
strategy="ddp" if device_count > 1 else "auto", | ||
) | ||
|
||
train_modality_data = {} | ||
valid_modality_data = {} | ||
|
||
# Example SMILES - SELFIES modality pair: | ||
data = pl.read_csv(config.data.dataset_path) | ||
shuffled_data = data.sample( | ||
fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed | ||
) | ||
dataset_length = len(shuffled_data) | ||
valid_shuffled_data = shuffled_data.tail( | ||
int(config.data.valid_frac * dataset_length) | ||
) | ||
train_shuffled_data = shuffled_data.head( | ||
int(config.data.train_frac * dataset_length) | ||
) | ||
|
||
columns = shuffled_data.columns | ||
# extract non-central modalities | ||
non_central_modalities = config.data.modalities | ||
central_modality = config.data.central_modality | ||
|
||
for column in columns: | ||
if column in non_central_modalities: | ||
# drop nan for specific pair | ||
train_modality_pair = train_shuffled_data[ | ||
[central_modality, column] | ||
].drop_nulls() | ||
valid_modality_pair = valid_shuffled_data[ | ||
[central_modality, column] | ||
].drop_nulls() | ||
|
||
train_modality_data[column] = [ | ||
train_modality_pair[central_modality].to_list(), | ||
train_modality_pair[column].to_list(), | ||
] | ||
valid_modality_data[column] = [ | ||
valid_modality_pair[central_modality].to_list(), | ||
valid_modality_pair[column].to_list(), | ||
] | ||
|
||
combined_loader = load_combined_loader( | ||
central_modality=config.data.central_modality, | ||
data_modalities=train_modality_data, | ||
batch_size=config.data.batch_size, | ||
shuffle=True, | ||
num_workers=1, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you have a more complex loader, you might be happy about having more workers |
||
|
||
valid_dataloader = load_combined_loader( | ||
central_modality=config.data.central_modality, | ||
data_modalities=valid_modality_data, | ||
batch_size=config.data.batch_size, | ||
shuffle=False, | ||
num_workers=1, | ||
) | ||
|
||
trainer.fit( | ||
MolBindModule(config), | ||
train_dataloaders=combined_loader, | ||
val_dataloaders=valid_dataloader, | ||
) | ||
|
||
|
||
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") | ||
def main(config: DictConfig): | ||
train_molbind(config) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
from lightning.pytorch.utilities.combined_loader import CombinedLoader | ||
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER | ||
from networkx import Graph | ||
from typing import Tuple | ||
from typing import Tuple, Optional | ||
from torch import Tensor | ||
|
||
|
||
|
@@ -23,9 +23,13 @@ | |
|
||
class StringDataset(Dataset): | ||
def __init__( | ||
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256 | ||
self, | ||
dataset: Tuple[Tensor, Tensor], | ||
modality: str, | ||
central_modality: str = "smiles", | ||
context_length: Optional[int] = 256, | ||
): | ||
"""_summary_ | ||
"""Dataset for string modalities. | ||
|
||
Args: | ||
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1) | ||
|
@@ -34,14 +38,21 @@ def __init__( | |
""" | ||
assert len(dataset) == 2 | ||
assert len(dataset[0]) == len(dataset[1]) | ||
|
||
self.modality = modality | ||
self.tokenized_smiles = STRING_TOKENIZERS["smiles"]( | ||
self.central_modality = central_modality | ||
|
||
assert MODALITY_DATA_TYPES[modality] == str | ||
assert MODALITY_DATA_TYPES[central_modality] == str | ||
|
||
self.tokenized_central_modality = STRING_TOKENIZERS[central_modality]( | ||
dataset[0], | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=context_length, | ||
) | ||
|
||
self.tokenized_string = STRING_TOKENIZERS[modality]( | ||
dataset[1], | ||
padding="max_length", | ||
|
@@ -51,13 +62,13 @@ def __init__( | |
) | ||
|
||
def __len__(self): | ||
return len(self.tokenized_smiles.input_ids) | ||
return len(self.tokenized_central_modality.input_ids) | ||
|
||
def __getitem__(self, idx): | ||
return { | ||
"smiles": ( | ||
self.tokenized_smiles.input_ids[idx], | ||
self.tokenized_smiles.attention_mask[idx], | ||
self.central_modality: ( | ||
self.tokenized_central_modality.input_ids[idx], | ||
self.tokenized_central_modality.attention_mask[idx], | ||
), | ||
self.modality: ( | ||
self.tokenized_string.input_ids[idx], | ||
|
@@ -75,25 +86,31 @@ def load_combined_loader( | |
batch_size: int, | ||
shuffle: bool, | ||
num_workers: int, | ||
central_modality: str = "smiles", | ||
drop_last: bool = True, | ||
) -> CombinedLoader: | ||
"""Combine multiple dataloaders for different modalities into a single dataloader. | ||
|
||
Args: | ||
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality) | ||
data_modalities (dict): data inputs for each modality as pairs of (central_modality, modality) | ||
batch_size (int): batch size for the dataloader | ||
shuffle (bool): shuffle the dataset | ||
num_workers (int): number of workers for the dataloader | ||
drop_last (bool, optional): whether to drop the last batch; defaults to True. | ||
|
||
central_modality (str, optional): central modality to use for the dataset; defaults to "smiles". | ||
Returns: | ||
CombinedLoader: a combined dataloader for all the modalities | ||
""" | ||
loaders = {} | ||
|
||
for modality in [*data_modalities]: | ||
if MODALITY_DATA_TYPES[modality] == str: | ||
dataset_instance = StringDataset(data_modalities[modality], modality) | ||
dataset_instance = StringDataset( | ||
dataset=data_modalities[modality], | ||
modality=modality, | ||
central_modality=central_modality, | ||
context_length=256, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be fixed here? |
||
) | ||
loaders[modality] = DataLoader( | ||
dataset_instance, | ||
batch_size=batch_size, | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did know this module!
what didn't work without this?