Skip to content

Commit

Permalink
added saving for info - saves as tensors...need to consider waht is a…
Browse files Browse the repository at this point in the history
…ctually required
  • Loading branch information
beardyFace committed Oct 6, 2023
1 parent 4ea4876 commit bd950ea
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 11 deletions.
6 changes: 4 additions & 2 deletions cares_reinforcement_learning/util/Record.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, glob_log_dir=None, log_dir=None, network=None, config=None) -

self.train_data = pd.DataFrame()
self.eval_data = pd.DataFrame()
self.info_data = pd.DataFrame()

self.network = network

Expand All @@ -35,8 +36,9 @@ def __init__(self, glob_log_dir=None, log_dir=None, network=None, config=None) -
with open(f'{self.directory}/config.yml', 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)

def log_info(self, display=False, **logs):
pass # TODO implement logger for info from training the network e.g. loss rates etc etc
def log_info(self, info, display=False):
self.info_data = pd.concat([self.info_data, pd.DataFrame([info])], ignore_index=True)
self.save_data(self.info_data, "info", info, display=display)

def log_train(self, display=False, **logs):
self.log_count += 1
Expand Down
6 changes: 3 additions & 3 deletions cares_reinforcement_learning/util/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
def plot_data(plot_frame, title, label, x_label, y_label, directory, filename, display=True, close_figure=True):
window_size = plot_frame["window_size"]

# TODO make font size a parameter
plt.xlabel(x_label, fontsize=10)
plt.ylabel(y_label, fontsize=10)
plt.title(title, fontsize=10)

ax = sns.lineplot(data=plot_frame, x=plot_frame["steps"], y="avg", label=label)
# ax.set(xlabel=x_label, ylabel=y_label)

Z = 1.960 # 95% confidence interval
confidence_interval = Z * plot_frame["std_dev"] / np.sqrt(window_size)
Expand Down Expand Up @@ -122,8 +122,8 @@ def main():
train_plot_frames.append(train_plot_frame)
eval_plot_frames.append(eval_plot_frame)

plot_comparisons(train_plot_frames, f"{title}", labels, "Steps", "Average Reward", directory, "compare-train", True)
plot_comparisons(eval_plot_frames, f"{title}", labels, "Steps", "Average Reward", directory, "compare-eval", True)
plot_comparisons(train_plot_frames, f"{title}", labels, "Steps", "Average Reward", directory, f"{title}-compare-train", True)
plot_comparisons(eval_plot_frames, f"{title}", labels, "Steps", "Average Reward", directory, f"{title}-compare-eval", True)

if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions example/example_training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def main():

logging.info(f"Memory: {args['memory']}")

# Train the policy or value based approach
#create the record class - standardised results tracking
record = Record(network=agent, config={'args': args})
# Train the policy or value based approach
if args["algorithm"] == "PPO":
#create the record class
ppe.ppo_train(env, agent, record, args)
env = gym.make(env.spec.id, render_mode="human")
ppe.evaluate_ppo_network(env, agent, args)
Expand Down
4 changes: 2 additions & 2 deletions example/policy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def policy_based_train(env, agent, memory, record, args):
experience['next_state'],
experience['done']
))
memory.update_priorities(experience['indices'], info)
# TODO add saving info information from train_policy as seperate recording
memory.update_priorities(experience['indices'], info)
# record.log_info(info, display=False)

if (total_step_counter+1) % number_steps_per_evaluation == 0:
evaluate = True
Expand Down
2 changes: 1 addition & 1 deletion example/ppo_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def ppo_train(env, agent, record, args):
experience['done'],
experience['log_prob']
))
# TODO add saving info information from train_policy as seperate recording
# record.log_info(info, display=False)

if (total_step_counter+1) % number_steps_per_evaluation == 0:
evaluate = True
Expand Down
2 changes: 1 addition & 1 deletion example/value_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def value_based_train(env, agent, memory, record, args):
experience['done']
))
memory.update_priorities(experience['indices'], info)
# TODO add saving info information from train_policy as seperate recording
# record.log_info(info, display=False)

if (total_step_counter+1) % number_steps_per_evaluation == 0:
evaluate = True
Expand Down

0 comments on commit bd950ea

Please sign in to comment.