Skip to content

Commit

Permalink
Merge pull request #101 from EdanToledo/feat/add_navix
Browse files Browse the repository at this point in the history
Feat/add navix
  • Loading branch information
EdanToledo authored Jul 9, 2024
2 parents 7c06299 + 0d6b05f commit 54e0f4d
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 3 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Stoix currently offers the following building blocks for Single-Agent RL researc
- **Sampled Alpha/Mu-Zero** - [Paper](https://arxiv.org/abs/2104.06303)

### Environment Wrappers 🍬
Stoix offers wrappers for [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym] and even [JAXMarl][jaxmarl] (although using Centralised Controllers).
Stoix offers wrappers for [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym], [Navix][navix] and even [JAXMarl][jaxmarl] (although using Centralised Controllers).

### Statistically Robust Evaluation 🧪
Stoix natively supports logging to json files which adhere to the standard suggested by [Gorsane et al. (2022)][toward_standard_eval]. This enables easy downstream experiment plotting and aggregation using the tools found in the [MARL-eval][marl_eval] library.
Expand Down Expand Up @@ -140,6 +140,12 @@ or if you wanted to do dueling C51, you could do:
python stoix/systems/q_learning/ff_c51.py network=mlp_dueling_c51
```

## Important Considerations

1. If your environment does not have a timestep limit or is not guaranteed to end through some game mechanic, then it is possible for the evaluation to seem as if it is hanging forever thereby stalling the training but in fact your agent is just so good _or bad_ that the episode never finishes. Keep this in mind if you are seeing this behaviour. One solution is to simply add a time step limit or potentially action masking.

2. Due to the way Stoix is set up, you are not guaranteed to run for exactly the number of timesteps you set. A warning is given at the beginning of a run on the actual number of timesteps that will be run. This value will always be less than or equal to the specified sample budget. To get the exact number of transitions to run, ensure that the number of timesteps is divisible by the rollout length * total_num_envs and additionally ensure that the number of evaluations spaced out throughout training perfectly divide the number of updates to be performed. To see the exact calculation, see the file total_timestep_checker.py. This will give an indication of how the actual number of timesteps is calculated and how you can easily set it up to run the exact amount you desire. Its relatively trivial to do so but it is important to keep in mind.

## Contributing 🤝

Please read our [contributing docs](docs/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement and community guidelines.
Expand Down Expand Up @@ -210,5 +216,6 @@ We would like to thank the authors and developers of [Mava](mava) as this was es
[xminigrid]: https://github.com/corl-team/xland-minigrid/
[craftax]: https://github.com/MichaelTMatthews/Craftax
[popjym]: https://github.com/FLAIROx/popjym
[navix]: https://github.com/epignatelli/navix

Disclaimer: This is not an official InstaDeep product nor is any of the work putforward associated with InstaDeep in any official capacity.
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jaxlib
jaxmarl
jumanji==1.0.0
mctx
navix
neptune
numpy
omegaconf
Expand Down
15 changes: 15 additions & 0 deletions stoix/configs/env/navix/door_key_8x8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ---Environment Configs---
env_name: navix
scenario:
name: Navix-DoorKey-8x8-v0
task_name: navix-door_key-8x8-v0

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
15 changes: 15 additions & 0 deletions stoix/configs/env/navix/empty_5x5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# ---Environment Configs---
env_name: navix
scenario:
name: Navix-Empty-5x5-v0
task_name: navix-empty-5x5-v0

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return

# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ---Environment Configs---
env_name: minigrid
env_name: xland_minigrid
scenario:
name: MiniGrid-DoorKey-5x5
task_name: minigrid_doorkey_5x5
Expand Down
16 changes: 16 additions & 0 deletions stoix/configs/env/xland_minigrid/empty_5x5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ---Environment Configs---
env_name: xland_minigrid
scenario:
name: MiniGrid-Empty-5x5
task_name: minigrid_empty_5x5

kwargs: {}

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
eval_metric: episode_return


# Optional wrapper to flatten the observation space.
wrapper:
_target_: stoix.wrappers.transforms.FlattenObservationWrapper
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ---Environment Configs---
env_name: minigrid
env_name: xland_minigrid
scenario:
name: MiniGrid-Empty-6x6
task_name: minigrid_empty_6x6
Expand Down
30 changes: 30 additions & 0 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
import jaxmarl
import jumanji
import navix
import pgx
import popjym
import xminigrid
Expand All @@ -18,6 +19,7 @@
from jumanji.registration import _REGISTRY as JUMANJI_REGISTRY
from jumanji.specs import BoundedArray, MultiDiscreteArray
from jumanji.wrappers import AutoResetWrapper, MultiToSingleWrapper
from navix import registry as navix_registry
from omegaconf import DictConfig
from popjym.registration import REGISTERED_ENVS as POPJYM_REGISTRY
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY
Expand All @@ -26,6 +28,7 @@
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.navix import NavixWrapper
from stoix.wrappers.pgx import PGXWrapper
from stoix.wrappers.transforms import (
AddStartFlagAndPrevAction,
Expand Down Expand Up @@ -343,6 +346,31 @@ def make_popjym_env(env_name: str, config: DictConfig) -> Tuple[Environment, Env
return env, eval_env


def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Environment]:
"""
Create Navix environments for training and evaluation.
Args:
env_name (str): The name of the environment to create.
config (Dict): The configuration of the environment.
Returns:
A tuple of the environments.
"""

# Create envs.
env = navix.make(env_name, **config.env.kwargs)
eval_env = navix.make(env_name, **config.env.kwargs)

env = NavixWrapper(env)
eval_env = NavixWrapper(eval_env)

env = AutoResetWrapper(env, next_obs_in_extras=True)
env = RecordEpisodeMetrics(env)

return env, eval_env


def make(config: DictConfig) -> Tuple[Environment, Environment]:
"""
Create environments for training and evaluation..
Expand Down Expand Up @@ -373,6 +401,8 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]:
envs = make_pgx_env(env_name, config)
elif env_name in POPJYM_REGISTRY:
envs = make_popjym_env(env_name, config)
elif env_name in navix_registry():
envs = make_navix_env(env_name, config)
else:
raise ValueError(f"{env_name} is not a supported environment.")

Expand Down
97 changes: 97 additions & 0 deletions stoix/wrappers/navix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from typing import TYPE_CHECKING, Tuple

import chex
import jax
import jax.numpy as jnp
from jumanji import specs
from jumanji.specs import Array, DiscreteArray, Spec
from jumanji.types import StepType, TimeStep, restart
from jumanji.wrappers import Wrapper
from navix.environments import Environment
from navix.environments import Timestep as NavixState

from stoix.base_types import Observation

if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239
from dataclasses import dataclass
else:
from chex import dataclass


@dataclass
class NavixEnvState:
key: chex.PRNGKey
navix_state: NavixState


class NavixWrapper(Wrapper):
def __init__(self, env: Environment):
self._env = env
self._n_actions = len(self._env.action_set)

def reset(self, key: chex.PRNGKey) -> Tuple[NavixEnvState, TimeStep]:
key, key_reset = jax.random.split(key)
navix_state = self._env.reset(key_reset)
agent_view = navix_state.observation.astype(float)
legal_action_mask = jnp.ones((self._n_actions,), dtype=float)
step_count = navix_state.t.astype(int)
obs = Observation(agent_view, legal_action_mask, step_count)
timestep = restart(obs, extras={})
state = NavixEnvState(key=key, navix_state=navix_state)
return state, timestep

def step(self, state: NavixEnvState, action: chex.Array) -> Tuple[NavixEnvState, TimeStep]:
key, key_step = jax.random.split(state.key)

navix_state = self._env.step(state.navix_state, action)

agent_view = navix_state.observation.astype(float)
legal_action_mask = jnp.ones((self._n_actions,), dtype=float)
step_count = navix_state.t.astype(int)
next_obs = Observation(agent_view, legal_action_mask, step_count)

reward = navix_state.reward.astype(float)
terminal = navix_state.is_termination()
truncated = navix_state.is_truncation()

discount = jnp.array(1.0 - terminal, dtype=float)
final_step = jnp.logical_or(terminal, truncated)

timestep = TimeStep(
observation=next_obs,
reward=reward,
discount=discount,
step_type=jax.lax.select(final_step, StepType.LAST, StepType.MID),
extras={},
)
next_state = NavixEnvState(key=key_step, navix_state=navix_state)
return next_state, timestep

def reward_spec(self) -> specs.Array:
return specs.Array(shape=(), dtype=float, name="reward")

def discount_spec(self) -> specs.BoundedArray:
return specs.BoundedArray(shape=(), dtype=float, minimum=0.0, maximum=1.0, name="discount")

def action_spec(self) -> Spec:
return DiscreteArray(num_values=self._n_actions)

def observation_spec(self) -> Spec:
agent_view_shape = self._env.observation_space.shape
agent_view_min = self._env.observation_space.minimum
agent_view_max = self._env.observation_space.maximum
agent_view_spec = specs.BoundedArray(
shape=agent_view_shape,
dtype=float,
minimum=agent_view_min,
maximum=agent_view_max,
)
action_mask_spec = Array(shape=(self._n_actions,), dtype=float)

return specs.Spec(
Observation,
"ObservationSpec",
agent_view=agent_view_spec,
action_mask=action_mask_spec,
step_count=Array(shape=(), dtype=int),
)

0 comments on commit 54e0f4d

Please sign in to comment.