Skip to content

Commit

Permalink
final cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianM0 committed Apr 11, 2024
1 parent 1d1cc83 commit 546001f
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 114 deletions.
14 changes: 9 additions & 5 deletions configs/data/molbind.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
_target_: molbind.data.dataloaders.load_combined_loader
data_dir: ${paths.data_dir}
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_test_split: [55_000, 5_000, 10_000]
num_workers: 0
pin_memory: False
central_modality: "smiles"
modalities:
- "selfies"
train_frac : 0.8
val_frac : 0.2
seed: 42
fraction_data: 1.0
dataset_path: "subset.csv"
batch_size: 64
13 changes: 2 additions & 11 deletions configs/logger/wandb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@

wandb:
_target_: lightning.pytorch.loggers.wandb.WandbLogger
# name: "" # name of the run (normally generated by wandb)
save_dir: "${paths.output_dir}"
offline: False
id: null # pass correct id to resume experiment!
anonymous: null # enable anonymous logging
project: "lightning-hydra-template"
log_model: False # upload lightning ckpts
prefix: "" # a string to put at the beginning of metric keys
# entity: "" # set to name of your wandb team
group: ""
tags: []
job_type: ""
project: "molbind"
entity: "adrianmirza"
19 changes: 10 additions & 9 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: mnist
- model: mnist
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
- hydra: default
- data: molbind
- model: molbind
- logger: wandb
# - callbacks: default
# - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
# - trainer: default
# - paths: default
# - extras: default
# - hydra: default

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
Expand Down Expand Up @@ -46,4 +47,4 @@ test: True
ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null
seed: 42
11 changes: 7 additions & 4 deletions src/molbind/models/components/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@ def __init__(
):
super(ProjectionHead, self).__init__()
# build projection head
self.projection_head = self.build_projection_head(dims, activation, batch_norm)
self.projection_head = self._build_projection_head(dims, activation, batch_norm)

