-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ll7/issue48 retrain the robot with a reward that penalizes quick action changes #54
Merged
ll7
merged 23 commits into
main
from
ll7/issue48-Retrain-the-robot-with-a-reward-that-penalizes-quick-action-changes
Sep 23, 2024
Merged
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
d40acce
fix: adjust input and output parameters, delete deprecated seed method
JuliusMiller 2eba09b
add punish_action_reward
ll7 13ef7ec
extend the reward dictionary to include information about the action …
ll7 b4e4188
create new train script for ppo with action punish
ll7 d876549
remove setting last action too early
ll7 e85ac16
increase timesteps to be trained
ll7 eec41b0
Add git_submodule.sh script for recursive submodule update
ll7 9778259
chore: Add SLURM user queue script
ll7 904d4fe
chore: Add SLURM training script
ll7 1e1c3a4
chore: Update SLURM training script to include PPO training for robot_sf
ll7 0a76b3e
feat: Update SLURM training script to specify slurm server for robot_…
ll7 964b073
chore: Add .gitignore file for SLURM directory for *.out files
ll7 0e3ba9d
chore: Add SLURM load_module.sh script
ll7 09d7269
train 10m steps
ll7 c280e3b
add sbatch note
ll7 39c81ae
chore: Add .cursorrules file
ll7 155252e
chore: Update SLURM training script to specify job name and partition…
ll7 42aa3a2
add retrained model for evaluation
ll7 41e8e34
do not ignore global models
ll7 0e46529
chore: Refactor SLURM training script for robot_sf
ll7 defd7a5
chore: Update env_test.py to handle termination and truncation in tes…
ll7 825c9ff
Update sb3_test.py to include additional information in the observati…
ll7 f4f0ce0
chore: Update robot_env.py
ll7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
You are an AI assistant specialized in Python development. Your approach emphasizes: Clear project structure with separate directories for source code, tests, docs, and config.Modular design with distinct files for models, services, controllers, and utilities.Configuration management using environment variables. Robust error handling and logging, including context capture. Comprehensive testing with pytest. Detailed documentation using docstrings and README files.Dependency management via requirements.txt and virtual environments. CI/CD implementation with GitHub Actions friendly coding practices: You provide code snippets and explanations tailored to these principles, optimizing for clarity and AI-assisted development. Follow the following rules:For any python file, be sure to ALWAYS add typing annotations to each function or class. Be sure to include return types when necessary. Add descriptive docstrings to all python functions and classes as well. Please use pep257 convention. Update existing docstrings if need be. Make sure you keep any comments that exist in a file. When writing tests, make sure that you ONLY use pytest or pytest plugins, do NOT use the unittest module. All tests should have typing annotations as well. All tests should be in ./tests. Be sure to create all necessary files and folders. If you are creating files inside of ./tests, be sure to make a init.py file if one does not exist. All tests should be fully annotated and should contain docstrings. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,8 @@ build | |
file.log | ||
images | ||
logs | ||
model | ||
# model | ||
SLURM/model | ||
profile.json | ||
profiles | ||
pysf_tests | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#! /bin/bash | ||
|
||
module purge | ||
|
||
module load anaconda cuda | ||
|
||
conda activate conda_env | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# SLURM | ||
|
||
```bash | ||
sbatch slurm_train.sl | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"train a robot in robot_sf on a slurm server" | ||
|
||
""" | ||
This script is used to train a PPO model on the CartPole-v1 environment. | ||
""" | ||
import sys | ||
from datetime import datetime | ||
|
||
import gymnasium as gym | ||
from loguru import logger | ||
|
||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from stable_baselines3.common.vec_env import SubprocVecEnv | ||
from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList | ||
|
||
from robot_sf.gym_env.robot_env import RobotEnv | ||
from robot_sf.gym_env.env_config import EnvSettings | ||
from robot_sf.feature_extractor import DynamicsExtractor | ||
from robot_sf.tb_logging import DrivingMetricsCallback | ||
|
||
|
||
def training(): | ||
n_envs = 64 | ||
ped_densities = [0.01, 0.02, 0.04, 0.08] | ||
difficulty = 2 | ||
|
||
|
||
def make_env(): | ||
config = EnvSettings() | ||
config.sim_config.ped_density_by_difficulty = ped_densities | ||
config.sim_config.difficulty = difficulty | ||
return RobotEnv(config) | ||
|
||
env = make_vec_env(make_env, n_envs=n_envs, vec_env_cls=SubprocVecEnv) | ||
|
||
policy_kwargs = dict(features_extractor_class=DynamicsExtractor) | ||
model = PPO( | ||
"MultiInputPolicy", | ||
env, | ||
tensorboard_log="./logs/ppo_logs/", | ||
policy_kwargs=policy_kwargs | ||
) | ||
save_model_callback = CheckpointCallback( | ||
500_000 // n_envs, | ||
"./model/backup", | ||
"ppo_model" | ||
) | ||
collect_metrics_callback = DrivingMetricsCallback(n_envs) | ||
combined_callback = CallbackList( | ||
[save_model_callback, collect_metrics_callback] | ||
) | ||
|
||
logger.info("start learning") | ||
model.learn( | ||
total_timesteps=10_000_000, | ||
progress_bar=True, | ||
callback=combined_callback | ||
) | ||
logger.info("save model") | ||
model.save("./model/ppo_model") | ||
|
||
|
||
if __name__ == '__main__': | ||
logger.info(f"python path: {sys.executable}") | ||
logger.info(f"python version: {sys.version}") | ||
|
||
logger.info("start training") | ||
training() | ||
logger.info("end training") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/usr/bin/env bash | ||
|
||
#SBATCH --job-name=robot-sf | ||
#SBATCH --partition=epyc-gpu | ||
#SBATCH --time=10:00:00 | ||
|
||
# Request memory per CPU | ||
#SBATCH --mem-per-cpu=2G | ||
# Request n CPUs for your task. | ||
#SBATCH --cpus-per-task=64 | ||
# Request GPU Ressources (model:number) | ||
#SBATCH --gpus=a100:1 | ||
|
||
# Clear all interactively loaded modules | ||
module purge | ||
|
||
# Load a python package manager | ||
module load cuda anaconda # or micromamba or condaforge | ||
|
||
# Activate a certain environment | ||
conda activate conda_env | ||
|
||
# set number of OpenMP threads (i.e. for numpy, etc...) | ||
# export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK | ||
# if you are adding your own level of parallelzation, you | ||
# probably want to set OMP_NUM_THREADS=1 instead, in order | ||
# to prevent the creation of too many threads (massive slowdown!) | ||
|
||
# No need to pass number of tasks to srun | ||
srun python3 slurm_PPO_robot_sf.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#! /bin/bash | ||
|
||
squeue -u $USER |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#!/bin/bash | ||
# if you forgot to clone recursively, run this script | ||
|
||
git submodule update --init --recursive |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from stable_baselines3 import PPO | ||
from stable_baselines3.common.env_util import make_vec_env | ||
from stable_baselines3.common.vec_env import SubprocVecEnv | ||
from stable_baselines3.common.callbacks import CheckpointCallback, CallbackList | ||
|
||
from robot_sf.gym_env.robot_env import RobotEnv | ||
from robot_sf.gym_env.env_config import EnvSettings | ||
from robot_sf.feature_extractor import DynamicsExtractor | ||
from robot_sf.tb_logging import DrivingMetricsCallback | ||
from robot_sf.gym_env.reward import punish_action_reward | ||
|
||
|
||
def training(): | ||
n_envs = 32 | ||
ped_densities = [0.01, 0.02, 0.04, 0.08] | ||
difficulty = 2 | ||
|
||
|
||
def make_env(): | ||
config = EnvSettings() | ||
config.sim_config.ped_density_by_difficulty = ped_densities | ||
config.sim_config.difficulty = difficulty | ||
return RobotEnv( | ||
config, | ||
reward_func=punish_action_reward | ||
) | ||
|
||
env = make_vec_env(make_env, n_envs=n_envs, vec_env_cls=SubprocVecEnv) | ||
|
||
policy_kwargs = dict(features_extractor_class=DynamicsExtractor) | ||
model = PPO( | ||
"MultiInputPolicy", | ||
env, | ||
tensorboard_log="./logs/ppo_logs/", | ||
policy_kwargs=policy_kwargs | ||
) | ||
save_model_callback = CheckpointCallback( | ||
500_000 // n_envs, | ||
"./model/backup", | ||
"ppo_model" | ||
) | ||
collect_metrics_callback = DrivingMetricsCallback(n_envs) | ||
combined_callback = CallbackList( | ||
[save_model_callback, collect_metrics_callback] | ||
) | ||
|
||
model.learn( | ||
total_timesteps=10_000_000, | ||
progress_bar=True, | ||
callback=combined_callback | ||
) | ||
model.save("./model/ppo_model") | ||
|
||
|
||
if __name__ == '__main__': | ||
training() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Address the static analysis hints.
The static analysis tool suggests removing unused imports:
datetime
import is indeed unused in the code and can be safely removed.gymnasium
import is used indirectly through themake_vec_env
function fromstable_baselines3.common.env_util
. Verify if this import is necessary for themake_vec_env
function to work correctly before removing it.Apply this diff to remove the unused
datetime
import:-from datetime import datetime
Tools
Ruff