Skip to content

Commit

Permalink
Merge pull request #30 from Acellera/pretrain
Browse files Browse the repository at this point in the history
Improve pre-training defaults for GRU
  • Loading branch information
albertbou92 authored May 27, 2024
2 parents a79dc1e + 8ae55ed commit a57bdab
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 27 deletions.
19 changes: 14 additions & 5 deletions acegen/data/chem_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np

from rdkit.Chem import AllChem as Chem, Draw
Expand Down Expand Up @@ -37,12 +39,19 @@ def fraction_valid(mol_list):
def randomize_smiles(smiles, random_type="restricted"):
"""Randomize a SMILES string using restricted or unrestricted randomization."""
mol = get_mol(smiles)
if random_type == "restricted":
return Chem.MolToSmiles(mol, doRandom=True, canonical=True)
elif random_type == "unrestricted":
return Chem.MolToSmiles(mol, doRandom=True, canonical=False)
if mol:
if random_type == "restricted":
new_atom_order = list(range(mol.GetNumAtoms()))
np.random.shuffle(new_atom_order)
random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order)
return Chem.MolToSmiles(random_mol, canonical=False)
elif random_type == "unrestricted":
return Chem.MolToSmiles(mol, doRandom=True, canonical=False)
else:
raise ValueError(f"Invalid randomization type: {random_type}")
else:
raise ValueError(f"Invalid randomization type: {random_type}")
warnings.warning(f"Could not randomize SMILES string: {smiles}")
return smiles


def draw(mol_list, molsPerRow=5, subImgSize=(300, 300)):
Expand Down
12 changes: 9 additions & 3 deletions acegen/data/smiles_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gzip
import logging
import os
from pathlib import Path
Expand All @@ -22,9 +23,14 @@
def load_dataset(file_path):
"""Reads a list of SMILES from file_path."""
smiles_list = []
with open(file_path, "r") as f:
for line in tqdm(f, desc="Load Samples"):
smiles_list.append(line.split()[0])
if any(["gz" in ext for ext in os.path.basename(file_path).split(".")[1:]]):
with gzip.open(file_path) as f:
for line in tqdm(f, desc="Load Samples"):
smiles_list.append(line.decode("utf-8").split()[0])
else:
with open(file_path, "r") as f:
for line in tqdm(f, desc="Load Samples"):
smiles_list.append(line.split()[0])

return smiles_list

Expand Down
14 changes: 7 additions & 7 deletions scripts/pretrain/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ dataset_log_dir: /tmp/pretrain # if recomputing dataset, save it here
# Model configuration
model: gru # gru, lstm, or gpt2
custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model)
model_log_dir: /tmp/pretrain # save model here
model_log_dir: test #/tmp/pretrain # save model here

# Training configuration
lr: 0.0001
lr: 0.001
lr_scheduler: StepLR
lr_scheduler_kwargs:
step_size: 1
gamma: 1.0 # no decay
epochs: 10
batch_size: 8
randomize_smiles: False
step_size: 500
gamma: 0.97 # 1.0 = no decay
epochs: 50
batch_size: 128
randomize_smiles: True # Sample a random variant during training, therefore, for 10-fold augmentation on a dataset for 5 epochs, do 10*5=50 epochs.
num_test_smiles: 100
6 changes: 3 additions & 3 deletions scripts/pretrain/pretrain_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def valid_smiles(smiles_list):
actor_optimizer.step()
actor_losses[step] = loss_actor.item()

# Decay learning rate
lr_scheduler.step()

# Generate test smiles
smiles = generate_complete_smiles(test_env, actor_inference, max_length=100)
num_valid_smiles = valid_smiles(
Expand All @@ -242,9 +245,6 @@ def valid_smiles(smiles_list):
)
logger.log_scalar("lr", lr_scheduler.get_lr()[0], step=epoch)

# Decay learning rate
lr_scheduler.step()

save_path = Path(cfg.model_log_dir) / f"pretrained_actor_epoch_{epoch}.pt"
torch.save(actor_training.state_dict(), save_path)

Expand Down
6 changes: 3 additions & 3 deletions scripts/pretrain/pretrain_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,11 @@ def main(cfg: "DictConfig"):
"mols", wandb.Image(image), step=total_smiles
)
logger.log_scalar(
"lr", lr_scheduler.get_lr()[0], step=total_smiles
"lr", lr_scheduler.get_last_lr()[0], step=total_smiles
)

# Decay learning rate
lr_scheduler.step()
# Decay learning rate
lr_scheduler.step()

save_path = Path(cfg.model_log_dir) / f"pretrained_actor_epoch_{epoch}.pt"
torch.save(actor_training.state_dict(), save_path)
Expand Down
5 changes: 2 additions & 3 deletions scripts/sac/pretrain_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ def create_env_fn():
"train/reward": episode_rewards.mean().item(),
"train/min_reward": episode_rewards.min().item(),
"train/max_reward": episode_rewards.max().item(),
"train/episode_length": episode_length.sum().item() / len(
episode_length
),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
if logger:
Expand Down
5 changes: 2 additions & 3 deletions scripts/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,8 @@ def create_env_fn():
"train/reward": episode_rewards.mean().item(),
"train/min_reward": episode_rewards.min().item(),
"train/max_reward": episode_rewards.max().item(),
"train/episode_length": episode_length.sum().item() / len(
episode_length
),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
if logger:
Expand Down

0 comments on commit a57bdab

Please sign in to comment.