-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #82 from UoA-CARES/dev/dmcs
Dev/dmcs
- Loading branch information
Showing
10 changed files
with
448 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
180 changes: 180 additions & 0 deletions
180
cares_reinforcement_learning/util/EnvironmentFactory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import logging | ||
|
||
import cv2 | ||
|
||
import gym | ||
from gym import spaces | ||
|
||
from dm_control import suite | ||
|
||
import numpy as np | ||
from collections import deque | ||
|
||
# from typing import override | ||
from functools import cached_property | ||
|
||
class EnvironmentFactory: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
def create_environment(self, gym_environment, args): | ||
logging.info(f"Training Environment: {gym_environment}") | ||
if gym_environment == 'dmcs': | ||
env = DMCSImage(args=args) if args['image_observation'] else DMCS(args=args) | ||
elif gym_environment == "openai": | ||
env = OpenAIGym(args=args) | ||
else: | ||
raise ValueError(f"Unkown environment: {gym_environment}") | ||
return env | ||
|
||
class OpenAIGym: | ||
def __init__(self, args) -> None: | ||
logging.info(f"Training task {args['task']}") | ||
self.env = gym.make(args["task"], render_mode="rgb_array") | ||
self.set_seed(args['seed']) | ||
|
||
@cached_property | ||
def max_action_value(self): | ||
return self.env.action_space.high[0] | ||
|
||
@cached_property | ||
def min_action_value(self): | ||
return self.env.action_space.low[0] | ||
|
||
@cached_property | ||
def observation_space(self): | ||
return self.env.observation_space.shape[0] | ||
|
||
@cached_property | ||
def action_num(self): | ||
if type(self.env.action_space) == spaces.Box: | ||
action_num = self.env.action_space.shape[0] | ||
elif type(self.env.action_space) == spaces.Discrete: | ||
action_num= self.env.action_space.n | ||
else: | ||
raise ValueError(f"Unhandled action space type: {type(self.env.action_space)}") | ||
return action_num | ||
|
||
def set_seed(self, seed): | ||
self.env.action_space.seed(seed) | ||
|
||
def reset(self): | ||
state, _ = self.env.reset() | ||
return state | ||
|
||
def step(self, action): | ||
state, reward, done, truncated, _ = self.env.step(action) | ||
return state, reward, done, truncated | ||
|
||
def grab_frame(self): | ||
frame = self.env.render() | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert to BGR for use with OpenCV | ||
return frame | ||
|
||
class OpenAIGymImage: | ||
def __init__(self, args, k=3): | ||
self.k = k # number of frames to be stacked | ||
self.frames_stacked = deque([], maxlen=k) | ||
|
||
super().__init__(args=args) | ||
|
||
# @override | ||
@property | ||
def observation_space(self): | ||
raise NotImplementedError("Not Implemented Yet") | ||
|
||
# @override | ||
def reset(self): | ||
_ = self.env.reset() | ||
frame = self.env.physics.render(84, 84, camera_id=0) # --> shape= (84, 84, 3) | ||
frame = np.moveaxis(frame, -1, 0) # --> shape= (3, 84, 84) | ||
for _ in range(self.k): | ||
self.frames_stacked.append(frame) | ||
stacked_frames = np.concatenate(list(self.frames_stacked), axis=0) # --> shape = (9, 84, 84) | ||
return stacked_frames | ||
|
||
# @override | ||
def step(self, action): | ||
time_step = self.env.step(action) | ||
reward, done = time_step.reward, time_step.last() | ||
frame = self.env.physics.render(84, 84, camera_id=0) | ||
frame = np.moveaxis(frame, -1, 0) | ||
self.frames_stacked.append(frame) | ||
stacked_frames = np.concatenate(list(self.frames_stacked), axis=0) | ||
return stacked_frames, reward, done, False # for consistency with open ai gym just add false for truncated | ||
|
||
class DMCS: | ||
def __init__(self, args) -> None: | ||
logging.info(f"Training on Domain {args['domain']}") | ||
logging.info(f"Training with Task {args['task']}") | ||
|
||
self.env = suite.load(args['domain'], args['task'], task_kwargs={'random': args['seed']}) | ||
|
||
@cached_property | ||
def min_action_value(self): | ||
return self.env.action_spec().minimum[0] | ||
|
||
@cached_property | ||
def max_action_value(self): | ||
return self.env.action_spec().maximum[0] | ||
|
||
@cached_property | ||
def observation_space(self): | ||
time_step = self.env.reset() | ||
observation = np.hstack(list(time_step.observation.values())) # # e.g. position, orientation, joint_angles | ||
return len(observation) | ||
|
||
@cached_property | ||
def action_num(self): | ||
return self.env.action_spec().shape[0] | ||
|
||
def set_seed(self, seed): | ||
self.env = suite.load(self.env.domain, self.env.task, task_kwargs={'random': seed}) | ||
|
||
def reset(self): | ||
time_step = self.env.reset() | ||
observation = np.hstack(list(time_step.observation.values())) # # e.g. position, orientation, joint_angles | ||
return observation | ||
|
||
def step(self, action): | ||
time_step = self.env.step(action) | ||
state, reward, done = np.hstack(list(time_step.observation.values())), time_step.reward, time_step.last() | ||
return state, reward, done, False # for consistency with open ai gym just add false for truncated | ||
|
||
def grab_frame(self): | ||
frame = self.env.physics.render(camera_id=0, height=240, width=300) | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert to BGR for use with OpenCV | ||
return frame | ||
|
||
# TODO paramatise the observation size 3x84x84 | ||
class DMCSImage(DMCS): | ||
def __init__(self, args, k=3): | ||
self.k = k # number of frames to be stacked | ||
self.frames_stacked = deque([], maxlen=k) | ||
|
||
super().__init__(args=args) | ||
|
||
# @override | ||
@property | ||
def observation_space(self): | ||
raise NotImplementedError("Not Implemented Yet") | ||
|
||
# @override | ||
def reset(self): | ||
_ = self.env.reset() | ||
frame = self.env.physics.render(84, 84, camera_id=0) # --> shape= (84, 84, 3) | ||
frame = np.moveaxis(frame, -1, 0) # --> shape= (3, 84, 84) | ||
for _ in range(self.k): | ||
self.frames_stacked.append(frame) | ||
stacked_frames = np.concatenate(list(self.frames_stacked), axis=0) # --> shape = (9, 84, 84) | ||
return stacked_frames | ||
|
||
# @override | ||
def step(self, action): | ||
time_step = self.env.step(action) | ||
reward, done = time_step.reward, time_step.last() | ||
frame = self.env.physics.render(84, 84, camera_id=0) | ||
frame = np.moveaxis(frame, -1, 0) | ||
self.frames_stacked.append(frame) | ||
stacked_frames = np.concatenate(list(self.frames_stacked), axis=0) | ||
return stacked_frames, reward, done, False # for consistency with open ai gym just add false for truncated |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .NetworkFactory import NetworkFactory | ||
from .Record import Record | ||
from .EnvironmentFactory import EnvironmentFactory | ||
|
Oops, something went wrong.