Skip to content

Commit

Permalink
Merge pull request #16 from lamalab-org/hydra
Browse files Browse the repository at this point in the history
Run model from `hydra` setup
  • Loading branch information
AdrianM0 authored Apr 11, 2024
2 parents 25b3bf5 + 95f5b99 commit 0d5a79b
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 168 deletions.
6 changes: 0 additions & 6 deletions .env.example

This file was deleted.

5 changes: 2 additions & 3 deletions configs/data/molbind.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
_target_: molbind.data.dataloaders.load_combined_loader
central_modality: "smiles"
modalities:
- "selfies"
train_frac : 0.8
val_frac : 0.2
valid_frac : 0.2
seed: 42
fraction_data: 1.0
dataset_path: "subset.csv"
dataset_path: "${paths.data_dir}/subset.csv"
batch_size: 64
5 changes: 1 addition & 4 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# https://wandb.ai

wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
offline: False
project: "molbind"
entity: "adrianmirza"
entity: "wandb_username"
29 changes: 29 additions & 0 deletions configs/model/molbind.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# model architecture
encoders:
smiles:
pretrained: True
freeze_encoder: False

selfies:
pretrained: True
freeze_encoder: False

projection_heads:
smiles:
dims: [768, 256, 128]
activation: LeakyReLU
batch_norm: False
selfies:
dims: [768, 256, 128]
activation: LeakyReLU
batch_norm: False

optimizer:
lr: 0.0001
weight_decay: 0.0001

loss:
temperature: 0.1

# compile model for faster training with pytorch 2.0
compile: false
3 changes: 1 addition & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@ defaults:
- data: molbind
- model: molbind
- logger: wandb
- paths: default
# - callbacks: default
# - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
# - trainer: default
# - paths: default
# - extras: default
# - hydra: default

Expand Down
File renamed without changes.
120 changes: 91 additions & 29 deletions experiments/train.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,95 @@
from molbind.models.lightning_module import train_molbind
import hydra
import pytorch_lightning as L
import polars as pl
from molbind.data.dataloaders import load_combined_loader
from molbind.models.lightning_module import MolBindModule
from omegaconf import DictConfig
import torch
import rootutils

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

if __name__ == "__main__":
config = {
"wandb": {"entity": "adrianmirza", "project_name": "embedbind"},
"model": {
"projection_heads": {
"selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
"smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False},
},
"encoders": {
"smiles": {"pretrained": True, "freeze_encoder": False},
"selfies": {"pretrained": True, "freeze_encoder": False},
},
"optimizer": {"lr": 1e-4, "weight_decay": 1e-4},
},
"loss": {"temperature": 0.1},
"data": {
"central_modality": "smiles",
"modalities": ["selfies"],
"dataset_path": "subset.csv",
"train_frac": 0.8,
"valid_frac": 0.2,
"seed": 42,
"fraction_data": 1,
"batch_size": 64,
},
}

config = DictConfig(config)

def train_molbind(config: DictConfig):
wandb_logger = L.loggers.WandbLogger(**config.logger.wandb)

device_count = torch.cuda.device_count()
trainer = L.Trainer(
max_epochs=100,
accelerator="cuda",
log_every_n_steps=10,
logger=wandb_logger,
devices=device_count if device_count > 1 else "auto",
strategy="ddp" if device_count > 1 else "auto",
)

train_modality_data = {}
valid_modality_data = {}

# Example SMILES - SELFIES modality pair:
data = pl.read_csv(config.data.dataset_path)
shuffled_data = data.sample(
fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed
)
dataset_length = len(shuffled_data)
valid_shuffled_data = shuffled_data.tail(
int(config.data.valid_frac * dataset_length)
)
train_shuffled_data = shuffled_data.head(
int(config.data.train_frac * dataset_length)
)

columns = shuffled_data.columns
# extract non-central modalities
non_central_modalities = config.data.modalities
central_modality = config.data.central_modality

for column in columns:
if column in non_central_modalities:
# drop nan for specific pair
train_modality_pair = train_shuffled_data[
[central_modality, column]
].drop_nulls()
valid_modality_pair = valid_shuffled_data[
[central_modality, column]
].drop_nulls()

