Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run model from hydra setup #16

Merged
merged 13 commits into from
Apr 11, 2024
6 changes: 0 additions & 6 deletions .env.example

This file was deleted.

5 changes: 2 additions & 3 deletions configs/data/molbind.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
_target_: molbind.data.dataloaders.load_combined_loader
central_modality: "smiles"
modalities:
- "selfies"
train_frac : 0.8
val_frac : 0.2
valid_frac : 0.2
seed: 42
fraction_data: 1.0
dataset_path: "subset.csv"
dataset_path: "${paths.data_dir}/subset.csv"
batch_size: 64
5 changes: 1 addition & 4 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# https://wandb.ai

wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
offline: False
project: "molbind"
entity: "adrianmirza"
entity: "wandb_username"
29 changes: 29 additions & 0 deletions configs/model/molbind.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# model architecture
encoders:
smiles:
pretrained: True
freeze_encoder: False

selfies:
pretrained: True
freeze_encoder: False

projection_heads:
smiles:
dims: [768, 256, 128]
activation: LeakyReLU
batch_norm: False
selfies:
dims: [768, 256, 128]
activation: LeakyReLU
batch_norm: False

optimizer:
lr: 0.0001
weight_decay: 0.0001

loss:
temperature: 0.1

# compile model for faster training with pytorch 2.0
compile: false
3 changes: 1 addition & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ defaults:
- data: molbind
- model: molbind
- logger: wandb
- paths: default
# - callbacks: default
# - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
# - trainer: default
# - paths: default
# - extras: default
# - hydra: default

Expand Down
File renamed without changes.
120 changes: 91 additions & 29 deletions experiments/train.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,95 @@
from molbind.models.lightning_module import train_molbind
import hydra
import pytorch_lightning as L
import polars as pl
from molbind.data.dataloaders import load_combined_loader
from molbind.models.lightning_module import MolBindModule
from omegaconf import DictConfig
import torch
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did know this module!

what didn't work without this?

if __name__ == "__main__":
config = {
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"},
"model": {
"projection_heads": {
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
},
"encoders": {
"smiles": {"pretrained": True, "freeze_encoder": False},
"selfies": {"pretrained": True, "freeze_encoder": False},
},
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4},
},
"loss": {"temperature": 0.1},
"data": {
"central_modality": "smiles",
"modalities": ["selfies"],
"dataset_path": "subset.csv",
"train_frac": 0.8,
"valid_frac": 0.2,
"seed": 42,
"fraction_data": 1,
"batch_size": 64,
},
}

config = DictConfig(config)

def train_molbind(config: DictConfig):
wandb_logger = L.loggers.WandbLogger(**config.logger.wandb)

device_count = torch.cuda.device_count()
trainer = L.Trainer(
max_epochs=100,
accelerator="cuda",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you maybe wand to also make max_epochs and accelerator customizable (for debugging, mps on MacBook might be fine, too)

log_every_n_steps=10,
logger=wandb_logger,
devices=device_count if device_count > 1 else "auto",
strategy="ddp" if device_count > 1 else "auto",
)

train_modality_data = {}
valid_modality_data = {}

# Example SMILES - SELFIES modality pair:
data = pl.read_csv(config.data.dataset_path)
shuffled_data = data.sample(
fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed
)
dataset_length = len(shuffled_data)
valid_shuffled_data = shuffled_data.tail(
int(config.data.valid_frac * dataset_length)
)
train_shuffled_data = shuffled_data.head(
int(config.data.train_frac * dataset_length)
)

columns = shuffled_data.columns
# extract non-central modalities
non_central_modalities = config.data.modalities
central_modality = config.data.central_modality

for column in columns:
if column in non_central_modalities:
# drop nan for specific pair
train_modality_pair = train_shuffled_data[
[central_modality, column]
].drop_nulls()
valid_modality_pair = valid_shuffled_data[
[central_modality, column]
].drop_nulls()

