Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/configurations generic #101

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ An example is found below for running on the OpenAI and DMCS environments with T
python example_training_loops.py run --gym openai --task HalfCheetah-v4 TD3


python3 example_training_loops.py run dmcs --domain ball_in_cup --task catch TD3
python3 example_training_loops.py run --gym dmcs --domain ball_in_cup --task catch TD3
```

An example is found below for running using pre-defined configuration files
Expand Down
12 changes: 6 additions & 6 deletions cares_reinforcement_learning/util/EnvironmentFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# from typing import override
from functools import cached_property

from cares_reinforcement_learning.util.configurations import EnvironmentConfig
from cares_reinforcement_learning.util.configurations import GymEnvironmentConfig

class EnvironmentFactory:
def __init__(self) -> None:
pass

def create_environment(self, config: EnvironmentConfig):
def create_environment(self, config: GymEnvironmentConfig):
logging.info(f"Training Environment: {config.gym}")
if config.gym == 'dmcs':
env = DMCSImage(config) if config.image_observation else DMCS(config)
Expand All @@ -30,7 +30,7 @@ def create_environment(self, config: EnvironmentConfig):
return env

class OpenAIGym:
def __init__(self, config: EnvironmentConfig) -> None:
def __init__(self, config: GymEnvironmentConfig) -> None:
logging.info(f"Training task {config.task}")
self.env = gym.make(config.task, render_mode="rgb_array")

Expand Down Expand Up @@ -74,7 +74,7 @@ def grab_frame(self, height=240, width=300):
return frame

class OpenAIGymImage(OpenAIGym):
def __init__(self, config: EnvironmentConfig, k=3):
def __init__(self, config: GymEnvironmentConfig, k=3):
self.k = k # number of frames to be stacked
self.frames_stacked = deque([], maxlen=k)

Expand Down Expand Up @@ -109,7 +109,7 @@ def step(self, action):
return stacked_frames, reward, done, False # for consistency with open ai gym just add false for truncated

class DMCS:
def __init__(self, config: EnvironmentConfig) -> None:
def __init__(self, config: GymEnvironmentConfig) -> None:
logging.info(f"Training on Domain {config.domain}")
logging.info(f"Training with Task {config.task}")

Expand Down Expand Up @@ -155,7 +155,7 @@ def grab_frame(self, camera_id=0, height=240, width=300):

# TODO paramatise the observation size 3x84x84
class DMCSImage(DMCS):
def __init__(self, config: EnvironmentConfig, k=3):
def __init__(self, config: GymEnvironmentConfig, k=3):
self.k = k # number of frames to be stacked
self.frames_stacked = deque([], maxlen=k)

Expand Down
145 changes: 87 additions & 58 deletions cares_reinforcement_learning/util/RLParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging

import cares_reinforcement_learning.util.configurations as configurations
from cares_reinforcement_learning.util.configurations import TrainingConfig, AlgorithmConfig, EnvironmentConfig
from cares_reinforcement_learning.util.configurations import TrainingConfig, AlgorithmConfig, GymEnvironmentConfig, EnvironmentConfig, SubscriptableClass
import json

import pydantic
Expand All @@ -17,47 +17,54 @@
import inspect
from typing import get_origin

def add_model(parser, model):
"Add Pydantic model to an ArgumentParser"
fields = model.__fields__
for name, field in fields.items():
nargs = '+' if get_origin(field.annotation) is list else None
parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
required=field.required,
nargs=nargs
)

def get_algorithm_parser():
alg_parser = argparse.ArgumentParser()
alg_parsers = alg_parser.add_subparsers(help='Select which RL algorith you want to use', dest='algorithm', required=True)

for name, cls in inspect.getmembers(configurations, inspect.isclass):
if issubclass(cls, AlgorithmConfig) and cls != AlgorithmConfig:
name = name.replace('Config', '')
cls_parser = alg_parsers.add_parser(name, help=name)
add_model(cls_parser, cls)

return alg_parser, alg_parsers

class RLParser:
def __init__(self) -> None:
self.algorithm_parser, self.algorithm_parsers = get_algorithm_parser()
def __init__(self, EnvironmentConfig = GymEnvironmentConfig) -> None:
self.configurations = {}

self.algorithm_parser, self.algorithm_parsers = self._get_algorithm_parser()

self.algorithm_configurations = {}
for name, cls in inspect.getmembers(configurations, inspect.isclass):
if issubclass(cls, AlgorithmConfig) and cls != AlgorithmConfig:
self.algorithm_configurations[name] = cls

self.add_configuration("env_config", EnvironmentConfig)
self.add_configuration("training_config", TrainingConfig)

def add_model(self, parser, model):
fields = model.__fields__
for name, field in fields.items():
nargs = '+' if get_origin(field.annotation) is list else None
parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
required=field.required,
nargs=nargs
)

def _get_algorithm_parser(self):
alg_parser = argparse.ArgumentParser()
alg_parsers = alg_parser.add_subparsers(help='Select which RL algorith you want to use', dest='algorithm', required=True)

self.algorithm_configs = {}
for name, cls in inspect.getmembers(configurations, inspect.isclass):
if issubclass(cls, AlgorithmConfig) and cls != AlgorithmConfig:
self.algorithm_configs[name] = cls

def add_algorithm(self, algorithm_model):
name = algorithm_model.__name__.replace('Config', '')
name = name.replace('Config', '')
cls_parser = alg_parsers.add_parser(name, help=name)
self.add_model(cls_parser, cls)

return alg_parser, alg_parsers

def add_algorithm_config(self, AlgorithmConfig):
name = AlgorithmConfig.__name__.replace('Config', '')
parser = self.algorithm_parsers.add_parser(f"{name}", help=f"{name}")
add_model(parser, algorithm_model)
self.algorithm_configs[algorithm_model.__name__] = algorithm_model
self.add_model(parser, AlgorithmConfig)
self.algorithm_configurations[AlgorithmConfig.__name__] = AlgorithmConfig

def add_configuration(self, name, Configuration):
self.configurations[name] = Configuration

def parse_args(self):
parser = argparse.ArgumentParser(usage="<command> [<args>]")
Expand All @@ -73,46 +80,60 @@ def parse_args(self):
# use dispatch pattern to invoke method with same name
self.args = getattr(self, f"_{cmd_arg.command}")()
print(self.args)
env_config = EnvironmentConfig(**self.args)
training_config = TrainingConfig(**self.args)
algorithm_config = self.algorithm_configs[f"{self.args['algorithm']}Config"](**self.args)
return env_config, training_config, algorithm_config

configurations = {}

for name, Configuration in self.configurations.items():
configuration = Configuration(**self.args)
configurations[name] = configuration

algorithm_config = self.algorithm_configurations[f"{self.args['algorithm']}Config"](**self.args)
configurations['algorithm_config'] = algorithm_config

return configurations

def _config(self):
parser = argparse.ArgumentParser()
required = parser.add_argument_group('required arguments')
required.add_argument("--env_config", type=str, required=True, help='Configuration path for the environment')
required.add_argument("--training_config", type=str, required=True, help='Configuration path that defines the training parameters')
required.add_argument("--algorithm_config", type=str, required=True, help='Configuration path that defines the algorithm and its learning parameters')

for name, configuration in self.configurations.items():
required.add_argument(f"--{name}", required=True, help=f"Configuration path for {name}")

config_args = parser.parse_args(sys.argv[2:])

args = {}
with open(config_args.env_config) as f:
env_args = json.load(f)

with open(config_args.training_config) as f:
training_config = json.load(f)

with open(config_args.algorithm_config) as f:
algorithm_config = json.load(f)

args.update(env_args)
args.update(training_config)
args.update(algorithm_config)

config_args = vars(config_args)
for name, configuration in self.configurations.items():
with open(config_args[name]) as f:
config = json.load(f)
args.update(config)

return args

def _run(self):
parser = argparse.ArgumentParser()

add_model(parser, EnvironmentConfig)
add_model(parser, TrainingConfig)
for _, Configuration in self.configurations.items():
self.add_model(parser, Configuration)

firt_args, rest = parser.parse_known_args(sys.argv[2:])

alg_args, rest = self.algorithm_parser.parse_known_args(rest)

if len(rest) > 0:
logging.warn(f"Arugements not being passed properly and have been left over: {rest}")

args = Namespace(**vars(firt_args), **vars(alg_args))
return vars(args)

## Example of how to use the RLParser for custom environments - in this case the LAMO task
from pydantic import BaseModel, Field
from typing import List, Optional, Literal
class LMAOConfig(AlgorithmConfig):
Expand All @@ -124,10 +145,18 @@ class LMAOConfig(AlgorithmConfig):
exploration_min: Optional[float] = 1e-3
exploration_decay: Optional[float] = 0.95

class LMAOEnvironmentConfig(EnvironmentConfig):
gym: str = Field("LMAO-Gym", Literal=True)
task: str
domain: Optional[str] = None
image_observation: Optional[bool] = False

class LMAOHardwareConfig(SubscriptableClass):
value: str = "rofl-copter"

if __name__ == '__main__':
parser = RLParser()
parser.add_algorithm(LMAOConfig)
env_config, training_config, algorithm_config = parser.parse_args()
print(env_config)
print(training_config)
print(algorithm_config)
parser = RLParser(LMAOEnvironmentConfig)
parser.add_configuration("lmao_config",LMAOHardwareConfig)
parser.add_algorithm_config(LMAOConfig)
configurations = parser.parse_args()
print(configurations)
5 changes: 4 additions & 1 deletion cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
class SubscriptableClass(BaseModel):
def __getitem__(self, item):
return getattr(self, item)

class EnvironmentConfig(SubscriptableClass):
task: str

class GymEnvironmentConfig(EnvironmentConfig):
gym: str = Field(description='Gym Environment <openai, dmcs>')
task: str
domain: Optional[str] = None
Expand Down
8 changes: 6 additions & 2 deletions example/example_training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cares_reinforcement_learning.util import helpers as hlp

import cares_reinforcement_learning.util.configurations as configurations
from cares_reinforcement_learning.util.configurations import TrainingConfig, AlgorithmConfig, EnvironmentConfig
from cares_reinforcement_learning.util.configurations import TrainingConfig, AlgorithmConfig, GymEnvironmentConfig

import cares_reinforcement_learning.train_loops.policy_loop as pbe
import cares_reinforcement_learning.train_loops.value_loop as vbe
Expand All @@ -30,7 +30,11 @@

def main():
parser = RLParser()
env_config, training_config, alg_config = parser.parse_args()

configurations = parser.parse_args()
env_config = configurations["env_config"]
training_config = configurations["training_config"]
alg_config = configurations["algorithm_config"]

env_factory = EnvironmentFactory()
network_factory = NetworkFactory()
Expand Down