From 42d4a0621578f6f9461d35f25c6aa82e106aa930 Mon Sep 17 00:00:00 2001 From: ll7 Date: Thu, 7 Mar 2024 14:20:34 +0100 Subject: [PATCH 1/3] Add wandb_ppo_training script for training and logging PPO robot --- scripts/wandb_ppo_training.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 scripts/wandb_ppo_training.py diff --git a/scripts/wandb_ppo_training.py b/scripts/wandb_ppo_training.py new file mode 100644 index 0000000..9f2ebbf --- /dev/null +++ b/scripts/wandb_ppo_training.py @@ -0,0 +1,56 @@ +"""Train ppo robot and log to wandb""" + +import wandb + +from stable_baselines3 import PPO +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList + +from robot_sf.robot_env import RobotEnv +from robot_sf.sim_config import EnvSettings +from robot_sf.feature_extractor import DynamicsExtractor +from robot_sf.tb_logging import DrivingMetricsCallback + + +def training(): + n_envs = 32 + ped_densities = [0.01, 0.02, 0.04, 0.08] + difficulty = 2 + + + def make_env(): + config = EnvSettings() + config.sim_config.ped_density_by_difficulty = ped_densities + config.sim_config.difficulty = difficulty + return RobotEnv(config) + + env = make_vec_env(make_env, n_envs=n_envs, vec_env_cls=SubprocVecEnv) + + policy_kwargs = dict(features_extractor_class=DynamicsExtractor) + model = PPO( + "MultiInputPolicy", + env, + tensorboard_log="./logs/ppo_logs/", + policy_kwargs=policy_kwargs + ) + save_model_callback = CheckpointCallback( + 500_000 // n_envs, + "./model/backup", + "ppo_model" + ) + collect_metrics_callback = DrivingMetricsCallback(n_envs) + combined_callback = CallbackList( + [save_model_callback, collect_metrics_callback] + ) + + model.learn( + total_timesteps=10_000_000, + progress_bar=True, + callback=combined_callback + ) + model.save("./model/ppo_model") + + +if __name__ == '__main__': + training() From 1e9fdce3bfd8b2de932304401686e274cce033fa Mon Sep 17 00:00:00 2001 From: ll7 Date: Thu, 7 Mar 2024 15:02:32 +0100 Subject: [PATCH 2/3] Add wandb to .gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 65abc2b..3a7a5fe 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ images pysf_tests pysocialforce logs -*training*.zip \ No newline at end of file +*training*.zip +wandb \ No newline at end of file From 4e61510f291941f2b209115a7f18c245bf2c7cc4 Mon Sep 17 00:00:00 2001 From: ll7 Date: Thu, 7 Mar 2024 15:03:08 +0100 Subject: [PATCH 3/3] Add wandb integration and update training script --- scripts/wandb_ppo_training.py | 118 ++++++++++++++++++++++------------ 1 file changed, 76 insertions(+), 42 deletions(-) diff --git a/scripts/wandb_ppo_training.py b/scripts/wandb_ppo_training.py index 9f2ebbf..42dc060 100644 --- a/scripts/wandb_ppo_training.py +++ b/scripts/wandb_ppo_training.py @@ -1,6 +1,10 @@ -"""Train ppo robot and log to wandb""" +""" +Train ppo robot and log to wandb +Documentation can be found in `docs/wandb.md` +""" import wandb +from wandb.integration.sb3 import WandbCallback from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env @@ -12,45 +16,75 @@ from robot_sf.feature_extractor import DynamicsExtractor from robot_sf.tb_logging import DrivingMetricsCallback +wandb_config={ + "env": "robot_sf", + "algorithm": "ppo", + "difficulty": 2, + "ped_densities": [0.01, 0.02, 0.04, 0.08], + "n_envs": 32, + "total_timesteps": 10_000_000 +} -def training(): - n_envs = 32 - ped_densities = [0.01, 0.02, 0.04, 0.08] - difficulty = 2 - - - def make_env(): - config = EnvSettings() - config.sim_config.ped_density_by_difficulty = ped_densities - config.sim_config.difficulty = difficulty - return RobotEnv(config) - - env = make_vec_env(make_env, n_envs=n_envs, vec_env_cls=SubprocVecEnv) - - policy_kwargs = dict(features_extractor_class=DynamicsExtractor) - model = PPO( - "MultiInputPolicy", - env, - tensorboard_log="./logs/ppo_logs/", - policy_kwargs=policy_kwargs - ) - save_model_callback = CheckpointCallback( - 500_000 // n_envs, - "./model/backup", - "ppo_model" - ) - collect_metrics_callback = DrivingMetricsCallback(n_envs) - combined_callback = CallbackList( - [save_model_callback, collect_metrics_callback] - ) - - model.learn( - total_timesteps=10_000_000, - progress_bar=True, - callback=combined_callback - ) - model.save("./model/ppo_model") - - -if __name__ == '__main__': - training() +# Start a new run to track and log to W&B. +wandb_run = wandb.init( + project="robot_sf", + config=wandb_config, + save_code = True, + group="ppo robot_sf", + job_type="initial training", + tags=["ppo", "robot_sf"], + name="init ppo robot_sf", + notes="Initial training of ppo robot_sf", + resume="allow", + magic=True, + mode="online", + sync_tensorboard=True, + monitor_gym=True +) + + +N_ENVS = wandb_config["n_envs"] +ped_densities = wandb_config["ped_densities"] +DIFFICULTY = wandb_config["difficulty"] + + +def make_env(): + config = EnvSettings() + config.sim_config.ped_density_by_difficulty = ped_densities + config.sim_config.difficulty = DIFFICULTY + return RobotEnv(config) + +env = make_vec_env(make_env, n_envs=N_ENVS, vec_env_cls=SubprocVecEnv) + +policy_kwargs = dict(features_extractor_class=DynamicsExtractor) +model = PPO( + "MultiInputPolicy", + env, + tensorboard_log="./logs/ppo_logs/", + policy_kwargs=policy_kwargs + ) +save_model_callback = CheckpointCallback( + 500_000 // N_ENVS, + "./model/backup", + "ppo_model" + ) +collect_metrics_callback = DrivingMetricsCallback(N_ENVS) + +wandb_callback = WandbCallback( + gradient_save_freq=20_000, + model_save_path=f"models/{wandb_run.id}", + verbose=2, +) + +combined_callback = CallbackList( + [save_model_callback, collect_metrics_callback, wandb_callback] + ) + +model.learn( + total_timesteps=wandb_config["total_timesteps"], + progress_bar=True, + callback=combined_callback + ) +model.save("./model/ppo_model") + +wandb_run.finish()