Skip to content

Commit

Permalink
fix extract rnn modules
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Jun 5, 2024
1 parent 61e6928 commit 395d06e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion scripts/hill_climb/hill_climb.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(get_primers_from_module(actor_training))
if primers := get_primers_from_module(actor_inference):
env.append_transform(primers)
return env

env = create_env_fn()
Expand Down
5 changes: 3 additions & 2 deletions scripts/sac/pretrain_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ def create_env_fn():
"train/reward": episode_rewards.mean().item(),
"train/min_reward": episode_rewards.min().item(),
"train/max_reward": episode_rewards.max().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
"train/episode_length": episode_length.sum().item() / len(
episode_length
),
}
)
if logger:
Expand Down
5 changes: 3 additions & 2 deletions scripts/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,9 @@ def create_env_fn():
"train/reward": episode_rewards.mean().item(),
"train/min_reward": episode_rewards.min().item(),
"train/max_reward": episode_rewards.max().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
"train/episode_length": episode_length.sum().item() / len(
episode_length
),
}
)
if logger:
Expand Down

0 comments on commit 395d06e

Please sign in to comment.