diff --git a/tests/sb3_test.py b/tests/sb3_test.py index 52c8fa7..054c529 100644 --- a/tests/sb3_test.py +++ b/tests/sb3_test.py @@ -23,7 +23,7 @@ def test_can_load_model_snapshot(): inf_env = RobotEnv() model2 = PPO.load(MODEL_PATH, env=inf_env) - obs = inf_env.reset() + obs, info = inf_env.reset() action, _ = model2.predict(obs, deterministic=True) assert action.shape == inf_env.action_space.shape