From b6f423750290af6c73846561f5fc0a4485543486 Mon Sep 17 00:00:00 2001 From: MorganCThomas Date: Wed, 24 Jul 2024 12:35:27 +0200 Subject: [PATCH 1/4] add detail --- tutorials/hyperparameter_optimisation_with_wandb.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tutorials/hyperparameter_optimisation_with_wandb.md b/tutorials/hyperparameter_optimisation_with_wandb.md index 25b49b9..8f7960f 100644 --- a/tutorials/hyperparameter_optimisation_with_wandb.md +++ b/tutorials/hyperparameter_optimisation_with_wandb.md @@ -58,4 +58,5 @@ This command can be run several times to create as many agents as you wish (for - Wandb provides an awesome GUI online to measure and track the best configurations according to some metrics (have you logged them in the script??) - Any hyperparameters not specified in the sweep configuration will be taken from the default configuration path. - This can be changed by adding `--config-name=` to the command section of the sweep inbetween `${program}` and `${args_no_hyphens}`. - - Or by changing the default path specified at the top of main in the respective script. \ No newline at end of file + - Or by changing the default path specified at the top of main in the respective script. +- Something may not be interpreted correctly, specifically `null` may be interpreted as a string. So it's recommended to use `False` or `0`. \ No newline at end of file From abe53d776c946fd1248b3fb2c5fefc4610bbe148 Mon Sep 17 00:00:00 2001 From: MorganCThomas Date: Wed, 24 Jul 2024 14:25:18 +0200 Subject: [PATCH 2/4] add agent count --- tutorials/hyperparameter_optimisation_with_wandb.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/hyperparameter_optimisation_with_wandb.md b/tutorials/hyperparameter_optimisation_with_wandb.md index 8f7960f..43f2be8 100644 --- a/tutorials/hyperparameter_optimisation_with_wandb.md +++ b/tutorials/hyperparameter_optimisation_with_wandb.md @@ -48,7 +48,7 @@ Now we have a sweep setup, we can start agents that sample a set of hyperparamet To start an agent, simply run the following command, ``` -wandb agent sweep_id +wandb agent --count sweep_id ``` This command can be run several times to create as many agents as you wish (for example, on different GPUs), enabling parallel agents testing different hyperparameters. From bc7aaa0edaee1cfdd6c01098ad6afdab5d9ea2e7 Mon Sep 17 00:00:00 2001 From: MorganCThomas Date: Tue, 13 Aug 2024 16:16:01 +0200 Subject: [PATCH 3/4] added augmented-memory --- acegen/data/__init__.py | 2 +- acegen/data/utils.py | 11 + scripts/augmented_memory/augmem.py | 363 ++++++++++++++++++ scripts/augmented_memory/config_denovo.yaml | 40 ++ scripts/augmented_memory/config_fragment.yaml | 44 +++ scripts/augmented_memory/config_scaffold.yaml | 44 +++ tests/check_scripts/run_augmem_denovo.sh | 38 ++ tests/check_scripts/run_augmem_fragment.sh | 38 ++ tests/check_scripts/run_augmem_scaffold.sh | 38 ++ 9 files changed, 617 insertions(+), 1 deletion(-) create mode 100644 scripts/augmented_memory/augmem.py create mode 100644 scripts/augmented_memory/config_denovo.yaml create mode 100644 scripts/augmented_memory/config_fragment.yaml create mode 100644 scripts/augmented_memory/config_scaffold.yaml create mode 100644 tests/check_scripts/run_augmem_denovo.sh create mode 100755 tests/check_scripts/run_augmem_fragment.sh create mode 100755 tests/check_scripts/run_augmem_scaffold.sh diff --git a/acegen/data/__init__.py b/acegen/data/__init__.py index 3609212..afafbfa 100644 --- a/acegen/data/__init__.py +++ b/acegen/data/__init__.py @@ -1,2 +1,2 @@ from acegen.data.smiles_dataset import load_dataset, MolBloomDataset, SMILESDataset -from acegen.data.utils import smiles_to_tensordict +from acegen.data.utils import smiles_to_tensordict, collate_smiles_to_tensordict diff --git a/acegen/data/utils.py b/acegen/data/utils.py index 17b702d..2414bbb 100644 --- a/acegen/data/utils.py +++ b/acegen/data/utils.py @@ -67,3 +67,14 @@ def smiles_to_tensordict( smiles_tensordict.set(("next", "is_init"), next_is_init) return smiles_tensordict + + +def collate_smiles_to_tensordict(arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu"): + """Function to take a list of encoded sequences and turn them into a tensordict.""" + collated_arr = torch.ones(len(arr), max_length) * -1 + for i, seq in enumerate(arr): + collated_arr[i, : seq.size(0)] = seq + data = smiles_to_tensordict(collated_arr, reward=reward, replace_mask_value=0, device=device) + data.set("sequence", data.get("observation")) + data.set("sequence_mask", data.get("mask")) + return data \ No newline at end of file diff --git a/scripts/augmented_memory/augmem.py b/scripts/augmented_memory/augmem.py new file mode 100644 index 0000000..f1c0f05 --- /dev/null +++ b/scripts/augmented_memory/augmem.py @@ -0,0 +1,363 @@ +#! /usr/bin/python3 +import datetime +import os +import random +from copy import deepcopy +from pathlib import Path + +import hydra +import numpy as np + +import torch +import tqdm +import yaml + +from acegen.models import adapt_state_dict, models, register_model +from acegen.rl_env import generate_complete_smiles, TokenEnv +from acegen.data import collate_smiles_to_tensordict +from acegen.scoring_functions import ( + custom_scoring_functions, + register_custom_scoring_function, + Task, +) +from acegen.vocabulary import Vocabulary +from omegaconf import OmegaConf, open_dict +from tensordict.utils import isin + +from torchrl.data import ( + LazyTensorStorage, + PrioritizedSampler, + TensorDictMaxValueWriter, + TensorDictReplayBuffer, +) +from torchrl.envs import InitTracker, TransformedEnv +from torchrl.modules.utils import get_primers_from_module +from torchrl.record.loggers import get_logger + +try: + import molscore + from molscore import MolScoreBenchmark, MolScoreCurriculum + from molscore.manager import MolScore + from molscore.utils import augment_smiles + + _has_molscore = True +except ImportError as err: + _has_molscore = False + MOLSCORE_ERR = err + +# hydra outputs saved in /tmp +os.chdir("/tmp") + + +@hydra.main( + config_path=".", + config_name="config_denovo", + version_base="1.2", +) +def main(cfg: "DictConfig"): + + if isinstance(cfg.seed, int): + cfg.seed = [cfg.seed] + + for seed in cfg.seed: + + # Set seed + random.seed(int(seed)) + np.random.seed(int(seed)) + torch.manual_seed(int(seed)) + + # Define save_dir and save config + current_time = datetime.datetime.now() + timestamp_str = current_time.strftime("%Y_%m_%d_%H%M%S") + os.chdir(os.path.dirname(__file__)) + save_dir = ( + f"{cfg.log_dir}/{cfg.experiment_name}_{cfg.agent_name}_{timestamp_str}" + ) + with open_dict(cfg): + cfg.save_dir = save_dir + os.makedirs(save_dir, exist_ok=True) + with open(Path(save_dir) / "config.yaml", "w") as yaml_file: + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + yaml.dump(cfg_dict, yaml_file, default_flow_style=False) + + # Define training task and run + if cfg.get("molscore_task", None): + + if not _has_molscore: + raise RuntimeError( + "MolScore library not found. Unable to create a scoring function. " + "To install MolScore, use: `pip install MolScore`" + ) from MOLSCORE_ERR + + if cfg.molscore_mode == "single": + task = MolScore( + model_name=cfg.agent_name, + task_config=cfg.molscore_task, + budget=cfg.total_smiles, + replay_size=cfg.replay_buffer_size, + replay_purge=True, + output_dir=os.path.abspath(save_dir), + add_run_dir=True, + **cfg.get("molscore_kwargs", {}), + ) + run_reinvent(cfg, task) + + if cfg.molscore_mode == "benchmark": + MSB = MolScoreBenchmark( + model_name=cfg.agent_name, + model_parameters=dict(cfg), + benchmark=cfg.molscore_task, + budget=cfg.total_smiles, + replay_size=cfg.replay_buffer_size, + replay_purge=True, + output_dir=os.path.abspath(save_dir), + add_benchmark_dir=False, + **cfg.get("molscore_kwargs", {}), + ) + for task in MSB: + run_reinvent(cfg, task) + + if cfg.molscore_mode == "curriculum": + task = MolScoreCurriculum( + model_name=cfg.agent_name, + model_parameters=dict(cfg), + benchmark=cfg.molscore_task, + budget=cfg.total_smiles, + replay_size=cfg.replay_buffer_size, + replay_purge=True, + output_dir=os.path.abspath(save_dir), + **cfg.get("molscore_kwargs", {}), + ) + run_reinvent(cfg, task) + + elif cfg.get("custom_task", None): + if cfg.custom_task not in custom_scoring_functions: + register_custom_scoring_function(cfg.custom_task, cfg.custom_task) + task = Task( + name=cfg.custom_task, + scoring_function=custom_scoring_functions[cfg.custom_task], + budget=cfg.total_smiles, + output_dir=save_dir, + ) + run_reinvent(cfg, task) + + else: + raise ValueError("No scoring function specified.") + + +def run_reinvent(cfg, task): + + # Get available device + device = ( + torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu") + ) + + # If custom model, register it + if cfg.model not in models and cfg.get("custom_model_factory", None) is not None: + register_model(cfg.model, cfg.model_factory) + + # Check if model is available + if cfg.model not in models: + raise ValueError( + f"Model {cfg.model} not found. For custom models, define a model factory as explain in the tutorials." + ) + + # Get model + (create_actor, _, _, voc_path, ckpt_path, tokenizer) = models[cfg.model] + + # Create vocabulary + #################################################################################################################### + + vocabulary = Vocabulary.load(voc_path, tokenizer=tokenizer) + + # Create models + #################################################################################################################### + + ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) + actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary)) + actor_inference.load_state_dict( + adapt_state_dict(deepcopy(ckpt), actor_inference.state_dict()) + ) + actor_inference = actor_inference.to(device) + actor_training = actor_training.to(device) + + prior, _ = create_actor(vocabulary_size=len(vocabulary)) + prior.load_state_dict( + adapt_state_dict(deepcopy(ckpt), prior.state_dict()) + ) + prior = prior.to(device) + + # Create RL environment + #################################################################################################################### + + env_kwargs = { + "start_token": vocabulary.start_token_index, + "end_token": vocabulary.end_token_index, + "length_vocabulary": len(vocabulary), + "max_length": cfg.get("max_length", 100), + "batch_size": cfg.num_envs, + "device": device, + } + + def create_env_fn(): + """Create a single RL rl_env.""" + env = TokenEnv(**env_kwargs) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + if primers := get_primers_from_module(actor_inference): + env.append_transform(primers) + return env + + env = create_env_fn() + + # Create optimizer + #################################################################################################################### + + optim = torch.optim.Adam( + actor_training.parameters(), + lr=cfg.lr, + eps=cfg.eps, + weight_decay=cfg.weight_decay, + ) + + # Create logger + #################################################################################################################### + + logger = None + if cfg.logger_backend: + experiment_name = f"{cfg.agent_name}" + try: + experiment_name += f"_{task.cfg.get('task')}" + except AttributeError: + experiment_name += task.name + logger = get_logger( + cfg.logger_backend, + logger_name=cfg.save_dir, + experiment_name=experiment_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.experiment_name, + "group": cfg.agent_name, + "reinit": True, + }, + ) + + # Training loop + #################################################################################################################### + + total_done = 0 + sigma = cfg.sigma + pbar = tqdm.tqdm(total=cfg.total_smiles) + + while not task.finished: + + # Generate data + data = generate_complete_smiles( + policy_sample=actor_inference, + policy_evaluate=actor_training, + vocabulary=vocabulary, + scoring_function=task, + environment=env, + prompt=cfg.get("prompt", None), + promptsmiles=cfg.get("promptsmiles", None), + promptsmiles_optimize=cfg.get("promptsmiles_optimize", True), + promptsmiles_shuffle=cfg.get("promptsmiles_shuffle", True), + promptsmiles_multi=cfg.get("promptsmiles_multi", False), + promptsmiles_scan=cfg.get("promptsmiles_scan", False), + remove_duplicates=True, + ) + + log_info = {} + data_next = data.get("next") + done = data_next.get("done").squeeze(-1) + total_done += done.sum().item() + pbar.update(done.sum().item()) + + # Save info about smiles lengths and rewards + episode_rewards = data_next["reward"][done] + episode_length = (data_next["observation"] != 0.0).float().sum(-1).mean() + if len(episode_rewards) > 0: + log_info.update( + { + "train/total_smiles": total_done, + "train/reward": episode_rewards.mean().item(), + "train/min_reward": episode_rewards.min().item(), + "train/max_reward": episode_rewards.max().item(), + "train/episode_length": episode_length.item(), + } + ) + + data, loss, agent_likelihood = compute_loss(data, actor_training, prior, sigma) + + # Average loss over the batch + loss = loss.mean() + + # Add regularizer that penalizes high likelihood for the entire sequence + loss_p = -(1 / agent_likelihood).mean() + loss += 5 * 1e3 * loss_p + + # Calculate gradients and make an update to the network weights + optim.zero_grad() + loss.backward() + optim.step() + + for _ in range(cfg.augmentation_rounds): + # Augment sampled SMILES + sampled_smiles = augment_smiles(data.get("SMILES").cpu().data) + sampled_reward = data.get(("next", "reward")).squeeze(-1).sum(-1) + # Sample replay buffer + replay_smiles, replay_reward = task.replay(cfg.replay_batch_size, augment=True) + replay_reward = torch.tensor(replay_reward, device=device).float() + # Concatenate and create tensor + aug_tokens = [torch.tensor(vocabulary.encode(smi)) for smi in sampled_smiles + replay_smiles] + aug_reward = torch.cat([sampled_reward, replay_reward], dim=0) + aug_data = collate_smiles_to_tensordict(arr=aug_tokens, max_length=env.max_length, reward=aug_reward, device=device) + # Compute loss + aug_data, loss, agent_likelihood = compute_loss(aug_data, actor_training, prior, sigma) + # Average loss over the batch + loss = loss.mean() + # Add regularizer that penalizes high likelihood for the entire sequence + loss_p = -(1 / agent_likelihood).mean() + loss += 5 * 1e3 * loss_p + # Calculate gradients and make an update to the network weights + optim.zero_grad() + loss.backward() + optim.step() + + # Log info + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, step=total_done) + + +def get_log_prob(data, model): + actions = data.get("action") + model_in = data.select(*model.in_keys, strict=False) + log_prob = model.get_dist(model_in).log_prob(actions) + return log_prob + + +def compute_loss(data, model, prior, sigma): + + mask = data.get("mask").squeeze(-1) + + if "prior_log_prob" not in data.keys(): + with torch.no_grad(): + prior_log_prob = get_log_prob(data, prior) + data.set("prior_log_prob", prior_log_prob) + else: + prior_log_prob = data.get("prior_log_prob") + + agent_log_prob = get_log_prob(data, model) + agent_likelihood = (agent_log_prob * mask).sum(-1) + prior_likelihood = (prior_log_prob * mask).sum(-1) + score = data.get(("next", "reward")).squeeze(-1).sum(-1) + + augmented_likelihood = prior_likelihood + sigma * score + loss = torch.pow((augmented_likelihood - agent_likelihood), 2) + + return data, loss, agent_likelihood + + +if __name__ == "__main__": + main() diff --git a/scripts/augmented_memory/config_denovo.yaml b/scripts/augmented_memory/config_denovo.yaml new file mode 100644 index 0000000..f1c1479 --- /dev/null +++ b/scripts/augmented_memory/config_denovo.yaml @@ -0,0 +1,40 @@ +# Logging configuration +experiment_name: acegen +agent_name: augmem +log_dir: results # Directory to save the results +logger_backend: null # wandb, tensorboard, or null +seed: 101 # multiple seeds can be provided as a list to multiple experiments sequentially e.g. [101, 102, 103] + +# Environment configuration +num_envs: 64 # Number of smiles to generate in parallel +total_smiles: 10_000 # Total number of smiles to generate + +# Scoring function +molscore_mode: single # single, benchmark, or curriculum +molscore_task: MolOpt:Albuterol_similarity # selects the Albuterol_similarity task from the MolOpt benchmark +# molscore_task accepts task configuration files (JSON), benchmark (preset only), or curriculum task (preset only) +# Refer to MolScore documentation for more information +custom_task: null # Requires molscore_task mode to be set to null +# Reefr to the tutorials for more information on how to use custom scoring functions / tasks + +# Promptsmiles configuration +prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules + +# Model architecture +model: gru # gru, lstm, gpt2, mamba or llama2 +# The default prior varies for each model. Refer to the README file in the root directory for more information. +# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. +custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) + +# Optimizer configuration +lr: 0.0001 +eps: 1.0e-08 +weight_decay: 0.0 + +# Reinvent configuration +sigma: 128 + +# Data replay configuration +augmentation_rounds: 2 +replay_buffer_size: 100 +replay_batch_size: 100 diff --git a/scripts/augmented_memory/config_fragment.yaml b/scripts/augmented_memory/config_fragment.yaml new file mode 100644 index 0000000..267d1de --- /dev/null +++ b/scripts/augmented_memory/config_fragment.yaml @@ -0,0 +1,44 @@ +# Logging configuration +experiment_name: acegen +agent_name: augmem +log_dir: results # Directory to save the results +logger_backend: null # wandb, tensorboard, or null +seed: 101 # multiple seeds can be provided as a list to multiple experiments sequentially e.g. [101, 102, 103] + +# Environment configuration +num_envs: 64 # Number of smiles to generate in parallel +total_smiles: 10_000 # Total number of smiles to generate +max_length: 200 # Maximum length limit of the generated SMILES + +# Scoring function +molscore_mode: single # single, benchmark, or curriculum +molscore_task: MolOpt:Celecoxxib_rediscovery # selects the Celecoxxib_rediscovery task from the MolOpt benchmark +# molscore_task accepts task configuration files (JSON), benchmark (preset only), or curriculum task (preset only) +# Refer to MolScore documentation for more information +custom_task: null # Requires molscore_task mode to be set to null +# Reefr to the tutorials for more information on how to use custom scoring functions / tasks + +# Promptsmiles configuration +promptsmiles: c1(C)ccc(*)cc1.NS(=O)(=O)(*) +promptsmiles_optimize: True +promptsmiles_shuffle: True +promptsmiles_multi: False + +# Model architecture +model: gru # gru, lstm, gpt2, mamba or llama2 +# The default prior varies for each model. Refer to the README file in the root directory for more information. +# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. +custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) + +# Optimizer configuration +lr: 0.0001 +eps: 1.0e-08 +weight_decay: 0.0 + +# Reinforce configuration +sigma: 128 + +# Data replay configuration +augmentation_rounds: 2 +replay_buffer_size: 100 +replay_batch_size: 100 diff --git a/scripts/augmented_memory/config_scaffold.yaml b/scripts/augmented_memory/config_scaffold.yaml new file mode 100644 index 0000000..825bd02 --- /dev/null +++ b/scripts/augmented_memory/config_scaffold.yaml @@ -0,0 +1,44 @@ +# Logging configuration +experiment_name: acegen +agent_name: augmem +log_dir: results # Directory to save the results +logger_backend: null # wandb, tensorboard, or null +seed: 101 # multiple seeds can be provided as a list to multiple experiments sequentially e.g. [101, 102, 103] + +# Environment configuration +num_envs: 64 # Number of smiles to generate in parallel +total_smiles: 10_000 # Total number of smiles to generate +max_length: 200 # Maximum length limit of the generated SMILES + +# Scoring function +molscore_mode: single # single, benchmark, or curriculum +molscore_task: LibINVENT_Exp1:DRD2_SelRF_SubFilt_DF # selects the DRD2_SelRF_SubFilt_DF task from the LibINVENT_Exp1 benchmark +# molscore_task accepts task configuration files (JSON), benchmark (preset only), or curriculum task (preset only) +# Refer to MolScore documentation for more information +custom_task: null # Requires molscore_task mode to be set to null +# Reefr to the tutorials for more information on how to use custom scoring functions / tasks + +# Promptsmiles configuration +promptsmiles: N1(*)CCN(CC1)CCCCN(*) +promptsmiles_optimize: True +promptsmiles_shuffle: True +promptsmiles_multi: False + +# Model architecture +model: gru # gru, lstm, gpt2, mamba or llama2 +# The default prior varies for each model. Refer to the README file in the root directory for more information. +# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. +custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) + +# Optimizer configuration +lr: 0.0001 +eps: 1.0e-08 +weight_decay: 0.0 + +# Reinforce configuration +sigma: 128 + +# Data replay configuration +augmentation_rounds: 2 +replay_buffer_size: 100 +replay_batch_size: 100 diff --git a/tests/check_scripts/run_augmem_denovo.sh b/tests/check_scripts/run_augmem_denovo.sh new file mode 100644 index 0000000..7e37c72 --- /dev/null +++ b/tests/check_scripts/run_augmem_denovo.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name=augmem_denovo +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/augmem_denovo%j.txt +#SBATCH --error=slurm_errors/augmem_denovo%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="augmem_denovo" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru." + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/augmented_memory/augmem.py \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir=/tmp/"$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/tests/check_scripts/run_augmem_fragment.sh b/tests/check_scripts/run_augmem_fragment.sh new file mode 100755 index 0000000..5d7f33a --- /dev/null +++ b/tests/check_scripts/run_augmem_fragment.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name=augmem_fragment +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/augmem_fragment%j.txt +#SBATCH --error=slurm_errors/augmem_fragment%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="augmem_fragment" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru." + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/augmented_memory/augmem.py --config-name config_fragment \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir=/tmp/"$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/tests/check_scripts/run_augmem_scaffold.sh b/tests/check_scripts/run_augmem_scaffold.sh new file mode 100755 index 0000000..60ba049 --- /dev/null +++ b/tests/check_scripts/run_augmem_scaffold.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +#SBATCH --job-name=augmem_scaffold +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/augmem_scaffold%j.txt +#SBATCH --error=slurm_errors/augmem_scaffold%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="augmem_scaffold" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru." + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/augmented_memory/augmem.py --config-name config_scaffold \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir=/tmp/"$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi From 967cbfb46c75fbe700595a85efb2c8b51ccb0e35 Mon Sep 17 00:00:00 2001 From: MorganCThomas Date: Tue, 13 Aug 2024 16:20:12 +0200 Subject: [PATCH 4/4] pre-commit formatting --- acegen/data/__init__.py | 2 +- acegen/data/utils.py | 10 +++++++--- scripts/augmented_memory/augmem.py | 28 +++++++++++++++++++--------- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/acegen/data/__init__.py b/acegen/data/__init__.py index afafbfa..f7a73d1 100644 --- a/acegen/data/__init__.py +++ b/acegen/data/__init__.py @@ -1,2 +1,2 @@ from acegen.data.smiles_dataset import load_dataset, MolBloomDataset, SMILESDataset -from acegen.data.utils import smiles_to_tensordict, collate_smiles_to_tensordict +from acegen.data.utils import collate_smiles_to_tensordict, smiles_to_tensordict diff --git a/acegen/data/utils.py b/acegen/data/utils.py index 2414bbb..c1b7149 100644 --- a/acegen/data/utils.py +++ b/acegen/data/utils.py @@ -69,12 +69,16 @@ def smiles_to_tensordict( return smiles_tensordict -def collate_smiles_to_tensordict(arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu"): +def collate_smiles_to_tensordict( + arr, max_length: int, reward: torch.Tensor = None, device: str = "cpu" +): """Function to take a list of encoded sequences and turn them into a tensordict.""" collated_arr = torch.ones(len(arr), max_length) * -1 for i, seq in enumerate(arr): collated_arr[i, : seq.size(0)] = seq - data = smiles_to_tensordict(collated_arr, reward=reward, replace_mask_value=0, device=device) + data = smiles_to_tensordict( + collated_arr, reward=reward, replace_mask_value=0, device=device + ) data.set("sequence", data.get("observation")) data.set("sequence_mask", data.get("mask")) - return data \ No newline at end of file + return data diff --git a/scripts/augmented_memory/augmem.py b/scripts/augmented_memory/augmem.py index f1c0f05..bcdbe29 100644 --- a/scripts/augmented_memory/augmem.py +++ b/scripts/augmented_memory/augmem.py @@ -11,10 +11,10 @@ import torch import tqdm import yaml +from acegen.data import collate_smiles_to_tensordict from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, TokenEnv -from acegen.data import collate_smiles_to_tensordict from acegen.scoring_functions import ( custom_scoring_functions, register_custom_scoring_function, @@ -182,9 +182,7 @@ def run_reinvent(cfg, task): actor_training = actor_training.to(device) prior, _ = create_actor(vocabulary_size=len(vocabulary)) - prior.load_state_dict( - adapt_state_dict(deepcopy(ckpt), prior.state_dict()) - ) + prior.load_state_dict(adapt_state_dict(deepcopy(ckpt), prior.state_dict())) prior = prior.to(device) # Create RL environment @@ -288,7 +286,7 @@ def create_env_fn(): ) data, loss, agent_likelihood = compute_loss(data, actor_training, prior, sigma) - + # Average loss over the batch loss = loss.mean() @@ -306,14 +304,26 @@ def create_env_fn(): sampled_smiles = augment_smiles(data.get("SMILES").cpu().data) sampled_reward = data.get(("next", "reward")).squeeze(-1).sum(-1) # Sample replay buffer - replay_smiles, replay_reward = task.replay(cfg.replay_batch_size, augment=True) + replay_smiles, replay_reward = task.replay( + cfg.replay_batch_size, augment=True + ) replay_reward = torch.tensor(replay_reward, device=device).float() # Concatenate and create tensor - aug_tokens = [torch.tensor(vocabulary.encode(smi)) for smi in sampled_smiles + replay_smiles] + aug_tokens = [ + torch.tensor(vocabulary.encode(smi)) + for smi in sampled_smiles + replay_smiles + ] aug_reward = torch.cat([sampled_reward, replay_reward], dim=0) - aug_data = collate_smiles_to_tensordict(arr=aug_tokens, max_length=env.max_length, reward=aug_reward, device=device) + aug_data = collate_smiles_to_tensordict( + arr=aug_tokens, + max_length=env.max_length, + reward=aug_reward, + device=device, + ) # Compute loss - aug_data, loss, agent_likelihood = compute_loss(aug_data, actor_training, prior, sigma) + aug_data, loss, agent_likelihood = compute_loss( + aug_data, actor_training, prior, sigma + ) # Average loss over the batch loss = loss.mean() # Add regularizer that penalizes high likelihood for the entire sequence