-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes: #75. Disable gravity to match thesis results. Fix bug where RNG seed was ignored.
- Loading branch information
Showing
15 changed files
with
687 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
from gymfc_nf.policies.pidpolicy import PidPolicy | ||
__all__ = ['PidPolicy'] | ||
from gymfc_nf.policies.baselinespolicy import PpoBaselinesPolicy | ||
__all__ = ['PidPolicy', 'PpoBaselinesPolicy'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
from .policy import Policy | ||
class PpoBaselinesPolicy(Policy): | ||
def __init__(self, sess): | ||
graph = tf.get_default_graph() | ||
self.x = graph.get_tensor_by_name('pi/ob:0') | ||
self.y = graph.get_tensor_by_name('pi/pol/final/BiasAdd:0') | ||
self.sess = sess | ||
|
||
def action(self, state, sim_time=0, desired=np.zeros(3), actual=np.zeros(3) ): | ||
|
||
y_out = self.sess.run(self.y, feed_dict={self.x:[state] }) | ||
return y_out[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
def make_header(ob_size): | ||
"""Make the log header. | ||
This needs to be done dynamically because the observations used as input | ||
to the NN may differ. | ||
""" | ||
entries = [] | ||
entries.append("t") | ||
for i in range(ob_size): | ||
entries.append("ob{}".format(i)) | ||
for i in range(4): | ||
entries.append("ac{}".format(i)) | ||
entries.append("p") # roll rate | ||
entries.append("q") # pitch rate | ||
entries.append("r") # yaw rate | ||
entries.append("p-sp") # roll rate setpoint | ||
entries.append("q-sp") # pitch rate setpoint | ||
entries.append("r-sp") # yaw rate setpoint | ||
for i in range(4): | ||
entries.append("y{}".format(i)) | ||
for i in range(4): | ||
entries.append("w{}".format(i)) # ESC rpms | ||
entries.append("reward") | ||
|
||
return ",".join(entries) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import tensorflow as tf | ||
import os.path | ||
import time | ||
|
||
|
||
class CheckpointMonitor: | ||
"""Helper class to monitor the Tensorflow checkpoints and call a callback | ||
when a new checkpoint has been created.""" | ||
|
||
def __init__(self, checkpoint_dir, callback): | ||
""" | ||
Args: | ||
checkpoint_dir: Directory to monitor where new checkpoint | ||
directories will be created | ||
callback: A callback for when a new checkpoint is created. | ||
""" | ||
self.checkpoint_dir = checkpoint_dir | ||
self.callback = callback | ||
# Track which checkpoints have already been called. | ||
self.processed = [] | ||
|
||
self.watching = True | ||
|
||
def _check_new_checkpoint(self): | ||
"""Update the queue with newly found checkpoints. | ||
When a checkpoint directory is created a 'checkpoint' file is created | ||
containing a list of all the checkpoints. We can monitor this file to | ||
determine when new checkpoints have been created. | ||
""" | ||
# TODO (wfk) check if there is a way to get a callback when a file has | ||
# changed. | ||
|
||
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) | ||
for path in ckpt.all_model_checkpoint_paths: | ||
checkpoint_filename = os.path.split(path)[-1] | ||
if tf.train.checkpoint_exists(path): | ||
# Make sure there is a checkpoint meta file before allowing it | ||
# to be processed | ||
meta_file = path + ".meta" | ||
if os.path.isfile(meta_file): | ||
if (checkpoint_filename not in self.processed): | ||
self.callback(path) | ||
self.processed.append(checkpoint_filename) | ||
else: | ||
print ("Meta file {} doesn't exist.".format(meta_file)) | ||
|
||
def start(self): | ||
|
||
# Sit and wait until the checkpoint directory is created, otherwise we | ||
# can't monitor it. If it never gets created this could be an indicator | ||
# something is wrong with the trainer. | ||
c=0 | ||
while not os.path.isdir(self.checkpoint_dir): | ||
print("[WARN {}] Directory {} doesn't exist yet, waiting until " | ||
"created...".format(c, self.checkpoint_dir)) | ||
time.sleep(30) | ||
c+=1 | ||
|
||
while self.watching: | ||
self._check_new_checkpoint() | ||
time.sleep(10) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import argparse | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from gymfc.tools.plot import * | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser("Plot recorded flight data.") | ||
parser.add_argument("log_file", help="Log file.") | ||
parser.add_argument("--title", help="Title for the plot.", | ||
default="Aircraft Response") | ||
args = parser.parse_args() | ||
|
||
fdata = np.loadtxt(args.log_file, delimiter=",") | ||
|
||
# Plot the response | ||
f, ax = plt.subplots(5, sharex=True, sharey=False) | ||
plt.suptitle(args.title) | ||
plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False) | ||
t = fdata[:, 0] | ||
pqr = fdata[:, 11:14] | ||
pqr_sp = fdata[:, 14:17] | ||
plot_rates(ax[:3], t, pqr_sp, pqr) | ||
|
||
us = fdata[:, 17:21] | ||
plot_u(ax[3], t, us) | ||
|
||
rpms = fdata[:, 21:25] | ||
plot_motor_rpms(ax[4], t, rpms) | ||
|
||
ax[-1].set_xlabel("Time (s)") | ||
plt.show() | ||
|
Oops, something went wrong.