-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #54 from ll7/ll7/issue48-Retrain-the-robot-with-a-…
…reward-that-penalizes-quick-action-changes Ll7/issue48 retrain the robot with a reward that penalizes quick action changes
- Loading branch information
Showing
15 changed files
with
247 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
You are an AI assistant specialized in Python development. Your approach emphasizes: Clear project structure with separate directories for source code, tests, docs, and config.Modular design with distinct files for models, services, controllers, and utilities.Configuration management using environment variables. Robust error handling and logging, including context capture. Comprehensive testing with pytest. Detailed documentation using docstrings and README files.Dependency management via requirements.txt and virtual environments. CI/CD implementation with GitHub Actions friendly coding practices: You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development. Follow the following rules:For any python file, be sure to ALWAYS add typing annotations to each function or class. Be sure to include return types when necessary. Add descriptive docstrings to all python functions and classes as well. Please use pep257 convention. Update existing docstrings if need be. Make sure you keep any comments that exist in a file. When writing tests, make sure that you ONLY use pytest or pytest plugins, do NOT use the unittest module. All tests should have typing annotations as well. All tests should be in ./tests. Be sure to create all necessary files and folders. If you are creating files inside of ./tests, be sure to make a init.py file if one does not exist. All tests should be fully annotated and should contain docstrings. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,8 @@ build | |
file.log | ||
images | ||
logs | ||
model | ||
# model | ||
SLURM/model | ||
profile.json | ||
profiles | ||
pysf_tests | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#! /bin/bash | ||
|
||
module purge | ||
|
||
module load anaconda cuda | ||
|
||
conda activate conda_env | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# SLURM | ||
|
||
```bash | ||
sbatch slurm_train.sl | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""train a robot in robot_sf on a slurm server""" | ||
|
||
import sys | ||
|
||
from loguru import logger | ||
|
||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from stable_baselines3.common.vec_env import SubprocVecEnv | ||
from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList | ||
|
||
from robot_sf.gym_env.robot_env import RobotEnv | ||
from robot_sf.gym_env.env_config import EnvSettings | ||
from robot_sf.feature_extractor import DynamicsExtractor | ||
from robot_sf.tb_logging import DrivingMetricsCallback | ||
|
||
|
||
def training(): | ||
n_envs = 64 | ||
ped_densities = [0.01, 0.02, 0.04, 0.08] | ||
difficulty = 2 | ||
|
||
|
||
def make_env(): | ||
config = EnvSettings() | ||
config.sim_config.ped_density_by_difficulty = ped_densities | ||
config.sim_config.difficulty = difficulty | ||
return RobotEnv(config) | ||
|
||
env = make_vec_env(make_env, n_envs=n_envs, vec_env_cls=SubprocVecEnv) | ||
|
||
policy_kwargs = dict(features_extractor_class=DynamicsExtractor) | ||
model = PPO( | ||
"MultiInputPolicy", | ||
env, | ||
tensorboard_log="./logs/ppo_logs/", | ||
policy_kwargs=policy_kwargs | ||
) | ||
save_model_callback = CheckpointCallback( | ||
500_000 // n_envs, | ||
"./model/backup", | ||
"ppo_model" | ||
) | ||
collect_metrics_callback = DrivingMetricsCallback(n_envs) | ||
combined_callback = CallbackList( | ||
[save_model_callback, collect_metrics_callback] | ||
) | ||
|
||
logger.info("start learning") | ||
model.learn( | ||
total_timesteps=10_000_000, | ||
progress_bar=True, | ||
callback=combined_callback | ||
) | ||
logger.info("save model") | ||
model.save("./model/ppo_model") | ||
|
||
|
||
if __name__ == '__main__': | ||
logger.info(f"python path: {sys.executable}") | ||
logger.info(f"python version: {sys.version}") | ||
|
||
logger.info("start training") | ||
training() | ||
logger.info("end training") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/usr/bin/env bash | ||
|
||
#SBATCH --job-name=robot-sf | ||
#SBATCH --partition=epyc-gpu | ||
#SBATCH --time=10:00:00 | ||
|
||
# Request memory per CPU | ||
#SBATCH --mem-per-cpu=2G | ||
# Request n CPUs for your task. | ||
#SBATCH --cpus-per-task=64 | ||
# Request GPU Ressources (model:number) | ||
#SBATCH --gpus=a100:1 | ||
|
||
# Clear all interactively loaded modules | ||
module purge | ||
|
||
# Load a python package manager | ||
module load cuda anaconda # or micromamba or condaforge | ||
|
||
# Activate a certain environment | ||
conda activate conda_env | ||
|
||
# set number of OpenMP threads (i.e. for numpy, etc...) | ||
# export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK | ||
# if you are adding your own level of parallelzation, you | ||
# probably want to set OMP_NUM_THREADS=1 instead, in order | ||
# to prevent the creation of too many threads (massive slowdown!) | ||
|
||
# No need to pass number of tasks to srun | ||
srun python3 slurm_PPO_robot_sf.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#! /bin/bash | ||
|
||
squeue -u $USER |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/bin/bash | ||
# if you forgot to clone recursively, run this script | ||
|
||
git submodule update --init --recursive |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from stable_baselines3.common.vec_env import SubprocVecEnv | ||
from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList | ||
|
||
from robot_sf.gym_env.robot_env import RobotEnv | ||
from robot_sf.gym_env.env_config import EnvSettings | ||
from robot_sf.feature_extractor import DynamicsExtractor | ||
from robot_sf.tb_logging import DrivingMetricsCallback | ||
from robot_sf.gym_env.reward import punish_action_reward | ||
|
||
|
||
def training(): | ||
n_envs = 32 | ||
ped_densities = [0.01, 0.02, 0.04, 0.08] | ||
difficulty = 2 | ||
|
||
|
||
def make_env(): | ||
config = EnvSettings() | ||
config.sim_config.ped_density_by_difficulty = ped_densities | ||
config.sim_config.difficulty = difficulty | ||
return RobotEnv( | ||
config, | ||
reward_func=punish_action_reward | ||
) | ||
|
||
env = make_vec_env(make_env, n_envs=n_envs, vec_env_cls=SubprocVecEnv) | ||
|
||
policy_kwargs = dict(features_extractor_class=DynamicsExtractor) | ||
model = PPO( | ||
"MultiInputPolicy", | ||
env, | ||
tensorboard_log="./logs/ppo_logs/", | ||
policy_kwargs=policy_kwargs | ||
) | ||
save_model_callback = CheckpointCallback( | ||
500_000 // n_envs, | ||
"./model/backup", | ||
"ppo_model" | ||
) | ||
collect_metrics_callback = DrivingMetricsCallback(n_envs) | ||
combined_callback = CallbackList( | ||
[save_model_callback, collect_metrics_callback] | ||
) | ||
|
||
model.learn( | ||
total_timesteps=10_000_000, | ||
progress_bar=True, | ||
callback=combined_callback | ||
) | ||
model.save("./model/ppo_model") | ||
|
||
|
||
if __name__ == '__main__': | ||
training() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters