Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Jun 27, 2024
1 parent 43f82e0 commit 224d3a4
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import yaml

from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, SMILESEnv
from acegen.rl_env import generate_complete_smiles, TokenEnv
from acegen.scoring_functions import (
custom_scoring_functions,
register_custom_scoring_function,
Task,
)
from acegen.vocabulary import SMILESVocabulary
from acegen.vocabulary import Vocabulary
from omegaconf import OmegaConf, open_dict
from tensordict import TensorDict
from tensordict.utils import isin
Expand Down Expand Up @@ -171,7 +171,7 @@ def run_ppo(cfg, task):
# Create vocabulary
####################################################################################################################

vocabulary = SMILESVocabulary.load(voc_path, tokenizer=tokenizer)
vocabulary = Vocabulary.load(voc_path, tokenizer=tokenizer)

# Create models
####################################################################################################################
Expand Down Expand Up @@ -220,7 +220,7 @@ def run_ppo(cfg, task):
# Define environment creation function
def create_env_fn():
"""Create a single RL rl_env."""
env = SMILESEnv(**env_kwargs)
env = TokenEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
if primers := get_primers_from_module(actor_inference):
Expand Down Expand Up @@ -338,6 +338,7 @@ def create_env_fn():
promptsmiles_scan=cfg.get("promptsmiles_scan", False),
remove_duplicates=True,
)
data.pop("SMILES")

# Update progress bar
data_next = data.get("next")
Expand Down Expand Up @@ -392,17 +393,14 @@ def create_env_fn():
loss = loss_module(batch)
loss = loss.named_apply(
lambda name, value: (
(value * mask).mean()
if name.startswith("loss_")
else value
# (value * mask).sum(-1).mean(-1) if name.startswith("loss_") else value
(value * mask).mean() if name.startswith("loss_") else value
),
batch_size=[],
)
loss_sum = (
loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"]
)
losses[j, i] = loss.select(
losses[j, 0] = loss.select(
"loss_critic", "loss_entropy", "loss_objective"
).detach()

Expand All @@ -412,7 +410,7 @@ def create_env_fn():
kl_div = kl_divergence(actor_training.get_dist(batch), prior_dist)
kl_div = (kl_div * mask.squeeze()).sum(-1).mean(-1)
loss_sum += kl_div * kl_coef
losses[j, i] = TensorDict(
losses[j, 0] = TensorDict(
{"kl_div": kl_div.detach().item()}, batch_size=[]
)

Expand Down

0 comments on commit 224d3a4

Please sign in to comment.