diff --git a/acegen/data/__init__.py b/acegen/data/__init__.py index 3609212d..f7a73d18 100644 --- a/acegen/data/__init__.py +++ b/acegen/data/__init__.py @@ -1,2 +1,2 @@ from acegen.data.smiles_dataset import load_dataset, MolBloomDataset, SMILESDataset -from acegen.data.utils import smiles_to_tensordict +from acegen.data.utils import collate_smiles_to_tensordict, smiles_to_tensordict diff --git a/acegen/data/utils.py b/acegen/data/utils.py index 17b702dd..c1b71496 100644 --- a/acegen/data/utils.py +++ b/acegen/data/utils.py @@ -67,3 +67,18 @@ def smiles_to_tensordict( smiles_tensordict.set(("next", "is_init"), next_is_init) return smiles_tensordict + + +def collate_smiles_to_tensordict( + arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu" +): + """Function to take a list of encoded sequences and turn them into a tensordict.""" + collated_arr = torch.ones(len(arr), max_length) * -1 + for i, seq in enumerate(arr): + collated_arr[i, : seq.size(0)] = seq + data = smiles_to_tensordict( + collated_arr, reward=reward, replace_mask_value=0, device=device + ) + data.set("sequence", data.get("observation")) + data.set("sequence_mask", data.get("mask")) + return data diff --git a/scripts/augmented_memory/augmem.py b/scripts/augmented_memory/augmem.py new file mode 100644 index 00000000..bcdbe292 --- /dev/null +++ b/scripts/augmented_memory/augmem.py @@ -0,0 +1,373 @@ +#! /usr/bin/python3 +import datetime +import os +import random +from copy import deepcopy +from pathlib import Path + +import hydra +import numpy as np + +import torch +import tqdm +import yaml +from acegen.data import collate_smiles_to_tensordict + +from acegen.models import adapt_state_dict, models, register_model +from acegen.rl_env import generate_complete_smiles, TokenEnv +from acegen.scoring_functions import ( + custom_scoring_functions, + register_custom_scoring_function, + Task, +) +from acegen.vocabulary import Vocabulary +from omegaconf import OmegaConf, open_dict +from tensordict.utils import isin + +from torchrl.data import ( + LazyTensorStorage, + PrioritizedSampler, + TensorDictMaxValueWriter, + TensorDictReplayBuffer, +) +from torchrl.envs import InitTracker, TransformedEnv +from torchrl.modules.utils import get_primers_from_module +from torchrl.record.loggers import get_logger + +try: + import molscore + from molscore import MolScoreBenchmark, MolScoreCurriculum + from molscore.manager import MolScore + from molscore.utils import augment_smiles + + _has_molscore = True +except ImportError as err: + _has_molscore = False + MOLSCORE_ERR = err + +# hydra outputs saved in /tmp +os.chdir("/tmp") + + +@hydra.main( + config_path=".", + config_name="config_denovo", + version_base="1.2", +) +def main(cfg: "DictConfig"): + + if isinstance(cfg.seed, int): + cfg.seed = [cfg.seed] + + for seed in cfg.seed: + + # Set seed + random.seed(int(seed)) + np.random.seed(int(seed)) + torch.manual_seed(int(seed)) + + # Define save_dir and save config + current_time = datetime.datetime.now() + timestamp_str = current_time.strftime("%Y_%m_%d_%H%M%S") + os.chdir(os.path.dirname(__file__)) + save_dir = ( + f"{cfg.log_dir}/{cfg.experiment_name}_{cfg.agent_name}_{timestamp_str}" + ) + with open_dict(cfg): + cfg.save_dir = save_dir + os.makedirs(save_dir, exist_ok=True) + with open(Path(save_dir) / "config.yaml", "w") as yaml_file: + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + yaml.dump(cfg_dict, yaml_file, default_flow_style=False) + + # Define training task and run + if cfg.get("molscore_task", None): + + if not _has_molscore: + raise RuntimeError( + "MolScore library not found. Unable to create a scoring function. " + "To install MolScore, use: `pip install MolScore`" + ) from MOLSCORE_ERR + + if cfg.molscore_mode == "single": + task = MolScore( + model_name=cfg.agent_name, + task_config=cfg.molscore_task, + budget=cfg.total_smiles, + replay_size=cfg.replay_buffer_size, + replay_purge=True, + output_dir=os.path.abspath(save_dir), + add_run_dir=True, + **cfg.get("molscore_kwargs", {}), + ) + run_reinvent(cfg, task) + + if cfg.molscore_mode == "benchmark": + MSB = MolScoreBenchmark( + model_name=cfg.agent_name, + model_parameters=dict(cfg), + benchmark=cfg.molscore_task, + budget=cfg.total_smiles, + replay_size=cfg.replay_buffer_size, + replay_purge=True, + output_dir=os.path.abspath(save_dir), + add_benchmark_dir=False, + **cfg.get("molscore_kwargs", {}), + ) + for task in MSB: + run_reinvent(cfg, task) + + if cfg.molscore_mode == "curriculum": + task = MolScoreCurriculum( + model_name=cfg.agent_name, + model_parameters=dict(cfg), + benchmark=cfg.molscore_task, + budget=cfg.total_smiles, + replay_size=cfg.replay_buffer_size, + replay_purge=True, + output_dir=os.path.abspath(save_dir), + **cfg.get("molscore_kwargs", {}), + ) + run_reinvent(cfg, task) + + elif cfg.get("custom_task", None): + if cfg.custom_task not in custom_scoring_functions: + register_custom_scoring_function(cfg.custom_task, cfg.custom_task) + task = Task( + name=cfg.custom_task, + scoring_function=custom_scoring_functions[cfg.custom_task], + budget=cfg.total_smiles, + output_dir=save_dir, + ) + run_reinvent(cfg, task) + + else: + raise ValueError("No scoring function specified.") + + +def run_reinvent(cfg, task): + + # Get available device + device = ( + torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu") + ) + + # If custom model, register it + if cfg.model not in models and cfg.get("custom_model_factory", None) is not None: + register_model(cfg.model, cfg.model_factory) + + # Check if model is available + if cfg.model not in models: + raise ValueError( + f"Model {cfg.model} not found. For custom models, define a model factory as explain in the tutorials." + ) + + # Get model + (create_actor, _, _, voc_path, ckpt_path, tokenizer) = models[cfg.model] + + # Create vocabulary + #################################################################################################################### + + vocabulary = Vocabulary.load(voc_path, tokenizer=tokenizer) + + # Create models + #################################################################################################################### + + ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) + actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary)) + actor_inference.load_state_dict( + adapt_state_dict(deepcopy(ckpt), actor_inference.state_dict()) + ) + actor_inference = actor_inference.to(device) + actor_training = actor_training.to(device) + + prior, _ = create_actor(vocabulary_size=len(vocabulary)) + prior.load_state_dict(adapt_state_dict(deepcopy(ckpt), prior.state_dict())) + prior = prior.to(device) + + # Create RL environment + #################################################################################################################### + + env_kwargs = { + "start_token": vocabulary.start_token_index, + "end_token": vocabulary.end_token_index, + "length_vocabulary": len(vocabulary), + "max_length": cfg.get("max_length", 100), + "batch_size": cfg.num_envs, + "device": device, + } + + def create_env_fn(): + """Create a single RL rl_env.""" + env = TokenEnv(**env_kwargs) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + if primers := get_primers_from_module(actor_inference): + env.append_transform(primers) + return env + + env = create_env_fn() + + # Create optimizer + #################################################################################################################### + + optim = torch.optim.Adam( + actor_training.parameters(), + lr=cfg.lr, + eps=cfg.eps, + weight_decay=cfg.weight_decay, + ) + + # Create logger + #################################################################################################################### + + logger = None + if cfg.logger_backend: + experiment_name = f"{cfg.agent_name}" + try: + experiment_name += f"_{task.cfg.get('task')}" + except AttributeError: + experiment_name += task.name + logger = get_logger( + cfg.logger_backend, + logger_name=cfg.save_dir, + experiment_name=experiment_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.experiment_name, + "group": cfg.agent_name, + "reinit": True, + }, + ) + + # Training loop + #################################################################################################################### + + total_done = 0 + sigma = cfg.sigma + pbar = tqdm.tqdm(total=cfg.total_smiles) + + while not task.finished: + + # Generate data + data = generate_complete_smiles( + policy_sample=actor_inference, + policy_evaluate=actor_training, + vocabulary=vocabulary, + scoring_function=task, + environment=env, + prompt=cfg.get("prompt", None), + promptsmiles=cfg.get("promptsmiles", None), + promptsmiles_optimize=cfg.get("promptsmiles_optimize", True), + promptsmiles_shuffle=cfg.get("promptsmiles_shuffle", True), + promptsmiles_multi=cfg.get("promptsmiles_multi", False), + promptsmiles_scan=cfg.get("promptsmiles_scan", False), + remove_duplicates=True, + ) + + log_info = {} + data_next = data.get("next") + done = data_next.get("done").squeeze(-1) + total_done += done.sum().item() + pbar.update(done.sum().item()) + + # Save info about smiles lengths and rewards + episode_rewards = data_next["reward"][done] + episode_length = (data_next["observation"] != 0.0).float().sum(-1).mean() + if len(episode_rewards) > 0: + log_info.update( + { + "train/total_smiles": total_done, + "train/reward": episode_rewards.mean().item(), + "train/min_reward": episode_rewards.min().item(), + "train/max_reward": episode_rewards.max().item(), + "train/episode_length": episode_length.item(), + } + ) + + data, loss, agent_likelihood = compute_loss(data, actor_training, prior, sigma) + + # Average loss over the batch + loss = loss.mean() + + # Add regularizer that penalizes high likelihood for the entire sequence + loss_p = -(1 / agent_likelihood).mean() + loss += 5 * 1e3 * loss_p + + # Calculate gradients and make an update to the network weights + optim.zero_grad() + loss.backward() + optim.step() + + for _ in range(cfg.augmentation_rounds): + # Augment sampled SMILES + sampled_smiles = augment_smiles(data.get("SMILES").cpu().data) + sampled_reward = data.get(("next", "reward")).squeeze(-1).sum(-1) + # Sample replay buffer + replay_smiles, replay_reward = task.replay( + cfg.replay_batch_size, augment=True + ) + replay_reward = torch.tensor(replay_reward, device=device).float() + # Concatenate and create tensor + aug_tokens = [ + torch.tensor(vocabulary.encode(smi)) + for smi in sampled_smiles + replay_smiles + ] + aug_reward = torch.cat([sampled_reward, replay_reward], dim=0) + aug_data = collate_smiles_to_tensordict( + arr=aug_tokens, + max_length=env.max_length, + reward=aug_reward, + device=device, + ) + # Compute loss + aug_data, loss, agent_likelihood = compute_loss( + aug_data, actor_training, prior, sigma + ) + # Average loss over the batch + loss = loss.mean() + # Add regularizer that penalizes high likelihood for the entire sequence + loss_p = -(1 / agent_likelihood).mean() + loss += 5 * 1e3 * loss_p + # Calculate gradients and make an update to the network weights + optim.zero_grad() + loss.backward() + optim.step() + + # Log info + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, step=total_done) + + +def get_log_prob(data, model): + actions = data.get("action") + model_in = data.select(*model.in_keys, strict=False) + log_prob = model.get_dist(model_in).log_prob(actions) + return log_prob + + +def compute_loss(data, model, prior, sigma): + + mask = data.get("mask").squeeze(-1) + + if "prior_log_prob" not in data.keys(): + with torch.no_grad(): + prior_log_prob = get_log_prob(data, prior) + data.set("prior_log_prob", prior_log_prob) + else: + prior_log_prob = data.get("prior_log_prob") + + agent_log_prob = get_log_prob(data, model) + agent_likelihood = (agent_log_prob * mask).sum(-1) + prior_likelihood = (prior_log_prob * mask).sum(-1) + score = data.get(("next", "reward")).squeeze(-1).sum(-1) + + augmented_likelihood = prior_likelihood + sigma * score + loss = torch.pow((augmented_likelihood - agent_likelihood), 2) + + return data, loss, agent_likelihood + + +if __name__ == "__main__": + main() diff --git a/scripts/augmented_memory/config_denovo.yaml b/scripts/augmented_memory/config_denovo.yaml new file mode 100644 index 00000000..f1c14798 --- /dev/null +++ b/scripts/augmented_memory/config_denovo.yaml @@ -0,0 +1,40 @@ +# Logging configuration +experiment_name: acegen +agent_name: augmem +log_dir: results # Directory to save the results +logger_backend: null # wandb, tensorboard, or null +seed: 101 # multiple seeds can be provided as a list to multiple experiments sequentially e.g. [101, 102, 103] + +# Environment configuration +num_envs: 64 # Number of smiles to generate in parallel +total_smiles: 10_000 # Total number of smiles to generate + +# Scoring function +molscore_mode: single # single, benchmark, or curriculum +molscore_task: MolOpt:Albuterol_similarity # selects the Albuterol_similarity task from the MolOpt benchmark +# molscore_task accepts task configuration files (JSON), benchmark (preset only), or curriculum task (preset only) +# Refer to MolScore documentation for more information +custom_task: null # Requires molscore_task mode to be set to null +# Reefr to the tutorials for more information on how to use custom scoring functions / tasks + +# Promptsmiles configuration +prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules + +# Model architecture +model: gru # gru, lstm, gpt2, mamba or llama2 +# The default prior varies for each model. Refer to the README file in the root directory for more information. +# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. +custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) + +# Optimizer configuration +lr: 0.0001 +eps: 1.0e-08 +weight_decay: 0.0 + +# Reinvent configuration +sigma: 128 + +# Data replay configuration +augmentation_rounds: 2 +replay_buffer_size: 100 +replay_batch_size: 100 diff --git a/scripts/augmented_memory/config_fragment.yaml b/scripts/augmented_memory/config_fragment.yaml new file mode 100644 index 00000000..267d1dee --- /dev/null +++ b/scripts/augmented_memory/config_fragment.yaml @@ -0,0 +1,44 @@ +# Logging configuration +experiment_name: acegen +agent_name: augmem +log_dir: results # Directory to save the results +logger_backend: null # wandb, tensorboard, or null +seed: 101 # multiple seeds can be provided as a list to multiple experiments sequentially e.g. [101, 102, 103] + +# Environment configuration +num_envs: 64 # Number of smiles to generate in parallel +total_smiles: 10_000 # Total number of smiles to generate +max_length: 200 # Maximum length limit of the generated SMILES + +# Scoring function +molscore_mode: single # single, benchmark, or curriculum +molscore_task: MolOpt:Celecoxxib_rediscovery # selects the Celecoxxib_rediscovery task from the MolOpt benchmark +# molscore_task accepts task configuration files (JSON), benchmark (preset only), or curriculum task (preset only) +# Refer to MolScore documentation for more information +custom_task: null # Requires molscore_task mode to be set to null +# Reefr to the tutorials for more information on how to use custom scoring functions / tasks + +# Promptsmiles configuration +promptsmiles: c1(C)ccc(*)cc1.NS(=O)(=O)(*) +promptsmiles_optimize: True +promptsmiles_shuffle: True +promptsmiles_multi: False + +# Model architecture +model: gru # gru, lstm, gpt2, mamba or llama2 +# The default prior varies for each model. Refer to the README file in the root directory for more information. +# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. +custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) + +# Optimizer configuration +lr: 0.0001 +eps: 1.0e-08 +weight_decay: 0.0 + +# Reinforce configuration +sigma: 128 + +# Data replay configuration +augmentation_rounds: 2 +replay_buffer_size: 100 +replay_batch_size: 100 diff --git a/scripts/augmented_memory/config_scaffold.yaml b/scripts/augmented_memory/config_scaffold.yaml new file mode 100644 index 00000000..825bd020 --- /dev/null +++ b/scripts/augmented_memory/config_scaffold.yaml @@ -0,0 +1,44 @@ +# Logging configuration +experiment_name: acegen +agent_name: augmem +log_dir: results # Directory to save the results +logger_backend: null # wandb, tensorboard, or null +seed: 101 # multiple seeds can be provided as a list to multiple experiments sequentially e.g. [101, 102, 103] + +# Environment configuration +num_envs: 64 # Number of smiles to generate in parallel +total_smiles: 10_000 # Total number of smiles to generate +max_length: 200 # Maximum length limit of the generated SMILES + +# Scoring function +molscore_mode: single # single, benchmark, or curriculum +molscore_task: LibINVENT_Exp1:DRD2_SelRF_SubFilt_DF # selects the DRD2_SelRF_SubFilt_DF task from the LibINVENT_Exp1 benchmark +# molscore_task accepts task configuration files (JSON), benchmark (preset only), or curriculum task (preset only) +# Refer to MolScore documentation for more information +custom_task: null # Requires molscore_task mode to be set to null +# Reefr to the tutorials for more information on how to use custom scoring functions / tasks + +# Promptsmiles configuration +promptsmiles: N1(*)CCN(CC1)CCCCN(*) +promptsmiles_optimize: True +promptsmiles_shuffle: True +promptsmiles_multi: False + +# Model architecture +model: gru # gru, lstm, gpt2, mamba or llama2 +# The default prior varies for each model. Refer to the README file in the root directory for more information. +# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. +custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) + +# Optimizer configuration +lr: 0.0001 +eps: 1.0e-08 +weight_decay: 0.0 + +# Reinforce configuration +sigma: 128 + +# Data replay configuration +augmentation_rounds: 2 +replay_buffer_size: 100 +replay_batch_size: 100 diff --git a/tests/check_scripts/run_augmem_denovo.sh b/tests/check_scripts/run_augmem_denovo.sh new file mode 100644 index 00000000..7e37c72d --- /dev/null +++ b/tests/check_scripts/run_augmem_denovo.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name=augmem_denovo +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/augmem_denovo%j.txt +#SBATCH --error=slurm_errors/augmem_denovo%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="augmem_denovo" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru." + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/augmented_memory/augmem.py \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir=/tmp/"$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/tests/check_scripts/run_augmem_fragment.sh b/tests/check_scripts/run_augmem_fragment.sh new file mode 100755 index 00000000..5d7f33a8 --- /dev/null +++ b/tests/check_scripts/run_augmem_fragment.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name=augmem_fragment +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/augmem_fragment%j.txt +#SBATCH --error=slurm_errors/augmem_fragment%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="augmem_fragment" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru." + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/augmented_memory/augmem.py --config-name config_fragment \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir=/tmp/"$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/tests/check_scripts/run_augmem_scaffold.sh b/tests/check_scripts/run_augmem_scaffold.sh new file mode 100755 index 00000000..60ba049e --- /dev/null +++ b/tests/check_scripts/run_augmem_scaffold.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name=augmem_scaffold +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/augmem_scaffold%j.txt +#SBATCH --error=slurm_errors/augmem_scaffold%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="augmem_scaffold" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru." + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/augmented_memory/augmem.py --config-name config_scaffold \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir=/tmp/"$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/tutorials/hyperparameter_optimisation_with_wandb.md b/tutorials/hyperparameter_optimisation_with_wandb.md index 3bfec376..43f2be84 100644 --- a/tutorials/hyperparameter_optimisation_with_wandb.md +++ b/tutorials/hyperparameter_optimisation_with_wandb.md @@ -48,7 +48,7 @@ Now we have a sweep setup, we can start agents that sample a set of hyperparamet To start an agent, simply run the following command, ``` -wandb agent sweep_id +wandb agent --count sweep_id ``` This command can be run several times to create as many agents as you wish (for example, on different GPUs), enabling parallel agents testing different hyperparameters. @@ -59,3 +59,4 @@ This command can be run several times to create as many agents as you wish (for - Any hyperparameters not specified in the sweep configuration will be taken from the default configuration path. - This can be changed by adding `--config-name=` to the command section of the sweep inbetween `${program}` and `${args_no_hyphens}`. - Or by changing the default path specified at the top of main in the respective script. +- Something may not be interpreted correctly, specifically `null` may be interpreted as a string. So it's recommended to use `False` or `0`. \ No newline at end of file