Skip to content

Commit

Permalink
Fix retrieval of observation and action spaces when using wrappers (#129
Browse files Browse the repository at this point in the history
)

* Fix unwrapped of observation and action spaces

* Clean ModelEnv termination function
  • Loading branch information
LucasAlegre authored Dec 4, 2024
1 parent b39f316 commit 92572aa
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
14 changes: 6 additions & 8 deletions morl_baselines/common/model_based/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,18 @@ def __init__(self, model, env_id=None, rew_dim=1):
"""
self.model = model
self.rew_dim = rew_dim
if env_id == "Hopper-v2" or env_id == "Hopper-v4" or env_id == "mo-hopper-v4" or env_id == "mo-hopper-2d-v4":
if "hopper" in env_id:
self.termination_func = termination_fn_hopper
elif env_id == "HalfCheetah-v2" or env_id == "mo-halfcheetah-v4":
elif "halfcheetah" in env_id:
self.termination_func = termination_fn_false
elif env_id == "LunarLanderContinuous-v2" or env_id.startswith("mo-lunar-lander"):
elif "lunar-lander" in env_id:
self.termination_func = termination_fn_false
elif env_id == "ReacherMultiTask-v0" or env_id.startswith("mo-reacher-v"):
elif "mo-reacher" in env_id:
self.termination_func = termination_fn_false
elif env_id == "MountainCarContinuous-v0" or env_id.startswith("mo-mountaincar"):
elif "mountaincar" in env_id:
self.termination_func = termination_fn_mountaincar
elif env_id == "minecart-v0":
elif "minecart" in env_id:
self.termination_func = termination_fn_minecart
elif env_id == "SEIRsingle-v0":
self.termination_func = termination_fn_false
elif env_id == "mo-highway-fast-v0" or env_id == "mo-highway-v0":
self.termination_func = termination_fn_false
elif env_id == "deep-sea-treasure-v0":
Expand Down
17 changes: 9 additions & 8 deletions morl_baselines/common/morl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,19 @@ def extract_env_info(self, env: Optional[gym.Env]) -> None:
self.env = env
if isinstance(self.env.observation_space, spaces.Discrete):
self.observation_shape = (1,)
self.observation_dim = self.env.unwrapped.observation_space.n
self.observation_dim = self.env.observation_space.n
else:
self.observation_shape = self.env.unwrapped.observation_space.shape
self.observation_dim = self.env.unwrapped.observation_space.shape[0]
self.observation_shape = self.env.observation_space.shape
self.observation_dim = self.env.observation_space.shape[0]

self.action_space = env.unwrapped.action_space
if isinstance(self.env.unwrapped.action_space, (spaces.Discrete, spaces.MultiBinary)):
self.action_space = env.action_space
if isinstance(self.env.action_space, (spaces.Discrete, spaces.MultiBinary)):
self.action_shape = (1,)
self.action_dim = self.env.unwrapped.action_space.n
self.action_dim = self.env.action_space.n
else:
self.action_shape = self.env.unwrapped.action_space.shape
self.action_dim = self.env.unwrapped.action_space.shape[0]
self.action_shape = self.env.action_space.shape
self.action_dim = self.env.action_space.shape[0]

self.reward_dim = self.env.unwrapped.reward_space.shape[0]

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def test_gpi_pd():


def test_gpi_pd_continuous_action():
env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)

agent = GPIPDContinuousAction(
env,
Expand Down Expand Up @@ -278,8 +278,8 @@ def test_pcn():


def test_capql():
env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v4", cost_objective=False, max_episode_steps=500)
env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)
eval_env = mo_gym.make("mo-hopper-v5", cost_objective=False, max_episode_steps=500)

agent = CAPQL(
env,
Expand Down

0 comments on commit 92572aa

Please sign in to comment.