Skip to content

Commit

Permalink
Merge branch 'main' into llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Jun 13, 2024
2 parents b4ce11b + c67f09f commit 1f2bb8b
Show file tree
Hide file tree
Showing 33 changed files with 63 additions and 144 deletions.
31 changes: 24 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,22 @@ The full paper can be found [here](https://arxiv.org/abs/2405.04657).

## Features

- **Multiple Generative Modes:** ACEGEN facilitates the generation of chemical libraries with different modes: de novo generation, scaffold decoration, and fragment linking.
- **RL Algorithms:** ACEGEN offers task optimization with various reinforcement learning algorithms such as [Proximal Policy Optimization (PPO)][1], [Advantage Actor-Critic (A2C)][2], [Reinforce][3], [Reinvent][4], and [Augmented Hill-Climb (AHC)][5].
- **Other Algorithms:** ACEGEN also includes [Direct Preference Optimization (DPO)][8] and Hill Climbing.
- **Pre-trained Models:** ACEGEN contains pre-trained models including Gated Recurrent Unit (GRU), Long Short-Term Memory (LSTM), GPT-2, LLama2 and Mamba.
- **Scoring Functions :** ACEGEN relies on MolScore, a comprehensive scoring function suite for generative chemistry, to evaluate the quality of the generated molecules.
- **Customization Support:** ACEGEN provides tutorials for integrating custom models and custom scoring functions, ensuring flexibility for advanced users.
- __**Multiple Generative Modes:**__
ACEGEN facilitates the generation of chemical libraries with different modes: de novo generation, scaffold decoration, and fragment linking.

- __**RL Algorithms:**__
ACEGEN offers task optimization with various reinforcement learning algorithms such as [Proximal Policy Optimization (PPO)][1], [Advantage Actor-Critic (A2C)][2], [Reinforce][3], [Reinvent][4], and [Augmented Hill-Climb (AHC)][5].

- __**Other Algorithms:**__
ACEGEN also includes [Direct Preference Optimization (DPO)][8] and Hill Climbing.

- __**Pre-trained Models:**__ ACEGEN contains pre-trained models including Gated Recurrent Unit (GRU), Long Short-Term Memory (LSTM), GPT-2, LLama2 and Mamba.

- __**Scoring Functions :**__
ACEGEN defaults to MolScore, a comprehensive scoring function suite for generative chemistry, to evaluate the quality of the generated molecules. MolScore allows to train agents on single scoring functions, on entire benchmarks containing multiple scoring functions (e.g., MolOpt, GuacaMol), or using curriculum learning where the same agent is optimized on a sequence of different scoring functions.

- __**Customization Support:**__
ACEGEN provides tutorials for integrating custom models and custom scoring functions, ensuring flexibility for advanced users.

---

Expand Down Expand Up @@ -94,6 +104,9 @@ To run the training scripts for denovo generation, run the following commands:
python scripts/ppo/ppo.py --config-name config_denovo
python scripts/reinvent/reinvent.py --config-name config_denovo
python scripts/ahc/ahc.py --config-name config_denovo
python scripts/dpo/dpo.py --config-name config_denovo
python scripts/hill_climb/hill_climb.py --config-name config_denovo


To run the training scripts for scaffold decoration, run the following commands (requires installation of promptsmiles):

Expand All @@ -102,6 +115,8 @@ To run the training scripts for scaffold decoration, run the following commands
python scripts/ppo/ppo.py --config-name config_scaffold
python scripts/reinvent/reinvent.py --config-name config_scaffold
python scripts/ahc/ahc.py --config-name config_scaffold
python scripts/dpo/dpo.py --config-name config_scaffold
python scripts/hill_climb/hill_climb.py --config-name config_scaffold

To run the training scripts for fragment linking, run the following commands (requires installation of promptsmiles):

Expand All @@ -110,6 +125,8 @@ To run the training scripts for fragment linking, run the following commands (re
python scripts/ppo/ppo.py --config-name config_linking
python scripts/reinvent/reinvent.py --config-name config_linking
python scripts/ahc/ahc.py --config-name config_linking
python scripts/dpo/dpo.py --config-name config_linking
python scripts/hill_climb/hill_climb.py --config-name config_linking

### Advanced usage

Expand Down Expand Up @@ -161,7 +178,7 @@ We provide a variety of default priors that can be selected in the configuration
- to select set the field `model` to `gpt2` in any configuration file


- A Mamba model (requires installation of mamba-ssm library)
- A Mamba model (requires installation of `mamba-ssm` library)
- pre-training dataset: [ChEMBL](https://www.ebi.ac.uk/chembl/)
- number of parameters: 2,809,216
- to select set the field `model` to `mamba` in any configuration file
Expand Down
13 changes: 3 additions & 10 deletions scripts/pretrain/pretrain_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from acegen.vocabulary import SMILESVocabulary, tokenizer_options
from rdkit import Chem
from tensordict.utils import remove_duplicates
from tokenizer import Tokenizer
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
Expand All @@ -28,12 +27,6 @@
from tqdm import tqdm


logging.basicConfig(
level=logging.INFO,
filename="pretraining.log",
format="%(asctime)s - %(levelname)s - %(message)s",
)

# hydra outputs saved in /tmp
os.chdir("/tmp")

Expand Down Expand Up @@ -108,7 +101,7 @@ def main(cfg: "DictConfig"):
# Load vocabulary from a file
vocabulary = SMILESVocabulary()
vocabulary.load_state_dict(torch.load(save_path))
vocabulary.tokenizer = Tokenizer()
vocabulary.tokenizer = tokenizer_options[cfg.tokenizer]()

logging.info("\nPreparing dataset and dataloader...")
if master:
Expand Down Expand Up @@ -188,11 +181,11 @@ def valid_smiles(smiles_list):
)

logger = None
if cfg.logger_backend:
if cfg.logger_backend and master:
logging.info("\nCreating logger...")
logger = get_logger(
cfg.logger_backend,
logger_name=Path.cwd(),
logger_name=cfg.model_log_dir,
experiment_name=cfg.agent_name,
wandb_kwargs={
"config": dict(cfg),
Expand Down
8 changes: 1 addition & 7 deletions scripts/pretrain/pretrain_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@
except:
_has_wandb = False

logging.basicConfig(
level=logging.INFO,
filename="pretraining.log",
format="%(asctime)s - %(levelname)s - %(message)s",
)

# hydra outputs saved in /tmp
os.chdir("/tmp")

Expand Down Expand Up @@ -127,7 +121,7 @@ def main(cfg: "DictConfig"):
logging.info("\nCreating logger...")
logger = get_logger(
cfg.logger_backend,
logger_name=Path.cwd(),
logger_name=cfg.model_log_dir,
experiment_name=cfg.agent_name,
wandb_kwargs={
"config": dict(cfg),
Expand Down
35 changes: 0 additions & 35 deletions scripts/pretrain/tokenizer.py

This file was deleted.

5 changes: 3 additions & 2 deletions scripts/sac/pretrain_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ def create_env_fn():
"train/reward": episode_rewards.mean().item(),
"train/min_reward": episode_rewards.min().item(),
"train/max_reward": episode_rewards.max().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
"train/episode_length": episode_length.sum().item() / len(
episode_length
),
}
)
if logger:
Expand Down
5 changes: 3 additions & 2 deletions scripts/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ def create_env_fn():
"train/reward": episode_rewards.mean().item(),
"train/min_reward": episode_rewards.min().item(),
"train/max_reward": episode_rewards.max().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
"train/episode_length": episode_length.sum().item() / len(
episode_length
),
}
)
if logger:
Expand Down
4 changes: 1 addition & 3 deletions tests/check_scripts/run_a2c_denovo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/a2c/a2c.py \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_a2c_fragment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/a2c/a2c.py --config-name config_fragment \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_a2c_scaffold.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/a2c/a2c.py --config-name config_scaffold \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_ahc_denovo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/ahc/ahc.py \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_ahc_fragment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/ahc/ahc.py --config-name config_fragment \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_ahc_scaffold.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/ahc/ahc.py --config-name config_scaffold \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_dpo_denovo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/dpo/dpo.py \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_dpo_fragment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/dpo/dpo.py --config-name config_fragment \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_dpo_scaffold.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/dpo/dpo.py --config-name config_scaffold \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_hill_climb_denovo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/hill_climb/hill_climb.py \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_hill_climb_fragment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/hill_climb/hill_climb.py --config-name config_fragmen
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_hill_climb_scaffold.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/hill_climb/hill_climb.py --config-name config_scaffol
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
4 changes: 1 addition & 3 deletions tests/check_scripts/run_ppo_denovo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python $PYTHONPATH/scripts/ppo/ppo.py \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
log_dir=/tmp/"$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
Expand All @@ -36,5 +36,3 @@ if [ $exit_status -eq 0 ]; then
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
Loading

0 comments on commit 1f2bb8b

Please sign in to comment.