Skip to content

Commit

Permalink
Update train and test script
Browse files Browse the repository at this point in the history
  • Loading branch information
amacati committed Feb 5, 2025
1 parent 7c2ad96 commit 1236f12
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
8 changes: 5 additions & 3 deletions tutorials/ppo/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gymnasium
import gymnasium.wrappers.vector.jax_to_torch
import jax
import numpy as np
import torch
from agent import Agent
Expand Down Expand Up @@ -44,15 +45,16 @@
test_env = gymnasium.wrappers.vector.jax_to_torch.JaxToTorch(norm_test_env, device=device)

# Load checkpoint
checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt")
checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt", weights_only=True)

# Create agent and load state
agent = Agent(test_env).to(device)
agent.load_state_dict(checkpoint["model_state_dict"])

# Set normalization parameters
norm_test_env.obs_rms.mean = checkpoint["obs_mean"]
norm_test_env.obs_rms.var = checkpoint["obs_var"]
jax_device = jax.devices(env_device)[0]
norm_test_env.obs_rms.mean = jax.dlpack.from_dlpack(checkpoint["obs_mean"], jax_device)
norm_test_env.obs_rms.var = jax.dlpack.from_dlpack(checkpoint["obs_var"], jax_device)

# Test for 10 episodes
n_episodes = 10
Expand Down
8 changes: 4 additions & 4 deletions tutorials/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def save_model(agent, optimizer, train_envs, path):
save_dict = {"model_state_dict": agent.state_dict(), "optim_state_dict": optimizer.state_dict()}
# Unwrap the environment to find normalization wrapper
if (norm_env := unwrap_norm_env(train_envs)) is not None:
save_dict["obs_mean"] = norm_env.obs_rms.mean
save_dict["obs_var"] = norm_env.obs_rms.var
save_dict["obs_mean"] = torch.utils.dlpack.from_dlpack(norm_env.obs_rms.mean)
save_dict["obs_var"] = torch.utils.dlpack.from_dlpack(norm_env.obs_rms.var)
torch.save(save_dict, path)


Expand Down Expand Up @@ -414,10 +414,10 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False):
"seed": 0,
"n_eval_envs": 64,
"n_eval_steps": 1_000,
"save_model": False,
"save_model": True,
"eval_interval": 999_000,
"lr_decay": False,
}
)

train_ppo(config, wandb_log=True)
train_ppo(config, wandb_log=False)

0 comments on commit 1236f12

Please sign in to comment.