def build_projection_head(
self, dims, activation, batch_norm=False
def _build_projection_head(
self,
dims: List[int],
activation: Union[str, List[str]],
batch_norm: bool = False,
) -> nn.Sequential:
# Build projection head dynamically based on the length of dims
layers = []
Expand All @@ -35,7 +38,7 @@ def build_projection_head(

def _get_activation(self, activation: Union[str, List[str]]):
if isinstance(activation, str):
return ACTIVATION_RESOLVER.resolve(activation)
return ACTIVATION_RESOLVER.make(activation)
elif isinstance(activation, list):
# In case you want multiple activation functions in sequence
return nn.Sequential(*[self._get_activation(act) for act in activation])
Expand Down
146 changes: 73 additions & 73 deletions src/molbind/models/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,76 +45,76 @@ def configure_optimizers(self):
)


def train_molbind(config: Dict = None):
wandb_logger = L.loggers.WandbLogger(
project=config.wandb.project_name, entity=config.wandb.entity
)

device_count = torch.cuda.device_count()
trainer = L.Trainer(
max_epochs=100,
accelerator="cuda",
log_every_n_steps=10,
logger=wandb_logger,
devices=device_count if device_count > 1 else "auto",
strategy="ddp" if device_count > 1 else "auto",
)

train_modality_data = {}
valid_modality_data = {}

# Example SMILES - SELFIES modality pair:
data = pl.read_csv(config.data.dataset_path)
shuffled_data = data.sample(
fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed
)
dataset_length = len(shuffled_data)
valid_shuffled_data = shuffled_data.tail(
int(config.data.valid_frac * dataset_length)
)
train_shuffled_data = shuffled_data.head(
int(config.data.train_frac * dataset_length)
)

columns = shuffled_data.columns
# extract non-central modalities (i.e. not the central modality smiles)
non_central_modalities = config.data.modalities

for column in columns:
if column in non_central_modalities:
# drop nan for specific pair
train_modality_smiles_pair = train_shuffled_data[
["smiles", column]
].drop_nulls()
valid_modality_smiles_pair = valid_shuffled_data[
["smiles", column]
].drop_nulls()

train_modality_data[column] = [
train_modality_smiles_pair["smiles"].to_list(),
train_modality_smiles_pair[column].to_list(),
]
valid_modality_data[column] = [
valid_modality_smiles_pair["smiles"].to_list(),
valid_modality_smiles_pair[column].to_list(),
]

combined_loader = load_combined_loader(
data_modalities=train_modality_data,
batch_size=config.data.batch_size,
shuffle=True,
num_workers=1,
)

valid_dataloader = load_combined_loader(
data_modalities=valid_modality_data,
batch_size=config.data.batch_size,
shuffle=False,
num_workers=1,
)

trainer.fit(
MolBindModule(config),
train_dataloaders=combined_loader,
val_dataloaders=valid_dataloader,
)
# def train_molbind(config: Dict = None):
# wandb_logger = L.loggers.WandbLogger(
# project=config.wandb.project_name, entity=config.wandb.entity
# )

# device_count = torch.cuda.device_count()
# trainer = L.Trainer(
# max_epochs=100,
# accelerator="cuda",
# log_every_n_steps=10,
# logger=wandb_logger,
# devices=device_count if device_count > 1 else "auto",
# strategy="ddp" if device_count > 1 else "auto",
# )

# train_modality_data = {}
# valid_modality_data = {}

# data = pl.read_csv(config.data.dataset_path)
# shuffled_data = data.sample(
# fraction=config.data.fraction_data, shuffle=True, seed=config.data.seed
# )
# dataset_length = len(shuffled_data)
# valid_shuffled_data = shuffled_data.tail(
# int(config.data.valid_frac * dataset_length)
# )
# train_shuffled_data = shuffled_data.head(
# int(config.data.train_frac * dataset_length)
# )

# columns = shuffled_data.columns
# # extract non-central modalities (i.e. not the central modality smiles)
# non_central_modalities = config.data.modalities
# central_modality = config.data.central_modality

# for column in columns:
# if column in non_central_modalities:
# # drop nan for specific pair
# train_modality_pair = train_shuffled_data[
# [central_modality, column]
# ].drop_nulls()
# valid_modality_pair = valid_shuffled_data[
# [central_modality, column]
# ].drop_nulls()

# train_modality_data[column] = [
# train_modality_pair[central_modality].to_list(),
# train_modality_pair[column].to_list(),
# ]
# valid_modality_data[column] = [
# valid_modality_pair[central_modality].to_list(),
# valid_modality_pair[column].to_list(),
# ]

# combined_loader = load_combined_loader(
# data_modalities=train_modality_data,
# batch_size=config.data.batch_size,
# shuffle=True,
# num_workers=1,
# )

# valid_dataloader = load_combined_loader(
# data_modalities=valid_modality_data,
# batch_size=config.data.batch_size,
# shuffle=False,
# num_workers=1,
# )

# trainer.fit(
# MolBindModule(config),
# train_dataloaders=combined_loader,
# val_dataloaders=valid_dataloader,
# )
21 changes: 12 additions & 9 deletions src/molbind/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def __init__(self, cfg):
super(MolBind, self).__init__()

modalities = cfg.data.modalities
# Instantiate all encoders in modalities
central_modality = cfg.data.central_modality
self.central_modality = central_modality

self.dict_encoders = {"smiles": SmilesEncoder(**cfg.model.encoders["smiles"])}
# Instantiate all encoders and projection heads
self.dict_encoders = {central_modality: SmilesEncoder(**cfg.model.encoders[central_modality])}
self.dict_projection_heads = {
"smiles": ProjectionHead(**cfg.model.projection_heads["smiles"])
central_modality: ProjectionHead(**cfg.model.projection_heads[central_modality])
}

# Add other modalities to `dict_encoders` and `dict_projection_heads
for modality in modalities:
if modality not in [*AVAILABLE_ENCODERS]:
raise ValueError(f"Modality {modality} not supported yet.")
Expand All @@ -40,7 +42,7 @@ def __init__(self, cfg):
**cfg.model.projection_heads[modality]
)

# convert to nn.ModuleDict
# convert dicts to nn.ModuleDict
self.dict_encoders = nn.ModuleDict(self.dict_encoders)
self.dict_projection_heads = nn.ModuleDict(self.dict_projection_heads)

Expand All @@ -51,19 +53,20 @@ def forward(
# input_data = [data, batch_index, dataloader_index]
input_data, _, _ = input_data
# input_data is a dictionary with (smiles, modality) pairs (where the central modality is at index 0)
central_modality = [*input_data][0]
modality = [*input_data][1]
# store embeddings as store_embeddings[modality] = (smiles_embedding, modality_embedding)
# forward through respective encoder
smiles_embedding = self.dict_encoders["smiles"].forward(input_data["smiles"])
smiles_embedding = self.dict_encoders[central_modality].forward(input_data[central_modality])
modality_embedding = self.dict_encoders[modality].forward(input_data[modality])
smiles_embedding_projected = self.dict_projection_heads["smiles"](
central_modality_embedding_projected = self.dict_projection_heads[central_modality](
smiles_embedding
)
modality_embedding_projected = self.dict_projection_heads[modality](
modality_embedding
)
# projection head
store_embeddings["smiles"] = smiles_embedding_projected
# projection heads
store_embeddings[central_modality] = central_modality_embedding_projected
store_embeddings[modality] = modality_embedding_projected
return store_embeddings

Expand Down
8 changes: 5 additions & 3 deletions src/molbind/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# (necessary before importing any local modules e.g. `from molbind import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
Expand All @@ -26,7 +26,7 @@
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #

from src.utils import (
from molbind.utils import (
RankedLogger,
extras,
get_metric_value,
Expand Down Expand Up @@ -67,7 +67,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger
)

object_dict = {
"cfg": cfg,
Expand Down

0 comments on commit 546001f

Please sign in to comment.