train_modality_data[column] = [
train_modality_pair[central_modality].to_list(),
train_modality_pair[column].to_list(),
]
valid_modality_data[column] = [
valid_modality_pair[central_modality].to_list(),
valid_modality_pair[column].to_list(),
]

combined_loader = load_combined_loader(
central_modality=config.data.central_modality,
data_modalities=train_modality_data,
batch_size=config.data.batch_size,
shuffle=True,
num_workers=1,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you have a more complex loader, you might be happy about having more workers


valid_dataloader = load_combined_loader(
central_modality=config.data.central_modality,
data_modalities=valid_modality_data,
batch_size=config.data.batch_size,
shuffle=False,
num_workers=1,
)

trainer.fit(
MolBindModule(config),
train_dataloaders=combined_loader,
val_dataloaders=valid_dataloader,
)


@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(config: DictConfig):
train_molbind(config)


if __name__ == "__main__":
main()
39 changes: 28 additions & 11 deletions src/molbind/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER
from networkx import Graph
from typing import Tuple
from typing import Tuple, Optional
from torch import Tensor


Expand All @@ -23,9 +23,13 @@

class StringDataset(Dataset):
def __init__(
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256
self,
dataset: Tuple[Tensor, Tensor],
modality: str,
central_modality: str = "smiles",
context_length: Optional[int] = 256,
):
"""_summary_
"""Dataset for string modalities.

Args:
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1)
Expand All @@ -34,14 +38,21 @@ def __init__(
"""
assert len(dataset) == 2
assert len(dataset[0]) == len(dataset[1])

self.modality = modality
self.tokenized_smiles = STRING_TOKENIZERS["smiles"](
self.central_modality = central_modality

assert MODALITY_DATA_TYPES[modality] == str
assert MODALITY_DATA_TYPES[central_modality] == str

self.tokenized_central_modality = STRING_TOKENIZERS[central_modality](
dataset[0],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)

self.tokenized_string = STRING_TOKENIZERS[modality](
dataset[1],
padding="max_length",
Expand All @@ -51,13 +62,13 @@ def __init__(
)

def __len__(self):
return len(self.tokenized_smiles.input_ids)
return len(self.tokenized_central_modality.input_ids)

def __getitem__(self, idx):
return {
"smiles": (
self.tokenized_smiles.input_ids[idx],
self.tokenized_smiles.attention_mask[idx],
self.central_modality: (
self.tokenized_central_modality.input_ids[idx],
self.tokenized_central_modality.attention_mask[idx],
),
self.modality: (
self.tokenized_string.input_ids[idx],
Expand All @@ -75,25 +86,31 @@ def load_combined_loader(
batch_size: int,
shuffle: bool,
num_workers: int,
central_modality: str = "smiles",
drop_last: bool = True,
) -> CombinedLoader:
"""Combine multiple dataloaders for different modalities into a single dataloader.

Args:
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality)
data_modalities (dict): data inputs for each modality as pairs of (central_modality, modality)
batch_size (int): batch size for the dataloader
shuffle (bool): shuffle the dataset
num_workers (int): number of workers for the dataloader
drop_last (bool, optional): whether to drop the last batch; defaults to True.

central_modality (str, optional): central modality to use for the dataset; defaults to "smiles".
Returns:
CombinedLoader: a combined dataloader for all the modalities
"""
loaders = {}

for modality in [*data_modalities]:
if MODALITY_DATA_TYPES[modality] == str:
dataset_instance = StringDataset(data_modalities[modality], modality)
dataset_instance = StringDataset(
dataset=data_modalities[modality],
modality=modality,
central_modality=central_modality,
context_length=256,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be fixed here?

)
loaders[modality] = DataLoader(
dataset_instance,
batch_size=batch_size,
Expand Down
19 changes: 0 additions & 19 deletions src/molbind/models/components/pooler.py

This file was deleted.

Loading
Loading