Skip to content

Commit

Permalink
Feat/add jax env factory (#118)
Browse files Browse the repository at this point in the history
* feat: add env factory for jax envs
  • Loading branch information
EdanToledo authored Sep 20, 2024
1 parent ce449a2 commit e10dc3e
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 8 deletions.
3 changes: 0 additions & 3 deletions stoix/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,3 @@ class EvaluationOutput(NamedTuple, Generic[StoixState]):
[FrozenDict, HiddenState, RNNObservation, chex.PRNGKey], Tuple[HiddenState, chex.Array]
]
RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]]


EnvFactory = Callable[[int], Any]
2 changes: 1 addition & 1 deletion stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from stoix.base_types import (
ActFn,
ActorApply,
EnvFactory,
EvalFn,
EvalState,
EvaluationOutput,
Expand All @@ -25,6 +24,7 @@
RNNObservation,
SebulbaEvalFn,
)
from stoix.utils.env_factory import EnvFactory
from stoix.utils.jax_utils import unreplicate_batch_dim


Expand Down
9 changes: 5 additions & 4 deletions stoix/utils/make_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Tuple, Union
from typing import Tuple

import gymnax
import hydra
Expand All @@ -25,9 +25,10 @@
from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY

from stoix.utils.debug_env import IdentityGame, SequenceGame
from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory
from stoix.utils.env_factory import EnvFactory, EnvPoolFactory, GymnasiumFactory
from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics
from stoix.wrappers.brax import BraxJumanjiWrapper
from stoix.wrappers.jax_to_factory import JaxEnvFactory
from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper
from stoix.wrappers.navix import NavixWrapper
from stoix.wrappers.pgx import PGXWrapper
Expand Down Expand Up @@ -426,7 +427,7 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]:
return envs


def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]:
def make_factory(config: DictConfig) -> EnvFactory:
"""
Create a env_factory for sebulba systems.
Expand All @@ -444,4 +445,4 @@ def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]:
elif "gymnasium" in suite_name:
return make_gymnasium_factory(env_name, config)
else:
raise ValueError(f"{suite_name} is not a supported suite.")
return JaxEnvFactory(make(config)[0], init_seed=config.arch.seed)
129 changes: 129 additions & 0 deletions stoix/wrappers/jax_to_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import threading
from typing import Optional

import jax
import numpy as np
from jumanji.env import Environment
from jumanji.specs import Spec
from jumanji.types import TimeStep

from stoix.utils.env_factory import EnvFactory


class JaxToStateful:
"""Converts a Stoix-ready JAX environment to a stateful one to be used by Sebulba systems."""

def __init__(self, env: Environment, num_envs: int, device: jax.Device, init_seed: int):
self.env = env
self.num_envs = num_envs
self.device = device

# Create the metrics
self.running_count_episode_return = np.zeros(self.num_envs, dtype=float)
self.running_count_episode_length = np.zeros(self.num_envs, dtype=int)
self.episode_return = np.zeros(self.num_envs, dtype=float)
self.episode_length = np.zeros(self.num_envs, dtype=int)

# Create the seeds
max_int = np.iinfo(np.int32).max
min_int = np.iinfo(np.int32).min
init_seeds = jax.random.randint(
jax.random.PRNGKey(init_seed), (num_envs,), min_int, max_int
)
self.rng_keys = jax.vmap(jax.random.PRNGKey)(init_seeds)

# Vmap and compile the reset and step functions
self.vmapped_reset = jax.jit(jax.vmap(self.env.reset), device=self.device)
self.vmapped_step = jax.jit(jax.vmap(self.env.step, in_axes=(0, 0)), device=self.device)

def reset(
self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None
) -> TimeStep:
with jax.default_device(self.device):

self.state, timestep = self.vmapped_reset(self.rng_keys)

# Reset the metrics
self.running_count_episode_return = np.zeros(self.num_envs, dtype=float)
self.running_count_episode_length = np.zeros(self.num_envs, dtype=int)
self.episode_return = np.zeros(self.num_envs, dtype=float)
self.episode_length = np.zeros(self.num_envs, dtype=int)

# Create the metrics dict
metrics = {
"episode_return": np.zeros(self.num_envs, dtype=float),
"episode_length": np.zeros(self.num_envs, dtype=int),
"is_terminal_step": np.zeros(self.num_envs, dtype=bool),
}

timestep_extras = timestep.extras

timestep_extras["metrics"] = metrics

timestep = timestep.replace(extras=timestep_extras)

return timestep

def step(self, action: list) -> TimeStep:
with jax.default_device(self.device):
self.state, timestep = self.vmapped_step(self.state, action)

ep_done = timestep.last()
not_done = ~ep_done

# Counting episode return and length.
new_episode_return = self.running_count_episode_return + timestep.reward
new_episode_length = self.running_count_episode_length + 1

# Update the episode return and length if the episode is done otherwise
# keep the previous values
episode_return_info = self.episode_return * not_done + new_episode_return * ep_done
episode_length_info = self.episode_length * not_done + new_episode_length * ep_done
# Update the running count
self.running_count_episode_return = new_episode_return * not_done
self.running_count_episode_length = new_episode_length * not_done

self.episode_return = episode_return_info
self.episode_length = episode_length_info

# Create the metrics dict
metrics = {
"episode_return": episode_return_info,
"episode_length": episode_length_info,
"is_terminal_step": ep_done,
}

timestep_extras = timestep.extras
timestep_extras["metrics"] = metrics
timestep = timestep.replace(extras=timestep_extras)

return timestep

def observation_spec(self) -> Spec:
return self.env.observation_spec()

def action_spec(self) -> Spec:
return self.env.action_spec()

def close(self) -> None:
pass


class JaxEnvFactory(EnvFactory):
"""
Create environments using stoix-ready JAX environments
"""

def __init__(self, jax_env: Environment, init_seed: int):
self.jax_env = jax_env
self.cpu = jax.devices("cpu")[0]
self.seed = init_seed
# a lock is needed because this object will be used from different threads.
# We want to make sure all seeds are unique
self.lock = threading.Lock()

def __call__(self, num_envs: int) -> JaxToStateful:
with self.lock:
seed = self.seed
self.seed += num_envs
return JaxToStateful(self.jax_env, num_envs, self.cpu, seed)

0 comments on commit e10dc3e

Please sign in to comment.