diff --git a/scripts/ppo/ppo.py b/scripts/ppo/ppo.py index 362ea9e..3340942 100644 --- a/scripts/ppo/ppo.py +++ b/scripts/ppo/ppo.py @@ -14,13 +14,13 @@ import yaml from acegen.models import adapt_state_dict, models, register_model -from acegen.rl_env import generate_complete_smiles, SMILESEnv +from acegen.rl_env import generate_complete_smiles, TokenEnv from acegen.scoring_functions import ( custom_scoring_functions, register_custom_scoring_function, Task, ) -from acegen.vocabulary import SMILESVocabulary +from acegen.vocabulary import Vocabulary from omegaconf import OmegaConf, open_dict from tensordict import TensorDict from tensordict.utils import isin @@ -171,7 +171,7 @@ def run_ppo(cfg, task): # Create vocabulary #################################################################################################################### - vocabulary = SMILESVocabulary.load(voc_path, tokenizer=tokenizer) + vocabulary = Vocabulary.load(voc_path, tokenizer=tokenizer) # Create models #################################################################################################################### @@ -220,7 +220,7 @@ def run_ppo(cfg, task): # Define environment creation function def create_env_fn(): """Create a single RL rl_env.""" - env = SMILESEnv(**env_kwargs) + env = TokenEnv(**env_kwargs) env = TransformedEnv(env) env.append_transform(InitTracker()) if primers := get_primers_from_module(actor_inference): @@ -338,6 +338,7 @@ def create_env_fn(): promptsmiles_scan=cfg.get("promptsmiles_scan", False), remove_duplicates=True, ) + data.pop("SMILES") # Update progress bar data_next = data.get("next") @@ -392,17 +393,14 @@ def create_env_fn(): loss = loss_module(batch) loss = loss.named_apply( lambda name, value: ( - (value * mask).mean() - if name.startswith("loss_") - else value - # (value * mask).sum(-1).mean(-1) if name.startswith("loss_") else value + (value * mask).mean() if name.startswith("loss_") else value ), batch_size=[], ) loss_sum = ( loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) - losses[j, i] = loss.select( + losses[j, 0] = loss.select( "loss_critic", "loss_entropy", "loss_objective" ).detach() @@ -412,7 +410,7 @@ def create_env_fn(): kl_div = kl_divergence(actor_training.get_dist(batch), prior_dist) kl_div = (kl_div * mask.squeeze()).sum(-1).mean(-1) loss_sum += kl_div * kl_coef - losses[j, i] = TensorDict( + losses[j, 0] = TensorDict( {"kl_div": kl_div.detach().item()}, batch_size=[] )