diff --git a/stoix/configs/env/navix/door_key_8x8.yaml b/stoix/configs/env/navix/door_key_8x8.yaml new file mode 100644 index 00000000..67f30d2b --- /dev/null +++ b/stoix/configs/env/navix/door_key_8x8.yaml @@ -0,0 +1,15 @@ +# ---Environment Configs--- +env_name: navix +scenario: + name: Navix-DoorKey-8x8-v0 + task_name: navix-door_key-8x8-v0 + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/navix/empty_5x5.yaml b/stoix/configs/env/navix/empty_5x5.yaml new file mode 100644 index 00000000..91f47be4 --- /dev/null +++ b/stoix/configs/env/navix/empty_5x5.yaml @@ -0,0 +1,15 @@ +# ---Environment Configs--- +env_name: navix +scenario: + name: Navix-Empty-5x5-v0 + task_name: navix-dempty-5x5-v0 + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# Optional wrapper to flatten the observation space. +wrapper: + _target_: stoix.wrappers.transforms.FlattenObservationWrapper diff --git a/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml b/stoix/configs/env/xland_minigrid/doorkey_5x5.yaml similarity index 94% rename from stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml rename to stoix/configs/env/xland_minigrid/doorkey_5x5.yaml index 6c848151..af2e555b 100644 --- a/stoix/configs/env/minigrid/minigrid_doorkey_5x5.yaml +++ b/stoix/configs/env/xland_minigrid/doorkey_5x5.yaml @@ -1,5 +1,5 @@ # ---Environment Configs--- -env_name: minigrid +env_name: xland_minigrid scenario: name: MiniGrid-DoorKey-5x5 task_name: minigrid_doorkey_5x5 diff --git a/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml b/stoix/configs/env/xland_minigrid/empty_6x6.yaml similarity index 94% rename from stoix/configs/env/minigrid/minigrid_empty_6x6.yaml rename to stoix/configs/env/xland_minigrid/empty_6x6.yaml index fd11326b..566ed5cc 100644 --- a/stoix/configs/env/minigrid/minigrid_empty_6x6.yaml +++ b/stoix/configs/env/xland_minigrid/empty_6x6.yaml @@ -1,5 +1,5 @@ # ---Environment Configs--- -env_name: minigrid +env_name: xland_minigrid scenario: name: MiniGrid-Empty-6x6 task_name: minigrid_empty_6x6 diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 7471aca5..67ebed25 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import jaxmarl import jumanji +import navix import pgx import popjym import xminigrid @@ -18,6 +19,7 @@ from jumanji.registration import _REGISTRY as JUMANJI_REGISTRY from jumanji.specs import BoundedArray, MultiDiscreteArray from jumanji.wrappers import AutoResetWrapper, MultiToSingleWrapper +from navix import registry as navix_registry from omegaconf import DictConfig from popjym.registration import REGISTERED_ENVS as POPJYM_REGISTRY from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY @@ -26,6 +28,7 @@ from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper +from stoix.wrappers.navix import NavixWrapper from stoix.wrappers.pgx import PGXWrapper from stoix.wrappers.transforms import ( AddStartFlagAndPrevAction, @@ -343,6 +346,31 @@ def make_popjym_env(env_name: str, config: DictConfig) -> Tuple[Environment, Env return env, eval_env +def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Environment]: + """ + Create Navix environments for training and evaluation. + + Args: + env_name (str): The name of the environment to create. + config (Dict): The configuration of the environment. + + Returns: + A tuple of the environments. + """ + + # Create envs. + env = navix.make(env_name, **config.env.kwargs) + eval_env = navix.make(env_name, **config.env.kwargs) + + env = NavixWrapper(env) + eval_env = NavixWrapper(eval_env) + + env = AutoResetWrapper(env, next_obs_in_extras=True) + env = RecordEpisodeMetrics(env) + + return env, eval_env + + def make(config: DictConfig) -> Tuple[Environment, Environment]: """ Create environments for training and evaluation.. @@ -373,6 +401,8 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: envs = make_pgx_env(env_name, config) elif env_name in POPJYM_REGISTRY: envs = make_popjym_env(env_name, config) + elif env_name in navix_registry(): + envs = make_navix_env(env_name, config) else: raise ValueError(f"{env_name} is not a supported environment.") diff --git a/stoix/wrappers/navix.py b/stoix/wrappers/navix.py new file mode 100644 index 00000000..5b548007 --- /dev/null +++ b/stoix/wrappers/navix.py @@ -0,0 +1,97 @@ +from typing import TYPE_CHECKING, Tuple + +import chex +import jax +import jax.numpy as jnp +from jumanji import specs +from jumanji.specs import Array, DiscreteArray, Spec +from jumanji.types import StepType, TimeStep, restart +from jumanji.wrappers import Wrapper +from navix.environments import Environment +from navix.environments import Timestep as NavixState + +from stoix.base_types import Observation + +if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 + from dataclasses import dataclass +else: + from chex import dataclass + + +@dataclass +class NavixEnvState: + key: chex.PRNGKey + navix_state: NavixState + + +class NavixWrapper(Wrapper): + def __init__(self, env: Environment): + self._env = env + self._n_actions = len(self._env.action_set) + + def reset(self, key: chex.PRNGKey) -> Tuple[NavixEnvState, TimeStep]: + key, key_reset = jax.random.split(key) + navix_state = self._env.reset(key_reset) + agent_view = navix_state.observation.astype(float) + legal_action_mask = jnp.ones((self._n_actions,), dtype=float) + step_count = navix_state.t.astype(int) + obs = Observation(agent_view, legal_action_mask, step_count) + timestep = restart(obs, extras={}) + state = NavixEnvState(key=key, navix_state=navix_state) + return state, timestep + + def step(self, state: NavixEnvState, action: chex.Array) -> Tuple[NavixEnvState, TimeStep]: + key, key_step = jax.random.split(state.key) + + navix_state = self._env.step(state.navix_state, action) + + agent_view = navix_state.observation.astype(float) + legal_action_mask = jnp.ones((self._n_actions,), dtype=float) + step_count = navix_state.t.astype(int) + next_obs = Observation(agent_view, legal_action_mask, step_count) + + reward = navix_state.reward.astype(float) + terminal = navix_state.is_termination() + truncated = navix_state.is_truncation() + + discount = jnp.array(1.0 - terminal, dtype=float) + final_step = jnp.logical_or(terminal, truncated) + + timestep = TimeStep( + observation=next_obs, + reward=reward, + discount=discount, + step_type=jax.lax.select(final_step, StepType.LAST, StepType.MID), + extras={}, + ) + next_state = NavixEnvState(key=key_step, navix_state=navix_state) + return next_state, timestep + + def reward_spec(self) -> specs.Array: + return specs.Array(shape=(), dtype=float, name="reward") + + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount") + + def action_spec(self) -> Spec: + return DiscreteArray(num_values=self._n_actions) + + def observation_spec(self) -> Spec: + agent_view_shape = self._env.observation_space.shape + agent_view_min = self._env.observation_space.minimum + agent_view_max = self._env.observation_space.maximum + agent_view_spec = specs.BoundedArray( + shape=agent_view_shape, + dtype=float, + minimum=agent_view_min, + maximum=agent_view_max, + ) + action_mask_spec = Array(shape=(self._n_actions,), dtype=float) + + return specs.Spec( + Observation, + "ObservationSpec", + agent_view=agent_view_spec, + action_mask=action_mask_spec, + step_count=Array(shape=(), dtype=int), + )