-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes: #75
- Loading branch information
Showing
5 changed files
with
205 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
from gymfc_nf.policies.pidpolicy import PidPolicy | ||
__all__ = ['PidPolicy'] | ||
from gymfc_nf.policies.baselinespolicy import PpoBaselinesPolicy | ||
__all__ = ['PidPolicy', 'PpoBaselinesPolicy'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from .policy import Policy | ||
class PpoBaselinesPolicy(Policy): | ||
def __init__(self, sess): | ||
graph = tf.get_default_graph() | ||
self.x = graph.get_tensor_by_name('pi/ob:0') | ||
self.y = graph.get_tensor_by_name('pi/pol/final/BiasAdd:0') | ||
self.sess = sess | ||
|
||
def action(self, state, sim_time=0, desired=np.zeros(3), actual=np.zeros(3) ): | ||
|
||
y_out = self.sess.run(self.y, feed_dict={self.x:[state] }) | ||
return y_out[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import tensorflow as tf | ||
import os.path | ||
import time | ||
|
||
|
||
class CheckpointMonitor: | ||
"""Helper class to monitor the Tensorflow checkpoints and call a callback | ||
when a new checkpoint has been created.""" | ||
|
||
def __init__(self, checkpoint_dir, callback): | ||
""" | ||
Args: | ||
checkpoint_dir: Directory to monitor where new checkpoint | ||
directories will be created | ||
callback: A callback for when a new checkpoint is created. | ||
""" | ||
self.checkpoint_dir = checkpoint_dir | ||
self.callback = callback | ||
# Track which checkpoints have already been called. | ||
self.processed = [] | ||
|
||
self.watching = True | ||
|
||
def _check_new_checkpoint(self): | ||
"""Update the queue with newly found checkpoints. | ||
When a checkpoint directory is created a 'checkpoint' file is created | ||
containing a list of all the checkpoints. We can monitor this file to | ||
determine when new checkpoints have been created. | ||
""" | ||
# TODO (wfk) check if there is a way to get a callback when a file has | ||
# changed. | ||
|
||
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) | ||
for path in ckpt.all_model_checkpoint_paths: | ||
checkpoint_filename = os.path.split(path)[-1] | ||
if tf.train.checkpoint_exists(path): | ||
# Make sure there is a checkpoint meta file before allowing it | ||
# to be processed | ||
meta_file = path + ".meta" | ||
if os.path.isfile(meta_file): | ||
if (checkpoint_filename not in self.processed): | ||
self.callback(checkpoint_filename) | ||
self.processed.append(checkpoint_filename) | ||
else: | ||
print ("Meta file {} doesn't exist.".format(meta_file)) | ||
|
||
def start(self): | ||
|
||
# Sit and wait until the checkpoint directory is created, otherwise we | ||
# can't monitor it. If it never gets created this could be an indicator | ||
# something is wrong with the trainer. | ||
c=0 | ||
while not os.path.isdir(self.checkpoint_dir): | ||
print("[WARN {}] Directory {} doesn't exist yet, waiting until " | ||
"created...".format(c, self.checkpoint_dir)) | ||
time.sleep(30) | ||
c+=1 | ||
|
||
while self.watching: | ||
self._check_new_checkpoint() | ||
time.sleep(10) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import argparse | ||
from pathlib import Path | ||
import os.path | ||
import numpy as np | ||
import tensorflow as tf | ||
import gym | ||
from gymfc_nf.envs import * | ||
from gymfc_nf.utils.monitor import CheckpointMonitor | ||
from gymfc_nf.policies import PpoBaselinesPolicy | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser("Evaluate OpenAI Baseline PPO1 checkpoints.") | ||
parser.add_argument('ckpt_dir', help="Directory where checkpoints are saved. ") | ||
parser.add_argument('--twin', default="./gymfc_nf/twins/nf1/model.sdf", | ||
help="File path of the aircraft digitial twin/model SDF.") | ||
parser.add_argument('--eval-dir', | ||
help="Directory where evaluation logs are saved.") | ||
parser.add_argument('--gym-id', default="gymfc_nf-step-v1") | ||
parser.add_argument('--num-trials', type=int, default=1) | ||
# Provide a seed so the same setpoint will be created. Useful for debugging | ||
parser.add_argument('--seed', help='RNG seed', type=int, default=-1) | ||
|
||
args = parser.parse_args() | ||
|
||
seed = np.random.randint(0, 1e6) if args.seed < 0 else args.seed | ||
gym_id = args.gym_id | ||
ckpt_dir = args.ckpt_dir | ||
model_dir = Path(ckpt_dir).parent | ||
eval_dir = args.eval_dir if args.eval_dir else os.path.join(model_dir, | ||
"evaluations") | ||
num_trials = args.num_trials | ||
print ("Saving evaluations to {}".format(eval_dir)) | ||
|
||
env = gym.make(gym_id) | ||
env.seed(seed) | ||
env.set_aircraft_model(args.twin) | ||
|
||
log_header = "" | ||
def make_header(ob_size): | ||
"""Make the log header. | ||
This needs to be done dynamically because the observations which are | ||
used as input to the NN may differ. | ||
""" | ||
entries = [] | ||
entries.append("t") | ||
for i in range(ob_size): | ||
entries.append("ob{}".format(i)) | ||
for i in range(4): | ||
entries.append("ac{}".format(i)) | ||
for i in range(4): | ||
entries.append("y{}".format(i)) | ||
entries.append("p") # roll rate | ||
entries.append("q") # pitch rate | ||
entries.append("r") # yaw rate | ||
entries.append("p-sp") # roll rate setpoint | ||
entries.append("q-sp") # pitch rate setpoint | ||
entries.append("r-sp") # yaw rate setpoint | ||
for i in range(4): | ||
entries.append("w{}".format(i)) # ESC rpms | ||
entries.append("reward") | ||
|
||
log_header = ",".join(entries) | ||
|
||
def callback(checkpoint): | ||
print ("Callback ", checkpoint) | ||
|
||
ckpt_eval_dir = os.path.join(eval_dir, checkpoint) | ||
Path(ckpt_eval_dir).mkdir(parents=True, exist_ok=True) | ||
|
||
# TODO (wfk) I'm pretty sure this just takes the last checkpoint | ||
# written defined by 'model_checkpoint_path' in the checkpoint file | ||
# should look at how to specify the exact one. | ||
checkpoint = tf.train.get_checkpoint_state(ckpt_dir) | ||
input_checkpoint = checkpoint.model_checkpoint_path | ||
print ("Using checkpoint=", input_checkpoint) | ||
with tf.Session() as sess: | ||
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', | ||
clear_devices=True) | ||
saver.restore(sess, input_checkpoint) | ||
pi = PpoBaselinesPolicy(sess) | ||
|
||
|
||
for i in range(num_trials): | ||
|
||
pi.reset() | ||
ob = env.reset() | ||
if len(log_header) == 0: | ||
make_header(len(ob)) | ||
|
||
log_file = os.path.join(ckpt_eval_dir, "trial-{}.csv".format(i)) | ||
|
||
sim_time = 0 | ||
actual = np.zeros(3) | ||
|
||
logs = [] | ||
while True: | ||
ac = pi.action(ob, env.sim_time, env.angular_rate_sp, | ||
env.imu_angular_velocity_rpy) | ||
ob, reward, done, _ = env.step(ac) | ||
|
||
log = ([env.sim_time] + | ||
ob.tolist() + # The observations are the NN input | ||
ac.tolist() + # The actions are the NN output | ||
env.y.tolist() + # Y is the output sent to the ESC | ||
|
||
env.imu_angular_velocity_rpy.tolist() + # Angular velocites | ||
env.angular_rate_sp.tolist() + # | ||
env.esc_motor_angular_velocity.tolist() + | ||
[reward])# The reward that would have been given for the action, can be helpful for debugging | ||
|
||
logs.append(log) | ||
|
||
if done: | ||
break | ||
np.savetxt(log_file, logs, delimiter=",", header=log_header) | ||
|
||
monitor = CheckpointMonitor(args.ckpt_dir, callback) | ||
monitor.start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters