diff --git a/tutorials/ppo/test.py b/tutorials/ppo/test.py index cc71bbe..2fbd244 100644 --- a/tutorials/ppo/test.py +++ b/tutorials/ppo/test.py @@ -4,6 +4,7 @@ import gymnasium import gymnasium.wrappers.vector.jax_to_torch +import jax import numpy as np import torch from agent import Agent @@ -44,15 +45,16 @@ test_env = gymnasium.wrappers.vector.jax_to_torch.JaxToTorch(norm_test_env, device=device) # Load checkpoint -checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt") +checkpoint = torch.load(Path(__file__).parent / "ppo_checkpoint.pt", weights_only=True) # Create agent and load state agent = Agent(test_env).to(device) agent.load_state_dict(checkpoint["model_state_dict"]) # Set normalization parameters -norm_test_env.obs_rms.mean = checkpoint["obs_mean"] -norm_test_env.obs_rms.var = checkpoint["obs_var"] +jax_device = jax.devices(env_device)[0] +norm_test_env.obs_rms.mean = jax.dlpack.from_dlpack(checkpoint["obs_mean"], jax_device) +norm_test_env.obs_rms.var = jax.dlpack.from_dlpack(checkpoint["obs_var"], jax_device) # Test for 10 episodes n_episodes = 10 diff --git a/tutorials/ppo/train.py b/tutorials/ppo/train.py index d836a2c..c2dbc03 100644 --- a/tutorials/ppo/train.py +++ b/tutorials/ppo/train.py @@ -87,8 +87,8 @@ def save_model(agent, optimizer, train_envs, path): save_dict = {"model_state_dict": agent.state_dict(), "optim_state_dict": optimizer.state_dict()} # Unwrap the environment to find normalization wrapper if (norm_env := unwrap_norm_env(train_envs)) is not None: - save_dict["obs_mean"] = norm_env.obs_rms.mean - save_dict["obs_var"] = norm_env.obs_rms.var + save_dict["obs_mean"] = torch.utils.dlpack.from_dlpack(norm_env.obs_rms.mean) + save_dict["obs_var"] = torch.utils.dlpack.from_dlpack(norm_env.obs_rms.var) torch.save(save_dict, path) @@ -414,10 +414,10 @@ def train_ppo(config: ConfigDict, wandb_log: bool = False): "seed": 0, "n_eval_envs": 64, "n_eval_steps": 1_000, - "save_model": False, + "save_model": True, "eval_interval": 999_000, "lr_decay": False, } ) - train_ppo(config, wandb_log=True) + train_ppo(config, wandb_log=False)