Skip to content

Commit

Permalink
promptsmiles bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MorganCThomas committed May 30, 2024
1 parent a79dc1e commit 32b4184
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions acegen/rl_env/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def generate_complete_smiles(
replace_mask_value=vocabulary.end_token_index,
device=env_device,
)
# Add final complete smiles for logging
# Add final complete smiles for logging and scoring
_output_data.set(
"promptsmiles",
enc_smiles[-1][:, :-1].to(env_device),
Expand Down Expand Up @@ -249,7 +249,7 @@ def generate_complete_smiles(
failed_mask.unsqueeze(-1).expand_as(output_data["action"]), 0
)

# Add final completed promptsmiles for logging
# Add final completed promptsmiles for logging and scoring
output_data.set(
"promptsmiles",
enc_smiles[-1][:, :-1].to(env_device),
Expand All @@ -266,7 +266,7 @@ def generate_complete_smiles(

# Compute rewards
if scoring_function:
smiles = output_data.get("action").cpu()
smiles = output_data.get("promptsmiles").cpu()
next_output_data = output_data.get("next")
done = next_output_data.get("done").squeeze(-1)
smiles_str = [vocabulary.decode(smi.numpy()) for smi in smiles]
Expand Down

0 comments on commit 32b4184

Please sign in to comment.