diff --git a/cares_reinforcement_learning/util/EnvironmentFactory.py b/cares_reinforcement_learning/util/EnvironmentFactory.py index d92a1944..f24a5336 100644 --- a/cares_reinforcement_learning/util/EnvironmentFactory.py +++ b/cares_reinforcement_learning/util/EnvironmentFactory.py @@ -108,7 +108,9 @@ def __init__(self, args) -> None: logging.info(f"Training on Domain {args['domain']}") logging.info(f"Training with Task {args['task']}") - self.env = suite.load(args['domain'], args['task'], task_kwargs={'random': args['seed']}) + self.domain = args['domain'] + self.task = args['task'] + self.env = suite.load(self.domain, self.task, task_kwargs={'random': args['seed']}) @cached_property def min_action_value(self): @@ -129,7 +131,7 @@ def action_num(self): return self.env.action_spec().shape[0] def set_seed(self, seed): - self.env = suite.load(self.env.domain, self.env.task, task_kwargs={'random': seed}) + self.env = suite.load(self.domain, self.task, task_kwargs={'random': seed}) def reset(self): time_step = self.env.reset() diff --git a/example/example_training_loops.py b/example/example_training_loops.py index 99d1c2fe..573b08a8 100644 --- a/example/example_training_loops.py +++ b/example/example_training_loops.py @@ -21,6 +21,7 @@ import random import numpy as np from pathlib import Path +from datetime import datetime def set_seed(seed): torch.manual_seed(seed) @@ -62,16 +63,18 @@ def main(): logging.info(f"Memory: {args['memory']}") - seed = args['seed'] + iterations_folder = f"{args['algorithm']}-{args['task']}-{datetime.now().strftime('%y_%m_%d_%H:%M:%S')}" + glob_log_dir = f'{Path.home()}/cares_rl_logs/{iterations_folder}' training_iterations = args['number_training_iterations'] for training_iteration in range(0, training_iterations): - logging.info(f"Training iteration {training_iteration+1}/{training_iterations} with Seed: {seed}") - set_seed(seed) - env.set_seed(seed) + logging.info(f"Training iteration {training_iteration+1}/{training_iterations} with Seed: {args['seed']}") + set_seed(args['seed']) + env.set_seed(args['seed']) #create the record class - standardised results tracking - record = Record(network=agent, config={'args': args}) + log_dir = args['seed'] + record = Record(glob_log_dir=glob_log_dir, log_dir=log_dir, network=agent, config={'args': args}) # Train the policy or value based approach if args["algorithm"] == "PPO": @@ -82,7 +85,7 @@ def main(): vbe.value_based_train(env, agent, memory, record, args) else: raise ValueError(f"Agent type is unkown: {agent.type}") - seed += 10 + args['seed'] += 10 record.save()