train_modality_data[column] = [
train_modality_pair[central_modality].to_list(),
train_modality_pair[column].to_list(),
]
valid_modality_data[column] = [
valid_modality_pair[central_modality].to_list(),
valid_modality_pair[column].to_list(),
]

combined_loader = load_combined_loader(
central_modality=config.data.central_modality,
data_modalities=train_modality_data,
batch_size=config.data.batch_size,
shuffle=True,
num_workers=1,
)

valid_dataloader = load_combined_loader(
central_modality=config.data.central_modality,
data_modalities=valid_modality_data,
batch_size=config.data.batch_size,
shuffle=False,
num_workers=1,
)

trainer.fit(
MolBindModule(config),
train_dataloaders=combined_loader,
val_dataloaders=valid_dataloader,
)


@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(config: DictConfig):
train_molbind(config)


if __name__ == "__main__":
main()
39 changes: 28 additions & 11 deletions src/molbind/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER
from networkx import Graph
from typing import Tuple
from typing import Tuple, Optional
from torch import Tensor


Expand All @@ -23,9 +23,13 @@

class StringDataset(Dataset):
def __init__(
self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256
self,
dataset: Tuple[Tensor, Tensor],
modality: str,
central_modality: str = "smiles",
context_length: Optional[int] = 256,
):
"""_summary_
"""Dataset for string modalities.
Args:
dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1)
Expand All @@ -34,14 +38,21 @@ def __init__(
"""
assert len(dataset) == 2
assert len(dataset[0]) == len(dataset[1])

self.modality = modality
self.tokenized_smiles = STRING_TOKENIZERS["smiles"](
self.central_modality = central_modality

assert MODALITY_DATA_TYPES[modality] == str
assert MODALITY_DATA_TYPES[central_modality] == str

self.tokenized_central_modality = STRING_TOKENIZERS[central_modality](
dataset[0],
padding="max_length",
truncation=True,
return_tensors="pt",
max_length=context_length,
)

self.tokenized_string = STRING_TOKENIZERS[modality](
dataset[1],
padding="max_length",
Expand All @@ -51,13 +62,13 @@ def __init__(
)

def __len__(self):
return len(self.tokenized_smiles.input_ids)
return len(self.tokenized_central_modality.input_ids)

def __getitem__(self, idx):
return {
"smiles": (
self.tokenized_smiles.input_ids[idx],
self.tokenized_smiles.attention_mask[idx],
self.central_modality: (
self.tokenized_central_modality.input_ids[idx],
self.tokenized_central_modality.attention_mask[idx],
),
self.modality: (
self.tokenized_string.input_ids[idx],
Expand All @@ -75,25 +86,31 @@ def load_combined_loader(
batch_size: int,
shuffle: bool,
num_workers: int,
central_modality: str = "smiles",
drop_last: bool = True,
) -> CombinedLoader:
"""Combine multiple dataloaders for different modalities into a single dataloader.
Args:
data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality)
data_modalities (dict): data inputs for each modality as pairs of (central_modality, modality)
batch_size (int): batch size for the dataloader
shuffle (bool): shuffle the dataset
num_workers (int): number of workers for the dataloader
drop_last (bool, optional): whether to drop the last batch; defaults to True.
central_modality (str, optional): central modality to use for the dataset; defaults to "smiles".
Returns:
CombinedLoader: a combined dataloader for all the modalities
"""
loaders = {}

for modality in [*data_modalities]:
if MODALITY_DATA_TYPES[modality] == str:
dataset_instance = StringDataset(data_modalities[modality], modality)
dataset_instance = StringDataset(
dataset=data_modalities[modality],
modality=modality,
central_modality=central_modality,
context_length=256,
)
loaders[modality] = DataLoader(
dataset_instance,
batch_size=batch_size,
Expand Down
19 changes: 0 additions & 19 deletions src/molbind/models/components/pooler.py

This file was deleted.

Loading

0 comments on commit 0d5a79b

Please sign in to comment.