diff --git a/tests/sb3_test.py b/tests/sb3_test.py index 054c529..9df038f 100644 --- a/tests/sb3_test.py +++ b/tests/sb3_test.py @@ -7,10 +7,13 @@ from robot_sf.gym_env.robot_env import RobotEnv from robot_sf.feature_extractor import DynamicsExtractor - def test_can_load_model_snapshot(): MODEL_PATH = "./temp/ppo_model" MODEL_FILE = f"{MODEL_PATH}.zip" + + # Create the directory if it doesn't exist + os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) + if os.path.exists(MODEL_FILE) and os.path.isfile(MODEL_FILE): os.remove(MODEL_FILE)