Skip to content

Commit

Permalink
Cleaning up MPC and LQR example scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Dec 3, 2024
1 parent 3313d2a commit 3085883
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 68 deletions.
22 changes: 4 additions & 18 deletions examples/lqr/lqr_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from safe_control_gym.utils.registration import make


def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
def run(gui=False, plot=True, n_episodes=2, n_steps=None, save_data=False):
'''The main function running LQR and iLQR experiments.
Args:
gui (bool): Whether to display the gui and plot graphs.
gui (bool): Whether to display the gui.
plot (bool): Whether to plot graphs.
n_episodes (int): The number of episodes to execute.
n_steps (int): The total number of steps to execute.
save_data (bool): Whether to save the collected experiment data.
Expand Down Expand Up @@ -61,7 +62,7 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
else:
trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps)

if gui:
if plot:
post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env)

# Close environments
Expand Down Expand Up @@ -132,20 +133,5 @@ def post_analysis(state_stack, input_stack, env):
plt.show()


def wrap2pi_vec(angle_vec):
'''Wraps a vector of angles between -pi and pi.
Args:
angle_vec (ndarray): A vector of angles.
'''
for k, angle in enumerate(angle_vec):
while angle > np.pi:
angle -= np.pi
while angle <= -np.pi:
angle += np.pi
angle_vec[k] = angle
return angle_vec


if __name__ == '__main__':
run()
61 changes: 11 additions & 50 deletions examples/mpc/mpc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle
from collections import defaultdict
from functools import partial

import matplotlib.pyplot as plt
Expand All @@ -15,11 +14,12 @@
from safe_control_gym.utils.registration import make


def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
def run(gui=False, plot=True, n_episodes=2, n_steps=None, save_data=False):
'''The main function running MPC and Linear MPC experiments.
Args:
gui (bool): Whether to display the gui and plot graphs.
gui (bool): Whether to display the gui.
plot (bool): Whether to plot graphs.
n_episodes (int): The number of episodes to execute.
n_steps (int): The total number of steps to execute.
save_data (bool): Whether to save the collected experiment data.
Expand All @@ -34,51 +34,27 @@ def run(gui=True, n_episodes=1, n_steps=None, save_data=False):
config.task,
**config.task_config
)
random_env = env_func(gui=False)
env = env_func(gui=gui)

# Create controller.
ctrl = make(config.algo,
env_func,
**config.algo_config
)

all_trajs = defaultdict(list)
n_episodes = 1 if n_episodes is None else n_episodes

# Run the experiment.
for _ in range(n_episodes):
# Get initial state and create environments
init_state, _ = random_env.reset()
static_env = env_func(gui=gui, randomized_init=False, init_state=init_state)
static_train_env = env_func(gui=False, randomized_init=False, init_state=init_state)

# Create experiment, train, and run evaluation
experiment = BaseExperiment(env=static_env, ctrl=ctrl, train_env=static_train_env)
experiment.launch_training()

if n_steps is None:
trajs_data, _ = experiment.run_evaluation(training=True, n_episodes=1)
else:
trajs_data, _ = experiment.run_evaluation(training=True, n_steps=n_steps)

if gui:
post_analysis(trajs_data['obs'][0], trajs_data['action'][0], ctrl.env)

# Close environments
static_env.close()
static_train_env.close()
experiment = BaseExperiment(env=env, ctrl=ctrl)
trajs_data, metrics = experiment.run_evaluation(training=True, n_episodes=n_episodes, n_steps=n_steps)

# Merge in new trajectory data
for key, value in trajs_data.items():
all_trajs[key] += value
if plot:
for i in range(len(trajs_data['obs'])):
post_analysis(trajs_data['obs'][i], trajs_data['action'][i], ctrl.env)

ctrl.close()
random_env.close()
metrics = experiment.compute_metrics(all_trajs)
all_trajs = dict(all_trajs)
env.close()

if save_data:
results = {'trajs_data': all_trajs, 'metrics': metrics}
results = {'trajs_data': trajs_data, 'metrics': metrics}
path_dir = os.path.dirname('./temp-data/')
os.makedirs(path_dir, exist_ok=True)
with open(f'./temp-data/{config.algo}_data_{config.task}_{config.task_config.task}.pkl', 'wb') as file:
Expand Down Expand Up @@ -132,20 +108,5 @@ def post_analysis(state_stack, input_stack, env):
plt.show()


def wrap2pi_vec(angle_vec):
'''Wraps a vector of angles between -pi and pi.
Args:
angle_vec (ndarray): A vector of angles.
'''
for k, angle in enumerate(angle_vec):
while angle > np.pi:
angle -= np.pi
while angle <= -np.pi:
angle += np.pi
angle_vec[k] = angle
return angle_vec


if __name__ == '__main__':
run()

0 comments on commit 3085883

Please sign in to comment.