Skip to content

Commit

Permalink
feat: add navix
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Jul 9, 2024
1 parent 7c06299 commit 6505baf
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 2 deletions.
15 changes: 15 additions & 0 deletions stoix/configs/env/navix/door_key_8x8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ---Environment Configs---
env_name: navix
scenario:
name: Navix-DoorKey-8x8-v0
task_name: navix-door_key-8x8-v0

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
15 changes: 15 additions & 0 deletions stoix/configs/env/navix/empty_5x5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ---Environment Configs---
env_name: navix
scenario:
name: Navix-Empty-5x5-v0
task_name: navix-dempty-5x5-v0

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ---Environment Configs---
env_name: minigrid
env_name: xland_minigrid
scenario:
name: MiniGrid-DoorKey-5x5
task_name: minigrid_doorkey_5x5
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ---Environment Configs---
env_name: minigrid
env_name: xland_minigrid
scenario:
name: MiniGrid-Empty-6x6
task_name: minigrid_empty_6x6
Expand Down
30 changes: 30 additions & 0 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import jaxmarl
import jumanji
import navix
import pgx
import popjym
import xminigrid
Expand All @@ -18,6 +19,7 @@
from jumanji.registration import _REGISTRY as JUMANJI_REGISTRY
from jumanji.specs import BoundedArray, MultiDiscreteArray
from jumanji.wrappers import AutoResetWrapper, MultiToSingleWrapper
from navix import registry as navix_registry
from omegaconf import DictConfig
from popjym.registration import REGISTERED_ENVS as POPJYM_REGISTRY
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY
Expand All @@ -26,6 +28,7 @@
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.navix import NavixWrapper
from stoix.wrappers.pgx import PGXWrapper
from stoix.wrappers.transforms import (
AddStartFlagAndPrevAction,
Expand Down Expand Up @@ -343,6 +346,31 @@ def make_popjym_env(env_name: str, config: DictConfig) -> Tuple[Environment, Env
return env, eval_env


def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Environment]:
"""
Create Navix environments for training and evaluation.
Args:
env_name (str): The name of the environment to create.
config (Dict): The configuration of the environment.
Returns:
A tuple of the environments.
"""

# Create envs.
env = navix.make(env_name, **config.env.kwargs)
eval_env = navix.make(env_name, **config.env.kwargs)

env = NavixWrapper(env)
eval_env = NavixWrapper(eval_env)

env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)

return env, eval_env


def make(config: DictConfig) -> Tuple[Environment, Environment]:
"""
Create environments for training and evaluation..
Expand Down Expand Up @@ -373,6 +401,8 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]:
envs = make_pgx_env(env_name, config)
elif env_name in POPJYM_REGISTRY:
envs = make_popjym_env(env_name, config)
elif env_name in navix_registry():
envs = make_navix_env(env_name, config)
else:
raise ValueError(f"{env_name} is not a supported environment.")

Expand Down
97 changes: 97 additions & 0 deletions stoix/wrappers/navix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import TYPE_CHECKING, Tuple

import chex
import jax
import jax.numpy as jnp
from jumanji import specs
from jumanji.specs import Array, DiscreteArray, Spec
from jumanji.types import StepType, TimeStep, restart
from jumanji.wrappers import Wrapper
from navix.environments import Environment
from navix.environments import Timestep as NavixState

from stoix.base_types import Observation

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
from chex import dataclass


@dataclass
class NavixEnvState:
key: chex.PRNGKey
navix_state: NavixState


class NavixWrapper(Wrapper):
def __init__(self, env: Environment):
self._env = env
self._n_actions = len(self._env.action_set)

def reset(self, key: chex.PRNGKey) -> Tuple[NavixEnvState, TimeStep]:
key, key_reset = jax.random.split(key)
navix_state = self._env.reset(key_reset)
agent_view = navix_state.observation.astype(float)
legal_action_mask = jnp.ones((self._n_actions,), dtype=float)
step_count = navix_state.t.astype(int)
obs = Observation(agent_view, legal_action_mask, step_count)
timestep = restart(obs, extras={})
state = NavixEnvState(key=key, navix_state=navix_state)
return state, timestep

def step(self, state: NavixEnvState, action: chex.Array) -> Tuple[NavixEnvState, TimeStep]:
key, key_step = jax.random.split(state.key)

navix_state = self._env.step(state.navix_state, action)

agent_view = navix_state.observation.astype(float)
legal_action_mask = jnp.ones((self._n_actions,), dtype=float)
step_count = navix_state.t.astype(int)
next_obs = Observation(agent_view, legal_action_mask, step_count)

reward = navix_state.reward.astype(float)
terminal = navix_state.is_termination()
truncated = navix_state.is_truncation()

discount = jnp.array(1.0 - terminal, dtype=float)
final_step = jnp.logical_or(terminal, truncated)

timestep = TimeStep(
observation=next_obs,
reward=reward,
discount=discount,
step_type=jax.lax.select(final_step, StepType.LAST, StepType.MID),
extras={},
)
next_state = NavixEnvState(key=key_step, navix_state=navix_state)
return next_state, timestep

def reward_spec(self) -> specs.Array:
return specs.Array(shape=(), dtype=float, name="reward")

def discount_spec(self) -> specs.BoundedArray:
return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount")

def action_spec(self) -> Spec:
return DiscreteArray(num_values=self._n_actions)

def observation_spec(self) -> Spec:
agent_view_shape = self._env.observation_space.shape
agent_view_min = self._env.observation_space.minimum
agent_view_max = self._env.observation_space.maximum
agent_view_spec = specs.BoundedArray(
shape=agent_view_shape,
dtype=float,
minimum=agent_view_min,
maximum=agent_view_max,
)
action_mask_spec = Array(shape=(self._n_actions,), dtype=float)

return specs.Spec(
Observation,
"ObservationSpec",
agent_view=agent_view_spec,
action_mask=action_mask_spec,
step_count=Array(shape=(), dtype=int),
)

0 comments on commit 6505baf

Please sign in to comment.