diff --git a/.env.example b/.env.example deleted file mode 100644 index a790e320..00000000 --- a/.env.example +++ /dev/null @@ -1,6 +0,0 @@ -# example of file for storing private and user specific environment variables, like keys or system paths -# rename it to ".env" (excluded from version control by default) -# .env is loaded by train.py automatically -# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} - -MY_VAR="/home/user/my/system/path" diff --git a/configs/data/molbind.yaml b/configs/data/molbind.yaml index c507dc3a..4d42cbd8 100644 --- a/configs/data/molbind.yaml +++ b/configs/data/molbind.yaml @@ -1,10 +1,9 @@ -_target_: molbind.data.dataloaders.load_combined_loader central_modality: "smiles" modalities: - "selfies" train_frac : 0.8 -val_frac : 0.2 +valid_frac : 0.2 seed: 42 fraction_data: 1.0 -dataset_path: "subset.csv" +dataset_path: "${paths.data_dir}/subset.csv" batch_size: 64 \ No newline at end of file diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml index 2af3344f..0640d5ab 100644 --- a/configs/logger/wandb.yaml +++ b/configs/logger/wandb.yaml @@ -1,7 +1,4 @@ -# https://wandb.ai - wandb: - _target_: lightning.pytorch.loggers.wandb.WandbLogger offline: False project: "molbind" - entity: "adrianmirza" + entity: "wandb_username" \ No newline at end of file diff --git a/configs/model/molbind.yaml b/configs/model/molbind.yaml new file mode 100644 index 00000000..38ea079c --- /dev/null +++ b/configs/model/molbind.yaml @@ -0,0 +1,29 @@ +# model architecture +encoders: + smiles: + pretrained: True + freeze_encoder: False + + selfies: + pretrained: True + freeze_encoder: False + +projection_heads: + smiles: + dims: [768, 256, 128] + activation: LeakyReLU + batch_norm: False + selfies: + dims: [768, 256, 128] + activation: LeakyReLU + batch_norm: False + +optimizer: + lr: 0.0001 + weight_decay: 0.0001 + +loss: + temperature: 0.1 + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/train.yaml b/configs/train.yaml index f71d7cc7..3fe2bf35 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -7,10 +7,9 @@ defaults: - data: molbind - model: molbind - logger: wandb + - paths: default # - callbacks: default - # - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) # - trainer: default - # - paths: default # - extras: default # - hydra: default diff --git a/experiments/subset.csv b/data/subset.csv similarity index 100% rename from experiments/subset.csv rename to data/subset.csv diff --git a/experiments/train.py b/experiments/train.py index 2f524da7..3c86cb17 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -1,33 +1,95 @@ -from molbind.models.lightning_module import train_molbind +import hydra +import pytorch_lightning as L +import polars as pl +from molbind.data.dataloaders import load_combined_loader +from molbind.models.lightning_module import MolBindModule from omegaconf import DictConfig +import torch +import rootutils +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -if __name__ == "__main__": - config = { - "wandb": {"entity": "adrianmirza", "project_name": "embedbind"}, - "model": { - "projection_heads": { - "selfies": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, - "smiles": {"dims": [256, 128], "activation": "leakyrelu", "batch_norm": False}, - }, - "encoders": { - "smiles": {"pretrained": True, "freeze_encoder": False}, - "selfies": {"pretrained": True, "freeze_encoder": False}, - }, - "optimizer": {"lr": 1e-4, "weight_decay": 1e-4}, - }, - "loss": {"temperature": 0.1}, - "data": { - "central_modality": "smiles", - "modalities": ["selfies"], - "dataset_path": "subset.csv", - "train_frac": 0.8, - "valid_frac": 0.2, - "seed": 42, - "fraction_data": 1, - "batch_size": 64, - }, - } - - config = DictConfig(config) + +def train_molbind(config: DictConfig): + wandb_logger = L.loggers.WandbLogger(**config.logger.wandb) + + device_count = torch.cuda.device_count() + trainer = L.Trainer( + max_epochs=100, + accelerator="cuda", + log_every_n_steps=10, + logger=wandb_logger, + devices=device_count if device_count > 1 else "auto", + strategy="ddp" if device_count > 1 else "auto", + ) + + train_modality_data = {} + valid_modality_data = {} + + # Example SMILES - SELFIES modality pair: + data = pl.read_csv(config.data.dataset_path) + shuffled_data = data.sample( + fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed + ) + dataset_length = len(shuffled_data) + valid_shuffled_data = shuffled_data.tail( + int(config.data.valid_frac * dataset_length) + ) + train_shuffled_data = shuffled_data.head( + int(config.data.train_frac * dataset_length) + ) + + columns = shuffled_data.columns + # extract non-central modalities + non_central_modalities = config.data.modalities + central_modality = config.data.central_modality + + for column in columns: + if column in non_central_modalities: + # drop nan for specific pair + train_modality_pair = train_shuffled_data[ + [central_modality, column] + ].drop_nulls() + valid_modality_pair = valid_shuffled_data[ + [central_modality, column] + ].drop_nulls() + + train_modality_data[column] = [ + train_modality_pair[central_modality].to_list(), + train_modality_pair[column].to_list(), + ] + valid_modality_data[column] = [ + valid_modality_pair[central_modality].to_list(), + valid_modality_pair[column].to_list(), + ] + + combined_loader = load_combined_loader( + central_modality=config.data.central_modality, + data_modalities=train_modality_data, + batch_size=config.data.batch_size, + shuffle=True, + num_workers=1, + ) + + valid_dataloader = load_combined_loader( + central_modality=config.data.central_modality, + data_modalities=valid_modality_data, + batch_size=config.data.batch_size, + shuffle=False, + num_workers=1, + ) + + trainer.fit( + MolBindModule(config), + train_dataloaders=combined_loader, + val_dataloaders=valid_dataloader, + ) + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(config: DictConfig): train_molbind(config) + + +if __name__ == "__main__": + main() diff --git a/src/molbind/data/dataloaders.py b/src/molbind/data/dataloaders.py index de2bbd13..5d9d7875 100644 --- a/src/molbind/data/dataloaders.py +++ b/src/molbind/data/dataloaders.py @@ -2,7 +2,7 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from molbind.data.components.tokenizers import SMILES_TOKENIZER, SELFIES_TOKENIZER from networkx import Graph -from typing import Tuple +from typing import Tuple, Optional from torch import Tensor @@ -23,9 +23,13 @@ class StringDataset(Dataset): def __init__( - self, dataset: Tuple[Tensor, Tensor], modality: str, context_length=256 + self, + dataset: Tuple[Tensor, Tensor], + modality: str, + central_modality: str = "smiles", + context_length: Optional[int] = 256, ): - """_summary_ + """Dataset for string modalities. Args: dataset (Tuple[Tensor, Tensor]): pair of SMILES and data for the modality (smiles always index 0, modality index 1) @@ -34,14 +38,21 @@ def __init__( """ assert len(dataset) == 2 assert len(dataset[0]) == len(dataset[1]) + self.modality = modality - self.tokenized_smiles = STRING_TOKENIZERS["smiles"]( + self.central_modality = central_modality + + assert MODALITY_DATA_TYPES[modality] == str + assert MODALITY_DATA_TYPES[central_modality] == str + + self.tokenized_central_modality = STRING_TOKENIZERS[central_modality]( dataset[0], padding="max_length", truncation=True, return_tensors="pt", max_length=context_length, ) + self.tokenized_string = STRING_TOKENIZERS[modality]( dataset[1], padding="max_length", @@ -51,13 +62,13 @@ def __init__( ) def __len__(self): - return len(self.tokenized_smiles.input_ids) + return len(self.tokenized_central_modality.input_ids) def __getitem__(self, idx): return { - "smiles": ( - self.tokenized_smiles.input_ids[idx], - self.tokenized_smiles.attention_mask[idx], + self.central_modality: ( + self.tokenized_central_modality.input_ids[idx], + self.tokenized_central_modality.attention_mask[idx], ), self.modality: ( self.tokenized_string.input_ids[idx], @@ -75,17 +86,18 @@ def load_combined_loader( batch_size: int, shuffle: bool, num_workers: int, + central_modality: str = "smiles", drop_last: bool = True, ) -> CombinedLoader: """Combine multiple dataloaders for different modalities into a single dataloader. Args: - data_modalities (dict): data inputs for each modality as pairs of (SMILES, modality) + data_modalities (dict): data inputs for each modality as pairs of (central_modality, modality) batch_size (int): batch size for the dataloader shuffle (bool): shuffle the dataset num_workers (int): number of workers for the dataloader drop_last (bool, optional): whether to drop the last batch; defaults to True. - + central_modality (str, optional): central modality to use for the dataset; defaults to "smiles". Returns: CombinedLoader: a combined dataloader for all the modalities """ @@ -93,7 +105,12 @@ def load_combined_loader( for modality in [*data_modalities]: if MODALITY_DATA_TYPES[modality] == str: - dataset_instance = StringDataset(data_modalities[modality], modality) + dataset_instance = StringDataset( + dataset=data_modalities[modality], + modality=modality, + central_modality=central_modality, + context_length=256, + ) loaders[modality] = DataLoader( dataset_instance, batch_size=batch_size, diff --git a/src/molbind/models/components/pooler.py b/src/molbind/models/components/pooler.py deleted file mode 100644 index c4431360..00000000 --- a/src/molbind/models/components/pooler.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch.nn as nn -from typing import Hint - - -class EmbedPooler(nn.Module): - def __init__(self, pooling_function: Hint["max", "min", "last", "first"] = "max"): - super(EmbedPooler, self).__init__() - - def forward(self, x): - if self.pooling_function == "max": - return x.max(dim=1) - elif self.pooling_function == "min": - return x.min(dim=1) - elif self.pooling_function == "last": - return x[:,-1] - elif self.pooling_function == "first": - return x[:,0] - else: - raise ValueError("Pooling function not implemented.") \ No newline at end of file diff --git a/src/molbind/models/lightning_module.py b/src/molbind/models/lightning_module.py index b2d6a8c3..b48a0102 100644 --- a/src/molbind/models/lightning_module.py +++ b/src/molbind/models/lightning_module.py @@ -2,10 +2,8 @@ from pytorch_lightning import LightningModule from molbind.models.model import MolBind import torch -import pytorch_lightning as L -from molbind.data.dataloaders import load_combined_loader +from torch import Tensor from typing import Dict -import polars as pl class MolBindModule(LightningModule): @@ -13,15 +11,17 @@ def __init__(self, cfg): super().__init__() self.model = MolBind(cfg=cfg) self.config = cfg - self.loss = InfoNCE(temperature=cfg.loss.temperature, negative_mode="unpaired") + self.loss = InfoNCE( + temperature=cfg.model.loss.temperature, negative_mode="unpaired" + ) - def forward(self, input): - return self.model(input) + def forward(self, batch: Dict): + return self.model(batch) - def _info_nce_loss(self, z1, z2): + def _info_nce_loss(self, z1: Tensor, z2: Tensor): return self.loss(z1, z2) - def _multimodal_loss(self, embeddings_dict, prefix): + def _multimodal_loss(self, embeddings_dict: Dict, prefix: str): modality_pair = [*embeddings_dict.keys()] loss = self._info_nce_loss( embeddings_dict[modality_pair[0]], embeddings_dict[modality_pair[1]] @@ -29,13 +29,13 @@ def _multimodal_loss(self, embeddings_dict, prefix): self.log(f"{prefix}_loss", loss) return loss - def training_step(self, input): - embeddings_dict = self.forward(input) + def training_step(self, batch: Dict): + embeddings_dict = self.forward(batch) return self._multimodal_loss(embeddings_dict, "train") - def validation_step(self, input): - embeddings_dict = self.forward(input) - return self._multimodal_loss(embeddings_dict, "val") + def validation_step(self, batch: Dict): + embeddings_dict = self.forward(batch) + return self._multimodal_loss(embeddings_dict, "valid") def configure_optimizers(self): return torch.optim.AdamW( @@ -43,78 +43,3 @@ def configure_optimizers(self): lr=self.config.model.optimizer.lr, weight_decay=self.config.model.optimizer.weight_decay, ) - - -# def train_molbind(config: Dict = None): -# wandb_logger = L.loggers.WandbLogger( -# project=config.wandb.project_name, entity=config.wandb.entity -# ) - -# device_count = torch.cuda.device_count() -# trainer = L.Trainer( -# max_epochs=100, -# accelerator="cuda", -# log_every_n_steps=10, -# logger=wandb_logger, -# devices=device_count if device_count > 1 else "auto", -# strategy="ddp" if device_count > 1 else "auto", -# ) - -# train_modality_data = {} -# valid_modality_data = {} - -# data = pl.read_csv(config.data.dataset_path) -# shuffled_data = data.sample( -# fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed -# ) -# dataset_length = len(shuffled_data) -# valid_shuffled_data = shuffled_data.tail( -# int(config.data.valid_frac * dataset_length) -# ) -# train_shuffled_data = shuffled_data.head( -# int(config.data.train_frac * dataset_length) -# ) - -# columns = shuffled_data.columns -# # extract non-central modalities (i.e. not the central modality smiles) -# non_central_modalities = config.data.modalities -# central_modality = config.data.central_modality - -# for column in columns: -# if column in non_central_modalities: -# # drop nan for specific pair -# train_modality_pair = train_shuffled_data[ -# [central_modality, column] -# ].drop_nulls() -# valid_modality_pair = valid_shuffled_data[ -# [central_modality, column] -# ].drop_nulls() - -# train_modality_data[column] = [ -# train_modality_pair[central_modality].to_list(), -# train_modality_pair[column].to_list(), -# ] -# valid_modality_data[column] = [ -# valid_modality_pair[central_modality].to_list(), -# valid_modality_pair[column].to_list(), -# ] - -# combined_loader = load_combined_loader( -# data_modalities=train_modality_data, -# batch_size=config.data.batch_size, -# shuffle=True, -# num_workers=1, -# ) - -# valid_dataloader = load_combined_loader( -# data_modalities=valid_modality_data, -# batch_size=config.data.batch_size, -# shuffle=False, -# num_workers=1, -# ) - -# trainer.fit( -# MolBindModule(config), -# train_dataloaders=combined_loader, -# val_dataloaders=valid_dataloader, -# ) diff --git a/src/molbind/models/model.py b/src/molbind/models/model.py index 29fcaa73..967e4bb6 100644 --- a/src/molbind/models/model.py +++ b/src/molbind/models/model.py @@ -27,9 +27,15 @@ def __init__(self, cfg): self.central_modality = central_modality # Instantiate all encoders and projection heads - self.dict_encoders = {central_modality: SmilesEncoder(**cfg.model.encoders[central_modality])} + self.dict_encoders = { + central_modality: AVAILABLE_ENCODERS[central_modality]( + **cfg.model.encoders[central_modality] + ) + } self.dict_projection_heads = { - central_modality: ProjectionHead(**cfg.model.projection_heads[central_modality]) + central_modality: ProjectionHead( + **cfg.model.projection_heads[central_modality] + ) } # Add other modalities to `dict_encoders` and `dict_projection_heads for modality in modalities: @@ -57,11 +63,13 @@ def forward( modality = [*input_data][1] # store embeddings as store_embeddings[modality] = (smiles_embedding, modality_embedding) # forward through respective encoder - smiles_embedding = self.dict_encoders[central_modality].forward(input_data[central_modality]) - modality_embedding = self.dict_encoders[modality].forward(input_data[modality]) - central_modality_embedding_projected = self.dict_projection_heads[central_modality]( - smiles_embedding + smiles_embedding = self.dict_encoders[central_modality].forward( + input_data[central_modality] ) + modality_embedding = self.dict_encoders[modality].forward(input_data[modality]) + central_modality_embedding_projected = self.dict_projection_heads[ + central_modality + ](smiles_embedding) modality_embedding_projected = self.dict_projection_heads[modality]( modality_embedding )