diff --git a/README.md b/README.md index e23fe793..3f53dcfe 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,20 @@ ## Welcome to Stoix! ๐Ÿ›๏ธ -Stoix provides simplified code for quickly iterating on ideas in single-agent reinforcement learning with useful implementations of popular single-agent RL algorithms in JAX allowing for easy parallelisation across devices with JAX's `pmap`. All implementations are fully compilable with JAX's `jit` thus making training and environment execution very fast. However, this requires environments written in JAX. Algorithms and their default hyperparameters have not been hyper-optimised for any specific environment and are useful as a starting point for research and/or for initial baselines. +Stoix provides simplified code for quickly iterating on ideas in single-agent reinforcement learning with useful implementations of popular single-agent RL algorithms in JAX allowing for easy parallelisation across devices with JAX's `pmap`. All implementations are fully compiled with JAX's `jit` thus making training and environment execution very fast. However, this does require environments written in JAX. For environments not written in JAX, Stoix offers Sebulba systems (see below). Algorithms and their default hyperparameters have not been hyper-optimised for any specific environment and are useful as a starting point for research and/or for initial baselines. To join us in these efforts, please feel free to reach out, raise issues or read our [contribution guidelines](#contributing-) (or just star ๐ŸŒŸ to stay up to date with the latest developments)! -Stoix is fully in JAX with substantial speed improvement compared to other popular libraries. We currently provide native support for the [Jumanji][jumanji] environment API and wrappers for popular JAX-based RL environments. +Stoix is fully in JAX with substantial speed improvement compared to other popular libraries. We currently provide native support for the [Jumanji][jumanji] environment API and wrappers for popular RL environments. + +## System Design Paradigms +Stoix offers two primary system design paradigms (Podracer Architectures) to cater to different research and deployment needs: + +- **Anakin:** Traditional Stoix implementations are fully end-to-end compiled with JAX, focusing on speed and simplicity with native JAX environments. This design paradigm is ideal for setups where all components, including environments, can be optimized using JAX, leveraging the full power of JAX's pmap and jit. For an illustration of the Anakin architecture, see this [figure](docs/images/anakin_arch.jpg) from the [Mava](mava) technical report. + +- **Sebulba:** The Sebulba system introduces flexibility by allowing different devices to be assigned specifically for learning and acting. In this setup, acting devices serve as inference servers for multiple parallel environments, which can be written in any framework, not just JAX. This enables Stoix to be used with a broader range of environments while still benefiting from JAX's speed. For an illustration of the Sebulba architecture, see this [animation](docs/images/sebulba_arch.gif) from the [InstaDeep Sebulba implementation](https://github.com/instadeepai/sebulba/). + +Not all implementations have both Anakin and Sebulba implementations but effort has gone into making the two implementations as similar as possible to allow easy conversion. ## Code Philosophy ๐Ÿง˜ @@ -47,9 +56,11 @@ The current code in Stoix was initially **largely** taken and subsequently adapt ### Stoix TLDR 1. **Algorithms:** Stoix offers easily hackable, single-file implementations of popular algorithms in pure JAX. You can vectorize algorithm training on a single device using `vmap` as well as distribute training across multiple devices with `pmap` (or both). Multi-host support (i.e., vmap/pmap over multiple devices **and** machines) is coming soon! All implementations include checkpointing to save and resume parameters and training runs. -2. **Hydra Config System:** Leverage the Hydra configuration system for efficient and consistent management of experiments, network architectures, and environments. Hydra facilitates the easy addition of new hyperparameters and supports multi-runs and Optuna hyperparameter optimization. No more need to create large bash scripts to run a series of experiments with differing hyperparameters, network architectures or environments. +2. **System Designs:** Choose between Anakin systems for fully JAX-optimized workflows or Sebulba systems for flexibility with non-JAX environments. + +3. **Hydra Config System:** Leverage the Hydra configuration system for efficient and consistent management of experiments, network architectures, and environments. Hydra facilitates the easy addition of new hyperparameters and supports multi-runs and Optuna hyperparameter optimization. No more need to create large bash scripts to run a series of experiments with differing hyperparameters, network architectures or environments. -3. **Advanced Logging:** Stoix features advanced and configurable logging, ready for output to the terminal, TensorBoard, and other ML tracking dashboards (WandB and Neptune). It also supports logging experiments in JSON format ready for statistical tests and generating RLiable plots (see the plotting notebook). This enables statistically confident comparisons of algorithms natively. +4. **Advanced Logging:** Stoix features advanced and configurable logging, ready for output to the terminal, TensorBoard, and other ML tracking dashboards (WandB and Neptune). It also supports logging experiments in JSON format ready for statistical tests and generating RLiable plots (see the plotting notebook). This enables statistically confident comparisons of algorithms natively. Stoix currently offers the following building blocks for Single-Agent RL research: @@ -78,14 +89,17 @@ Stoix currently offers the following building blocks for Single-Agent RL researc - **Sampled Alpha/Mu-Zero** - [Paper](https://arxiv.org/abs/2104.06303) ### Environment Wrappers ๐Ÿฌ -Stoix offers wrappers for [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym], [Navix][navix] and even [JAXMarl][jaxmarl] (although using Centralised Controllers). +Stoix offers wrappers for: + +- **JAX environments:** [Gymnax][gymnax], [Jumanji][jumanji], [Brax][brax], [XMinigrid][xminigrid], [Craftax][craftax], [POPJym][popjym], [Navix][navix] and even [JAXMarl][jaxmarl] (although using Centralised Controllers). +- **Non-JAX environments:** [Envpool][envpool] and [Gymnasium][gymnasium]. ### Statistically Robust Evaluation ๐Ÿงช Stoix natively supports logging to json files which adhere to the standard suggested by [Gorsane et al. (2022)][toward_standard_eval]. This enables easy downstream experiment plotting and aggregation using the tools found in the [MARL-eval][marl_eval] library. ## Performance and Speed ๐Ÿš€ -As the code in Stoix (at the time of creation) was in essence a port of [Mava][mava], for further speed comparisons we point to their repo. Additionally, we refer to the PureJaxRL blog post [here](https://chrislu.page/blog/meta-disco/) where the speed benefits of end-to-end JAX systems are discussed. +As the code in Stoix (at the time of creation) was in essence a port of [Mava][mava], for further speed comparisons we point to their repo. Additionally, we refer to the PureJaxRL blog post [here](https://chrislu.page/blog/meta-disco/) where the speed benefits of end-to-end JAX systems are discussed. Lastly, we point to the Podracer architectures paper [here][anakin_paper] where these ideas were first discussed and benchmarked. Below we provide some plots illustrating that Stoix performs equally to that of [PureJaxRL][purejaxrl] but with the added benefit of the code being already set up for `pmap` distribution over devices as well as the other features provided (algorithm implementations, logging, config system, etc).

@@ -118,14 +132,22 @@ we advise users to explicitly install the correct JAX version (see the [official To get started with training your first Stoix system, simply run one of the system files. e.g., +For an Anakin system: + +```bash +python stoix/systems/ppo/anakin/ff_ppo.py +``` + +or for a Sebulba system: + ```bash -python stoix/systems/ppo/ff_ppo.py +python stoix/systems/ppo/sebulba/ff_ppo.py arch=sebulba env=envpool/pong network=visual_resnet ``` -Stoix makes use of Hydra for config management. In order to see our default system configs please see the `stoix/configs/` directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the CartPole environment, the above code can simply be adapted as follows: +Stoix makes use of Hydra for config management. In order to see our default system configs please see the `stoix/configs/` directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the CartPole environment and changing any hyperparameters, the above code can simply be adapted as follows: ```bash -python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole +python stoix/systems/ppo/ff_ppo.py env=gymnax/cartpole system.rollout_length=32 system.decay_learning_rates=True ``` Additionally, certain implementations such as Dueling DQN are decided by the network architecture but the underlying algorithm stays the same. For example, if you wanted to run Dueling DQN you would simply do: @@ -146,6 +168,8 @@ python stoix/systems/q_learning/ff_c51.py network=mlp_dueling_c51 2. Due to the way Stoix is set up, you are not guaranteed to run for exactly the number of timesteps you set. A warning is given at the beginning of a run on the actual number of timesteps that will be run. This value will always be less than or equal to the specified sample budget. To get the exact number of transitions to run, ensure that the number of timesteps is divisible by the rollout length * total_num_envs and additionally ensure that the number of evaluations spaced out throughout training perfectly divide the number of updates to be performed. To see the exact calculation, see the file total_timestep_checker.py. This will give an indication of how the actual number of timesteps is calculated and how you can easily set it up to run the exact amount you desire. Its relatively trivial to do so but it is important to keep in mind. +3. Optimising the performance and speed for Sebulba systems can be a little tricky as you need to balance the pipeline size, the number of actor threads, etc so keep this in mind when applying an algorithm to a new problem. + ## Contributing ๐Ÿค Please read our [contributing docs](docs/CONTRIBUTING.md) for details on how to submit pull requests, our Contributor License Agreement and community guidelines. @@ -217,5 +241,7 @@ We would like to thank the authors and developers of [Mava](mava) as this was es [craftax]: https://github.com/MichaelTMatthews/Craftax [popjym]: https://github.com/FLAIROx/popjym [navix]: https://github.com/epignatelli/navix +[envpool]: https://github.com/sail-sg/envpool/ +[gymnasium]: https://github.com/Farama-Foundation/Gymnasium Disclaimer: This is not an official InstaDeep product nor is any of the work putforward associated with InstaDeep in any official capacity. diff --git a/docs/images/anakin_arch.jpg b/docs/images/anakin_arch.jpg new file mode 100644 index 00000000..ee8a3461 Binary files /dev/null and b/docs/images/anakin_arch.jpg differ diff --git a/docs/images/sebulba_arch.gif b/docs/images/sebulba_arch.gif new file mode 100644 index 00000000..009a2564 Binary files /dev/null and b/docs/images/sebulba_arch.gif differ diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a94587d7..18c118b7 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,8 +3,10 @@ chex colorama craftax distrax @ git+https://github.com/google-deepmind/distrax # distrax release doesn't support jax > 0.4.13 +envpool flashbax @ git+https://github.com/instadeepai/flashbax flax +gymnasium gymnax>=0.0.6 huggingface_hub hydra-core==1.3.2 diff --git a/stoix/base_types.py b/stoix/base_types.py index 6f7436d4..00ddcf43 100644 --- a/stoix/base_types.py +++ b/stoix/base_types.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Tuple, TypeVar import chex from distrax import DistributionLike @@ -36,7 +36,7 @@ class Observation(NamedTuple): agent_view: chex.Array # (num_obs_features,) action_mask: chex.Array # (num_actions,) - step_count: chex.Array # (,) + step_count: Optional[chex.Array] = None # (,) class ObservationGlobalState(NamedTuple): @@ -106,8 +106,18 @@ class ActorCriticHiddenStates(NamedTuple): critic_hidden_state: HiddenState -class LearnerState(NamedTuple): - """State of the learner.""" +class CoreLearnerState(NamedTuple): + """Base state of the learner. Can be used for both on-policy and off-policy learners. + Mainly used for sebulba systems since we dont store env state.""" + + params: Parameters + opt_states: OptStates + key: chex.PRNGKey + timestep: TimeStep + + +class OnPolicyLearnerState(NamedTuple): + """State of the learner. Used for on-policy learners.""" params: Parameters opt_states: OptStates @@ -146,6 +156,9 @@ class OnlineAndTarget(NamedTuple): StoixState = TypeVar( "StoixState", ) +StoixTransition = TypeVar( + "StoixTransition", +) class ExperimentOutput(NamedTuple, Generic[StoixState]): @@ -158,6 +171,7 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]): RNNObservation: TypeAlias = Tuple[Observation, Done] LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]] +SebulbaLearnerFn = Callable[[StoixState, StoixTransition], ExperimentOutput[StoixState]] EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]] ActorApply = Callable[..., DistributionLike] @@ -174,3 +188,6 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]): [FrozenDict, HiddenState, RNNObservation, chex.PRNGKey], Tuple[HiddenState, chex.Array] ] RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]] + + +EnvFactory = Callable[[int], Any] diff --git a/stoix/configs/arch/anakin.yaml b/stoix/configs/arch/anakin.yaml index f6092512..8d2c40b8 100644 --- a/stoix/configs/arch/anakin.yaml +++ b/stoix/configs/arch/anakin.yaml @@ -1,5 +1,5 @@ # --- Anakin config --- - +architecture_name: anakin # --- Training --- seed: 42 # RNG seed. update_batch_size: 1 # Number of vectorised gradient updates per device. diff --git a/stoix/configs/arch/sebulba.yaml b/stoix/configs/arch/sebulba.yaml new file mode 100644 index 00000000..78f55a1b --- /dev/null +++ b/stoix/configs/arch/sebulba.yaml @@ -0,0 +1,29 @@ +# --- Sebulba config --- +architecture_name : sebulba +# --- Training --- +seed: 42 # RNG seed. +total_num_envs: 1024 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device. +total_timesteps: 1e7 # Set the total environment steps. +# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. +num_updates: ~ # Number of updates + +# Define the number of actors per device and which devices to use. +actor: + device_ids: [0,1] # Define which devices to use for the actors. + actor_per_device: 2 # number of different threads per actor device. + +# Define which devices to use for the learner. +learner: + device_ids: [2,3] # Define which devices to use for the learner. + +# Size of the queue for the pipeline where actors push data and the learner pulls data. +pipeline_queue_size: 10 + +# --- Evaluation --- +evaluation_greedy: False # Evaluate the policy greedily. If True the policy will select + # an action which corresponds to the greatest logit. If false, the policy will sample + # from the logits. +num_eval_episodes: 128 # Number of episodes to evaluate per evaluation. +num_evaluation: 20 # Number of evenly spaced evaluations to perform during training. +absolute_metric: True # Whether the absolute metric should be computed. For more details + # on the absolute metric please see: https://arxiv.org/abs/2209.10485 diff --git a/stoix/configs/default_ff_awr.yaml b/stoix/configs/default/anakin/default_ff_awr.yaml similarity index 71% rename from stoix/configs/default_ff_awr.yaml rename to stoix/configs/default/anakin/default_ff_awr.yaml index 2dbdcf9f..f0d5b9e5 100644 --- a/stoix/configs/default_ff_awr.yaml +++ b/stoix/configs/default/anakin/default_ff_awr.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_awr_continuous.yaml b/stoix/configs/default/anakin/default_ff_awr_continuous.yaml similarity index 72% rename from stoix/configs/default_ff_awr_continuous.yaml rename to stoix/configs/default/anakin/default_ff_awr_continuous.yaml index 85b712b8..5144a14d 100644 --- a/stoix/configs/default_ff_awr_continuous.yaml +++ b/stoix/configs/default/anakin/default_ff_awr_continuous.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_continuous - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_az.yaml b/stoix/configs/default/anakin/default_ff_az.yaml similarity index 71% rename from stoix/configs/default_ff_az.yaml rename to stoix/configs/default/anakin/default_ff_az.yaml index 30263b05..1a7b603e 100644 --- a/stoix/configs/default_ff_az.yaml +++ b/stoix/configs/default/anakin/default_ff_az.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_c51.yaml b/stoix/configs/default/anakin/default_ff_c51.yaml similarity index 72% rename from stoix/configs/default_ff_c51.yaml rename to stoix/configs/default/anakin/default_ff_c51.yaml index 02d4a589..0a98336f 100644 --- a/stoix/configs/default_ff_c51.yaml +++ b/stoix/configs/default/anakin/default_ff_c51.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_c51 - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_d4pg.yaml b/stoix/configs/default/anakin/default_ff_d4pg.yaml similarity index 71% rename from stoix/configs/default_ff_d4pg.yaml rename to stoix/configs/default/anakin/default_ff_d4pg.yaml index 3eb9bac3..7d6e8445 100644 --- a/stoix/configs/default_ff_d4pg.yaml +++ b/stoix/configs/default/anakin/default_ff_d4pg.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_d4pg - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_ddpg.yaml b/stoix/configs/default/anakin/default_ff_ddpg.yaml similarity index 71% rename from stoix/configs/default_ff_ddpg.yaml rename to stoix/configs/default/anakin/default_ff_ddpg.yaml index cf55a2f8..26a6c30c 100644 --- a/stoix/configs/default_ff_ddpg.yaml +++ b/stoix/configs/default/anakin/default_ff_ddpg.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_ddpg - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_ddqn.yaml b/stoix/configs/default/anakin/default_ff_ddqn.yaml similarity index 72% rename from stoix/configs/default_ff_ddqn.yaml rename to stoix/configs/default/anakin/default_ff_ddqn.yaml index 3f3fa8b1..adb112bd 100644 --- a/stoix/configs/default_ff_ddqn.yaml +++ b/stoix/configs/default/anakin/default_ff_ddqn.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_dqn - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_dpo_continuous.yaml b/stoix/configs/default/anakin/default_ff_dpo_continuous.yaml similarity index 72% rename from stoix/configs/default_ff_dpo_continuous.yaml rename to stoix/configs/default/anakin/default_ff_dpo_continuous.yaml index a3da231f..9f443132 100644 --- a/stoix/configs/default_ff_dpo_continuous.yaml +++ b/stoix/configs/default/anakin/default_ff_dpo_continuous.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_continuous - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_dqn.yaml b/stoix/configs/default/anakin/default_ff_dqn.yaml similarity index 72% rename from stoix/configs/default_ff_dqn.yaml rename to stoix/configs/default/anakin/default_ff_dqn.yaml index 3f3fa8b1..adb112bd 100644 --- a/stoix/configs/default_ff_dqn.yaml +++ b/stoix/configs/default/anakin/default_ff_dqn.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_dqn - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_dqn_reg.yaml b/stoix/configs/default/anakin/default_ff_dqn_reg.yaml similarity index 72% rename from stoix/configs/default_ff_dqn_reg.yaml rename to stoix/configs/default/anakin/default_ff_dqn_reg.yaml index 85437607..20768e7c 100644 --- a/stoix/configs/default_ff_dqn_reg.yaml +++ b/stoix/configs/default/anakin/default_ff_dqn_reg.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_dqn - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_mdqn.yaml b/stoix/configs/default/anakin/default_ff_mdqn.yaml similarity index 72% rename from stoix/configs/default_ff_mdqn.yaml rename to stoix/configs/default/anakin/default_ff_mdqn.yaml index c67b1547..78803765 100644 --- a/stoix/configs/default_ff_mdqn.yaml +++ b/stoix/configs/default/anakin/default_ff_mdqn.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_dqn - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_mpo.yaml b/stoix/configs/default/anakin/default_ff_mpo.yaml similarity index 72% rename from stoix/configs/default_ff_mpo.yaml rename to stoix/configs/default/anakin/default_ff_mpo.yaml index b327eda3..37fdacf3 100644 --- a/stoix/configs/default_ff_mpo.yaml +++ b/stoix/configs/default/anakin/default_ff_mpo.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_mpo - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_mpo_continuous.yaml b/stoix/configs/default/anakin/default_ff_mpo_continuous.yaml similarity index 74% rename from stoix/configs/default_ff_mpo_continuous.yaml rename to stoix/configs/default/anakin/default_ff_mpo_continuous.yaml index 875100eb..da6c092c 100644 --- a/stoix/configs/default_ff_mpo_continuous.yaml +++ b/stoix/configs/default/anakin/default_ff_mpo_continuous.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_mpo_continuous - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_mz.yaml b/stoix/configs/default/anakin/default_ff_mz.yaml similarity index 71% rename from stoix/configs/default_ff_mz.yaml rename to stoix/configs/default/anakin/default_ff_mz.yaml index 7832e120..f2111f07 100644 --- a/stoix/configs/default_ff_mz.yaml +++ b/stoix/configs/default/anakin/default_ff_mz.yaml @@ -5,3 +5,7 @@ defaults: - network: muzero - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_ppo.yaml b/stoix/configs/default/anakin/default_ff_ppo.yaml similarity index 71% rename from stoix/configs/default_ff_ppo.yaml rename to stoix/configs/default/anakin/default_ff_ppo.yaml index c2481cd4..ff09b91e 100644 --- a/stoix/configs/default_ff_ppo.yaml +++ b/stoix/configs/default/anakin/default_ff_ppo.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_ppo_continuous.yaml b/stoix/configs/default/anakin/default_ff_ppo_continuous.yaml similarity index 72% rename from stoix/configs/default_ff_ppo_continuous.yaml rename to stoix/configs/default/anakin/default_ff_ppo_continuous.yaml index 3cb9f874..95fb5249 100644 --- a/stoix/configs/default_ff_ppo_continuous.yaml +++ b/stoix/configs/default/anakin/default_ff_ppo_continuous.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_continuous - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_qr_dqn.yaml b/stoix/configs/default/anakin/default_ff_qr_dqn.yaml similarity index 73% rename from stoix/configs/default_ff_qr_dqn.yaml rename to stoix/configs/default/anakin/default_ff_qr_dqn.yaml index f9216b9f..3a54234b 100644 --- a/stoix/configs/default_ff_qr_dqn.yaml +++ b/stoix/configs/default/anakin/default_ff_qr_dqn.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_qr_dqn - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_rainbow.yaml b/stoix/configs/default/anakin/default_ff_rainbow.yaml similarity index 74% rename from stoix/configs/default_ff_rainbow.yaml rename to stoix/configs/default/anakin/default_ff_rainbow.yaml index f6918367..49d3f99a 100644 --- a/stoix/configs/default_ff_rainbow.yaml +++ b/stoix/configs/default/anakin/default_ff_rainbow.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_noisy_dueling_c51 - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_reinforce.yaml b/stoix/configs/default/anakin/default_ff_reinforce.yaml similarity index 72% rename from stoix/configs/default_ff_reinforce.yaml rename to stoix/configs/default/anakin/default_ff_reinforce.yaml index b07db485..5bb1c5a0 100644 --- a/stoix/configs/default_ff_reinforce.yaml +++ b/stoix/configs/default/anakin/default_ff_reinforce.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_reinforce_continuous.yaml b/stoix/configs/default/anakin/default_ff_reinforce_continuous.yaml similarity index 74% rename from stoix/configs/default_ff_reinforce_continuous.yaml rename to stoix/configs/default/anakin/default_ff_reinforce_continuous.yaml index 82a08bd8..c5ee5f3f 100644 --- a/stoix/configs/default_ff_reinforce_continuous.yaml +++ b/stoix/configs/default/anakin/default_ff_reinforce_continuous.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_continuous - env: gymnax/pendulum - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_sac.yaml b/stoix/configs/default/anakin/default_ff_sac.yaml similarity index 71% rename from stoix/configs/default_ff_sac.yaml rename to stoix/configs/default/anakin/default_ff_sac.yaml index d63a34db..42b64277 100644 --- a/stoix/configs/default_ff_sac.yaml +++ b/stoix/configs/default/anakin/default_ff_sac.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_sac - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_sampled_az.yaml b/stoix/configs/default/anakin/default_ff_sampled_az.yaml similarity index 73% rename from stoix/configs/default_ff_sampled_az.yaml rename to stoix/configs/default/anakin/default_ff_sampled_az.yaml index f7e2749e..33c58f7e 100644 --- a/stoix/configs/default_ff_sampled_az.yaml +++ b/stoix/configs/default/anakin/default_ff_sampled_az.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_continuous - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_sampled_mz.yaml b/stoix/configs/default/anakin/default_ff_sampled_mz.yaml similarity index 74% rename from stoix/configs/default_ff_sampled_mz.yaml rename to stoix/configs/default/anakin/default_ff_sampled_mz.yaml index 3ba4b51b..cded3d25 100644 --- a/stoix/configs/default_ff_sampled_mz.yaml +++ b/stoix/configs/default/anakin/default_ff_sampled_mz.yaml @@ -5,3 +5,7 @@ defaults: - network: sampled_muzero - env: gymnax/pendulum - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_td3.yaml b/stoix/configs/default/anakin/default_ff_td3.yaml similarity index 71% rename from stoix/configs/default_ff_td3.yaml rename to stoix/configs/default/anakin/default_ff_td3.yaml index 257f43c9..15e82e6c 100644 --- a/stoix/configs/default_ff_td3.yaml +++ b/stoix/configs/default/anakin/default_ff_td3.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_ddpg - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_vmpo.yaml b/stoix/configs/default/anakin/default_ff_vmpo.yaml similarity index 71% rename from stoix/configs/default_ff_vmpo.yaml rename to stoix/configs/default/anakin/default_ff_vmpo.yaml index af900d24..803a4229 100644 --- a/stoix/configs/default_ff_vmpo.yaml +++ b/stoix/configs/default/anakin/default_ff_vmpo.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_ff_vmpo_continuous.yaml b/stoix/configs/default/anakin/default_ff_vmpo_continuous.yaml similarity index 73% rename from stoix/configs/default_ff_vmpo_continuous.yaml rename to stoix/configs/default/anakin/default_ff_vmpo_continuous.yaml index 1680acd4..1ab5747a 100644 --- a/stoix/configs/default_ff_vmpo_continuous.yaml +++ b/stoix/configs/default/anakin/default_ff_vmpo_continuous.yaml @@ -5,3 +5,7 @@ defaults: - network: mlp_continuous - env: brax/ant - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/default_rec_ppo.yaml b/stoix/configs/default/anakin/default_rec_ppo.yaml similarity index 71% rename from stoix/configs/default_rec_ppo.yaml rename to stoix/configs/default/anakin/default_rec_ppo.yaml index 22e49afe..a85398f8 100644 --- a/stoix/configs/default_rec_ppo.yaml +++ b/stoix/configs/default/anakin/default_rec_ppo.yaml @@ -5,3 +5,7 @@ defaults: - network: rnn - env: gymnax/cartpole - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/hyperparameter_sweep.yaml b/stoix/configs/default/anakin/hyperparameter_sweep.yaml similarity index 93% rename from stoix/configs/hyperparameter_sweep.yaml rename to stoix/configs/default/anakin/hyperparameter_sweep.yaml index 8e03dddf..418cee32 100644 --- a/stoix/configs/hyperparameter_sweep.yaml +++ b/stoix/configs/default/anakin/hyperparameter_sweep.yaml @@ -11,6 +11,8 @@ defaults: - _self_ hydra: + searchpath: + - file://stoix/configs mode: MULTIRUN sweeper: direction: maximize diff --git a/stoix/configs/default/sebulba/default_ff_ppo.yaml b/stoix/configs/default/sebulba/default_ff_ppo.yaml new file mode 100644 index 00000000..f0501dc3 --- /dev/null +++ b/stoix/configs/default/sebulba/default_ff_ppo.yaml @@ -0,0 +1,11 @@ +defaults: + - logger: base_logger + - arch: sebulba + - system: ff_ppo + - network: mlp + - env: envpool/cartpole + - _self_ + +hydra: + searchpath: + - file://stoix/configs diff --git a/stoix/configs/env/envpool/breakout.yaml b/stoix/configs/env/envpool/breakout.yaml new file mode 100644 index 00000000..135d9560 --- /dev/null +++ b/stoix/configs/env/envpool/breakout.yaml @@ -0,0 +1,22 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: Breakout-v5 + task_name: breakout # For logging purposes. + +kwargs: + episodic_life: True + repeat_action_probability: 0 + noop_max: 30 + full_action_space: False + max_episode_steps: 27000 + + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 400.0 diff --git a/stoix/configs/env/envpool/cartpole.yaml b/stoix/configs/env/envpool/cartpole.yaml new file mode 100644 index 00000000..a466bf31 --- /dev/null +++ b/stoix/configs/env/envpool/cartpole.yaml @@ -0,0 +1,16 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: CartPole-v1 + task_name: cartpole # For logging purposes. + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 500.0 diff --git a/stoix/configs/env/envpool/lunarlander.yaml b/stoix/configs/env/envpool/lunarlander.yaml new file mode 100644 index 00000000..eb161d02 --- /dev/null +++ b/stoix/configs/env/envpool/lunarlander.yaml @@ -0,0 +1,16 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: LunarLander-v2 + task_name: lunarlander # For logging purposes. + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 200.0 diff --git a/stoix/configs/env/envpool/pong.yaml b/stoix/configs/env/envpool/pong.yaml new file mode 100644 index 00000000..b95cd010 --- /dev/null +++ b/stoix/configs/env/envpool/pong.yaml @@ -0,0 +1,22 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: Pong-v5 + task_name: pong # For logging purposes. + +kwargs: + episodic_life: True + repeat_action_probability: 0 + noop_max: 30 + full_action_space: False + max_episode_steps: 27000 + + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 20.0 diff --git a/stoix/configs/env/envpool/vizdoom_basic.yaml b/stoix/configs/env/envpool/vizdoom_basic.yaml new file mode 100644 index 00000000..e2e588b7 --- /dev/null +++ b/stoix/configs/env/envpool/vizdoom_basic.yaml @@ -0,0 +1,19 @@ +# ---Environment Configs--- +env_name: envpool # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: Basic-v1 + task_name: vizdoom_basic # For logging purposes. + +kwargs: + episodic_life: True + use_combined_action : True + + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 100.0 diff --git a/stoix/configs/env/gymnasium/cartpole.yaml b/stoix/configs/env/gymnasium/cartpole.yaml new file mode 100644 index 00000000..5b0e071d --- /dev/null +++ b/stoix/configs/env/gymnasium/cartpole.yaml @@ -0,0 +1,16 @@ +# ---Environment Configs--- +env_name: gymnasium # Used for logging purposes and selection of the corresponding wrapper. + +scenario: + name: CartPole-v1 + task_name: cartpole # For logging purposes. + +kwargs: {} + +# Defines the metric that will be used to evaluate the performance of the agent. +# This metric is returned at the end of an experiment and can be used for hyperparameter tuning. +eval_metric: episode_return + +# optional - defines the threshold that needs to be reached in order to consider the environment solved. +# if present then solve rate will be logged. +solved_return_threshold: 500.0 diff --git a/stoix/configs/network/cnn.yaml b/stoix/configs/network/cnn.yaml index 6635f53b..f55200e7 100644 --- a/stoix/configs/network/cnn.yaml +++ b/stoix/configs/network/cnn.yaml @@ -7,6 +7,7 @@ actor_network: strides: [1, 1] use_layer_norm: False activation: silu + channel_first: True action_head: _target_: stoix.networks.heads.CategoricalHead @@ -18,5 +19,6 @@ critic_network: strides: [1, 1] use_layer_norm: False activation: silu + channel_first: True critic_head: _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/configs/network/visual_resnet.yaml b/stoix/configs/network/visual_resnet.yaml index d1a5ce3d..3495bfb4 100644 --- a/stoix/configs/network/visual_resnet.yaml +++ b/stoix/configs/network/visual_resnet.yaml @@ -6,6 +6,7 @@ actor_network: blocks_per_group: [2, 2, 2] use_layer_norm: False activation: silu + channel_first: True action_head: _target_: stoix.networks.heads.CategoricalHead @@ -16,5 +17,6 @@ critic_network: blocks_per_group: [2, 2, 2] use_layer_norm: False activation: silu + channel_first: True critic_head: _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/evaluator.py b/stoix/evaluator.py index 78838866..2ed3a100 100644 --- a/stoix/evaluator.py +++ b/stoix/evaluator.py @@ -1,9 +1,13 @@ -from typing import Dict, Optional, Tuple, Union +import math +import time +from typing import Any, Callable, Dict, Optional, Tuple, Union import chex import flax.linen as nn import jax import jax.numpy as jnp +import numpy as np +from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict from jumanji.env import Environment from omegaconf import DictConfig @@ -11,6 +15,7 @@ from stoix.base_types import ( ActFn, ActorApply, + EnvFactory, EvalFn, EvalState, ExperimentOutput, @@ -349,3 +354,119 @@ def evaluator_setup( eval_keys = jnp.stack(eval_keys).reshape(n_devices, -1) return evaluator, absolute_metric_evaluator, (trained_params, eval_keys) + + +##### THIS IS TEMPORARY + +SebulbaEvalFn = Callable[[FrozenDict, chex.PRNGKey], Dict[str, chex.Array]] + + +def get_sebulba_eval_fn( + env_factory: EnvFactory, + act_fn: ActFn, + config: DictConfig, + np_rng: np.random.Generator, + device: jax.Device, + eval_multiplier: float = 1.0, +) -> Tuple[SebulbaEvalFn, Any]: + """Creates a function that can be used to evaluate agents on a given environment. + + Args: + ---- + env: an environment that conforms to the mava environment spec. + act_fn: a function that takes in params, timestep, key and optionally a state + and returns actions and optionally a state (see `EvalActFn`). + config: the system config. + np_rng: a numpy random number generator. + eval_multiplier: a scalar that will increase the number of evaluation episodes + by a fixed factor. + """ + eval_episodes = config.arch.num_eval_episodes * eval_multiplier + + # We calculate here the number of parallel envs we can run in parallel. + # If the total number of episodes is less than the number of parallel envs + # we will run all episodes in parallel. + # Otherwise we will run `num_envs` parallel envs and loop enough times + # so that we do at least `eval_episodes` number of episodes. + n_parallel_envs = int(min(eval_episodes, config.arch.num_envs)) + episode_loops = math.ceil(eval_episodes / n_parallel_envs) + envs = env_factory(n_parallel_envs) + cpu = jax.devices("cpu")[0] + act_fn = jax.jit(act_fn, device=device) + + # Warnings if num eval episodes is not divisible by num parallel envs. + if eval_episodes % n_parallel_envs != 0: + msg = ( + f"Please note that the number of evaluation episodes ({eval_episodes}) is not " + f"evenly divisible by `num_envs`. As a result, some additional evaluations will be " + f"conducted. The adjusted number of evaluation episodes is now " + f"{episode_loops * n_parallel_envs}." + ) + print(f"{Fore.YELLOW}{Style.BRIGHT}{msg}{Style.RESET_ALL}") + + def eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Dict: + """Evaluates the given params on an environment and returns relevant metrics. + + Metrics are collected by the `RecordEpisodeMetrics` wrapper: episode return and length, + also win rate for environments that support it. + + Returns: Dict[str, Array] - dictionary of metric name to metric values for each episode. + """ + + def _run_episodes(key: chex.PRNGKey) -> Tuple[chex.PRNGKey, Dict]: + """Simulates `num_envs` episodes.""" + with jax.default_device(device): + # Reset the environment. + seeds = np_rng.integers(np.iinfo(np.int32).max, size=n_parallel_envs).tolist() + timestep = envs.reset(seed=seeds) + + all_metrics = [timestep.extras["metrics"]] + all_dones = [timestep.last()] + finished_eps = timestep.last() + + # Loop until all episodes are done. + while not finished_eps.all(): + key, act_key = jax.random.split(key) + action = act_fn(params, timestep.observation, act_key) + action_cpu = np.asarray(jax.device_put(action, cpu)) + timestep = envs.step(action_cpu) + all_metrics.append(timestep.extras["metrics"]) + all_dones.append(timestep.last()) + finished_eps = np.logical_or(finished_eps, timestep.last()) + + metrics = jax.tree.map(lambda *x: np.stack(x), *all_metrics) + dones = np.stack(all_dones) + + # find the first instance of done to get the metrics at that timestep, we don't + # care about subsequent steps because we only the results from the first episode + done_idx = np.argmax(dones, axis=0) + metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) + del metrics["is_terminal_step"] # unneeded for logging + + return key, metrics + + # This loop is important because we don't want too many parallel envs. + # So in evaluation we have num_envs parallel envs and loop enough times + # so that we do at least `eval_episodes` number of episodes. + metrics = [] + for _ in range(episode_loops): + key, metric = _run_episodes(key) + metrics.append(metric) + + metrics: Dict = jax.tree_map( + lambda *x: np.array(x).reshape(-1), *metrics + ) # flatten metrics + return metrics + + def timed_eval_fn(params: FrozenDict, key: chex.PRNGKey) -> Any: + """Wrapper around eval function to time it and add in steps per second metric.""" + start_time = time.perf_counter() + metrics = eval_fn(params, key) + end_time = time.perf_counter() + + total_timesteps = jnp.sum(metrics["episode_length"]) + metrics["steps_per_second"] = total_timesteps / (end_time - start_time) + metrics["evaluator_run_time"] = end_time - start_time + return metrics + + return timed_eval_fn, envs diff --git a/stoix/networks/postprocessors.py b/stoix/networks/postprocessors.py index 309aabb4..a2c5d518 100644 --- a/stoix/networks/postprocessors.py +++ b/stoix/networks/postprocessors.py @@ -66,7 +66,9 @@ class ScalePostProcessor(nn.Module): @nn.compact def __call__(self, distribution: Distribution) -> Distribution: - post_processor = partial(self.scale_fn, minimum=self.minimum, maximum=self.maximum) + post_processor = partial( + self.scale_fn, minimum=self.minimum, maximum=self.maximum + ) # type: ignore return PostProcessedDistribution(distribution, post_processor) diff --git a/stoix/networks/resnet.py b/stoix/networks/resnet.py index 25bd8378..a5308007 100644 --- a/stoix/networks/resnet.py +++ b/stoix/networks/resnet.py @@ -109,8 +109,10 @@ class VisualResNetTorso(nn.Module): channels_per_group: Sequence[int] = (16, 32, 32) blocks_per_group: Sequence[int] = (2, 2, 2) downsampling_strategies: Sequence[DownsamplingStrategy] = (DownsamplingStrategy.CONV,) * 3 + hidden_sizes: Sequence[int] = (256,) use_layer_norm: bool = False activation: str = "relu" + channel_first: bool = False @nn.compact def __call__(self, observation: chex.Array) -> chex.Array: @@ -118,6 +120,10 @@ def __call__(self, observation: chex.Array) -> chex.Array: if observation.ndim > 4: return nn.batch_apply.BatchApply(self.__call__)(observation) + # If the input is in the form of [B, C, H, W], we need to transpose it to [B, H, W, C] + if self.channel_first: + observation = observation.transpose((0, 2, 3, 1)) + assert ( observation.ndim == 4 ), f"Expected inputs to have shape [B, H, W, C] but got shape {observation.shape}." @@ -139,7 +145,12 @@ def __call__(self, observation: chex.Array) -> chex.Array: non_linearity=parse_activation_fn(self.activation), )(output) - return output.reshape(*observation.shape[:-3], -1) + output = output.reshape(*observation.shape[:-3], -1) + for num_hidden_units in self.hidden_sizes: + output = nn.Dense(features=num_hidden_units)(output) + output = parse_activation_fn(self.activation)(output) + + return output class ResNetTorso(nn.Module): diff --git a/stoix/networks/torso.py b/stoix/networks/torso.py index 59caae6b..8a46d4ad 100644 --- a/stoix/networks/torso.py +++ b/stoix/networks/torso.py @@ -63,11 +63,15 @@ class CNNTorso(nn.Module): activation: str = "relu" use_layer_norm: bool = False kernel_init: Initializer = orthogonal(np.sqrt(2.0)) + channel_first: bool = False @nn.compact def __call__(self, observation: chex.Array) -> chex.Array: """Forward pass.""" x = observation + # Move channels to the last dimension if they are first + if self.channel_first: + x = x.transpose((0, 2, 3, 1)) for channel, kernel, stride in zip(self.channel_sizes, self.kernel_sizes, self.strides): x = nn.Conv(channel, (kernel, kernel), (stride, stride))(x) if self.use_layer_norm: diff --git a/stoix/systems/awr/ff_awr.py b/stoix/systems/awr/ff_awr.py index a43e5e55..1b500336 100644 --- a/stoix/systems/awr/ff_awr.py +++ b/stoix/systems/awr/ff_awr.py @@ -645,7 +645,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_awr.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_awr.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/awr/ff_awr_continuous.py b/stoix/systems/awr/ff_awr_continuous.py index a4b3392b..b83686f5 100644 --- a/stoix/systems/awr/ff_awr_continuous.py +++ b/stoix/systems/awr/ff_awr_continuous.py @@ -651,7 +651,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_awr_continuous.yaml", version_base="1.2" + config_path="../../configs/default/anakin", + config_name="default_ff_awr_continuous.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/ddpg/ff_d4pg.py b/stoix/systems/ddpg/ff_d4pg.py index a504768e..74b759e3 100644 --- a/stoix/systems/ddpg/ff_d4pg.py +++ b/stoix/systems/ddpg/ff_d4pg.py @@ -695,7 +695,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_d4pg.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_d4pg.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/ddpg/ff_ddpg.py b/stoix/systems/ddpg/ff_ddpg.py index 3a17aa4f..313934b1 100644 --- a/stoix/systems/ddpg/ff_ddpg.py +++ b/stoix/systems/ddpg/ff_ddpg.py @@ -648,7 +648,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ddpg.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_ddpg.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/ddpg/ff_td3.py b/stoix/systems/ddpg/ff_td3.py index 96e42cd7..09e7bc09 100644 --- a/stoix/systems/ddpg/ff_td3.py +++ b/stoix/systems/ddpg/ff_td3.py @@ -674,7 +674,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_td3.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_td3.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/mpo/ff_mpo.py b/stoix/systems/mpo/ff_mpo.py index e64d3171..00a96bd1 100644 --- a/stoix/systems/mpo/ff_mpo.py +++ b/stoix/systems/mpo/ff_mpo.py @@ -752,7 +752,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_mpo.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_mpo.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/mpo/ff_mpo_continuous.py b/stoix/systems/mpo/ff_mpo_continuous.py index a2ef6f05..e56121ef 100644 --- a/stoix/systems/mpo/ff_mpo_continuous.py +++ b/stoix/systems/mpo/ff_mpo_continuous.py @@ -784,7 +784,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_mpo_continuous.yaml", version_base="1.2" + config_path="../../configs/default/anakin", + config_name="default_ff_mpo_continuous.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/mpo/ff_vmpo.py b/stoix/systems/mpo/ff_vmpo.py index 7f5891f3..052b36cd 100644 --- a/stoix/systems/mpo/ff_vmpo.py +++ b/stoix/systems/mpo/ff_vmpo.py @@ -602,7 +602,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_vmpo.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_vmpo.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/mpo/ff_vmpo_continuous.py b/stoix/systems/mpo/ff_vmpo_continuous.py index e20fc0ff..5b5b3690 100644 --- a/stoix/systems/mpo/ff_vmpo_continuous.py +++ b/stoix/systems/mpo/ff_vmpo_continuous.py @@ -678,7 +678,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_vmpo_continuous.yaml", version_base="1.2" + config_path="../../configs/default/anakin", + config_name="default_ff_vmpo_continuous.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/ppo/ff_dpo_continuous.py b/stoix/systems/ppo/anakin/ff_dpo_continuous.py similarity index 96% rename from stoix/systems/ppo/ff_dpo_continuous.py rename to stoix/systems/ppo/anakin/ff_dpo_continuous.py index c8743a27..56764047 100644 --- a/stoix/systems/ppo/ff_dpo_continuous.py +++ b/stoix/systems/ppo/anakin/ff_dpo_continuous.py @@ -21,7 +21,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -47,14 +47,16 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -72,7 +74,9 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, PPOTransition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -101,7 +105,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra last_timestep.observation, info, ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -272,11 +276,11 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, last_timestep) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -308,7 +312,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -420,7 +424,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -568,7 +572,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_dpo_continuous.yaml", version_base="1.2" + config_path="../../../configs/default/anakin", + config_name="default_ff_dpo_continuous.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/ppo/ff_ppo.py b/stoix/systems/ppo/anakin/ff_ppo.py similarity index 96% rename from stoix/systems/ppo/ff_ppo.py rename to stoix/systems/ppo/anakin/ff_ppo.py index ebed06b7..db729d98 100644 --- a/stoix/systems/ppo/ff_ppo.py +++ b/stoix/systems/ppo/anakin/ff_ppo.py @@ -21,7 +21,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -47,14 +47,16 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -72,7 +74,9 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, PPOTransition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -101,7 +105,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra last_timestep.observation, info, ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -267,11 +271,11 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, last_timestep) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -303,7 +307,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -410,7 +414,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -557,7 +561,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ppo.yaml", version_base="1.2") +@hydra.main( + config_path="../../../configs/default/anakin", + config_name="default_ff_ppo.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/ppo/ff_ppo_continuous.py b/stoix/systems/ppo/anakin/ff_ppo_continuous.py similarity index 96% rename from stoix/systems/ppo/ff_ppo_continuous.py rename to stoix/systems/ppo/anakin/ff_ppo_continuous.py index eef74838..bdb4bbb0 100644 --- a/stoix/systems/ppo/ff_ppo_continuous.py +++ b/stoix/systems/ppo/anakin/ff_ppo_continuous.py @@ -21,7 +21,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -47,14 +47,16 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -72,7 +74,9 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, PPOTransition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -101,7 +105,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra last_timestep.observation, info, ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -273,11 +277,11 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, last_timestep) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -309,7 +313,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -421,7 +425,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -569,7 +573,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_ppo_continuous.yaml", version_base="1.2" + config_path="../../../configs/default/anakin", + config_name="default_ff_ppo_continuous.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/ppo/ff_ppo_penalty.py b/stoix/systems/ppo/anakin/ff_ppo_penalty.py similarity index 96% rename from stoix/systems/ppo/ff_ppo_penalty.py rename to stoix/systems/ppo/anakin/ff_ppo_penalty.py index 8b578570..b3240620 100644 --- a/stoix/systems/ppo/ff_ppo_penalty.py +++ b/stoix/systems/ppo/anakin/ff_ppo_penalty.py @@ -21,7 +21,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -47,14 +47,16 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -72,7 +74,9 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, PPOTransition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -101,7 +105,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra last_timestep.observation, info, ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -276,11 +280,11 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, last_timestep) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -312,7 +316,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -419,7 +423,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -566,7 +570,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ppo.yaml", version_base="1.2") +@hydra.main( + config_path="../../../configs/default/anakin", + config_name="default_ff_ppo.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/ppo/ff_ppo_penalty_continuous.py b/stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py similarity index 96% rename from stoix/systems/ppo/ff_ppo_penalty_continuous.py rename to stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py index 6ab7011a..46e0285c 100644 --- a/stoix/systems/ppo/ff_ppo_penalty_continuous.py +++ b/stoix/systems/ppo/anakin/ff_ppo_penalty_continuous.py @@ -21,7 +21,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -47,14 +47,16 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: """A single update of the network. This function steps the environment and records the trajectory batch for @@ -72,7 +74,9 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, PPOTransition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -101,7 +105,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra last_timestep.observation, info, ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -281,11 +285,11 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, last_timestep) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: """Learner function. This function represents the learner, it updates the network parameters @@ -317,7 +321,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -429,7 +433,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -577,7 +581,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_ppo_continuous.yaml", version_base="1.2" + config_path="../../../configs/default/anakin", + config_name="default_ff_ppo_continuous.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/ppo/rec_ppo.py b/stoix/systems/ppo/anakin/rec_ppo.py similarity index 99% rename from stoix/systems/ppo/rec_ppo.py rename to stoix/systems/ppo/anakin/rec_ppo.py index 78202123..be6ac789 100644 --- a/stoix/systems/ppo/rec_ppo.py +++ b/stoix/systems/ppo/anakin/rec_ppo.py @@ -744,7 +744,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_rec_ppo.yaml", version_base="1.2") +@hydra.main( + config_path="../../../configs/default/anakin", + config_name="default_rec_ppo.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py new file mode 100644 index 00000000..934a37c6 --- /dev/null +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -0,0 +1,875 @@ +import copy +import queue +import threading +import warnings +from collections import defaultdict +from queue import Queue +from typing import Any, Callable, Dict, List, Sequence, Tuple + +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import numpy as np +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.base_types import ( + ActorApply, + ActorCriticOptStates, + ActorCriticParams, + CoreLearnerState, + CriticApply, + ExperimentOutput, + Observation, + SebulbaLearnerFn, +) +from stoix.evaluator import get_distribution_act_fn, get_sebulba_eval_fn +from stoix.networks.base import FeedForwardActor as Actor +from stoix.networks.base import FeedForwardCritic as Critic +from stoix.systems.ppo.ppo_types import PPOTransition +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.env_factory import EnvFactory +from stoix.utils.jax_utils import merge_leading_dims +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.loss import clipped_value_loss, ppo_clip_loss +from stoix.utils.multistep import batch_truncated_generalized_advantage_estimation +from stoix.utils.sebulba_utils import ( + OnPolicyPipeline, + ParamsSource, + RecordTimeTo, + ThreadLifetime, +) +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate +from stoix.wrappers.episode_metrics import get_final_step_metrics + + +def get_act_fn( + apply_fns: Tuple[ActorApply, CriticApply] +) -> Callable[ + [ActorCriticParams, Observation, chex.PRNGKey], Tuple[chex.Array, chex.Array, chex.Array] +]: + """Get the act function that is used by the actor threads.""" + actor_apply_fn, critic_apply_fn = apply_fns + + def actor_fn( + params: ActorCriticParams, observation: Observation, rng_key: chex.PRNGKey + ) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Get the action, value and log_prob from the actor and critic networks.""" + rng_key, policy_key = jax.random.split(rng_key) + pi = actor_apply_fn(params.actor_params, observation) + value = critic_apply_fn(params.critic_params, observation) + action = pi.sample(seed=policy_key) + log_prob = pi.log_prob(action) + return action, value, log_prob + + return actor_fn + + +def get_rollout_fn( + env_factory: EnvFactory, + actor_device: jax.Device, + params_source: ParamsSource, + pipeline: OnPolicyPipeline, + apply_fns: Tuple[ActorApply, CriticApply], + config: DictConfig, + seeds: List[int], + thread_lifetime: ThreadLifetime, +) -> Callable[[chex.PRNGKey], None]: + """Get the rollout function that is used by the actor threads.""" + # Unpack and set up the functions + act_fn = get_act_fn(apply_fns) + act_fn = jax.jit(act_fn, device=actor_device) + cpu = jax.devices("cpu")[0] + move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, actor_device), tree) + split_key_fn = jax.jit(jax.random.split, device=actor_device) + # Build the environments + envs = env_factory(config.arch.actor.envs_per_actor) + + # Create the rollout function + def rollout_fn(rng_key: chex.PRNGKey) -> None: + # Ensure all computation is on the actor device + with jax.default_device(actor_device): + # Reset the environment + timestep = envs.reset(seed=seeds) + + # Loop until the thread is stopped + while not thread_lifetime.should_stop(): + # Create the list to store transitions + traj: List[PPOTransition] = [] + # Create the dictionary to store timings for metrics + timings_dict: Dict[str, List[float]] = defaultdict(list) + # Rollout the environment + with RecordTimeTo(timings_dict["single_rollout_time"]): + # Loop until the rollout length is reached + for _ in range(config.system.rollout_length): + # Get the latest parameters from the source + with RecordTimeTo(timings_dict["get_params_time"]): + params = params_source.get() + + # Move the environment data to the actor device + cached_obs = move_to_device(timestep.observation) + + # Run the actor and critic networks to get the action, value and log_prob + with RecordTimeTo(timings_dict["compute_action_time"]): + rng_key, policy_key = split_key_fn(rng_key) + action, value, log_prob = act_fn(params, cached_obs, policy_key) + + # Move the action to the CPU + action_cpu = np.asarray(jax.device_put(action, cpu)) + + # Step the environment + with RecordTimeTo(timings_dict["env_step_time"]): + timestep = envs.step(action_cpu) + + # Get the next dones and truncation flags + dones = np.logical_and( + np.asarray(timestep.last()), np.asarray(timestep.discount == 0.0) + ) + trunc = np.logical_and( + np.asarray(timestep.last()), np.asarray(timestep.discount == 1.0) + ) + cached_next_dones = move_to_device(dones) + cached_next_trunc = move_to_device(trunc) + + # Append PPOTransition to the trajectory list + reward = timestep.reward + metrics = timestep.extras["metrics"] + traj.append( + PPOTransition( + cached_next_dones, + cached_next_trunc, + action, + value, + reward, + log_prob, + cached_obs, + metrics, + ) + ) + + # Send the trajectory to the pipeline + with RecordTimeTo(timings_dict["rollout_put_time"]): + try: + pipeline.put(traj, timestep, timings_dict) + except queue.Full: + warnings.warn( + "Waited too long to add to the rollout queue, killing the actor thread", + stacklevel=2, + ) + break + + # Close the environments + envs.close() + + return rollout_fn + + +def get_actor_thread( + env_factory: EnvFactory, + actor_device: jax.Device, + params_source: ParamsSource, + pipeline: OnPolicyPipeline, + apply_fns: Tuple[ActorApply, CriticApply], + rng_key: chex.PRNGKey, + config: DictConfig, + seeds: List[int], + thread_lifetime: ThreadLifetime, + name: str, +) -> threading.Thread: + """Get the actor thread that once started will collect data from the + environment and send it to the pipeline.""" + rng_key = jax.device_put(rng_key, actor_device) + + rollout_fn = get_rollout_fn( + env_factory, + actor_device, + params_source, + pipeline, + apply_fns, + config, + seeds, + thread_lifetime, + ) + + actor = threading.Thread( + target=rollout_fn, + args=(rng_key,), + name=name, + ) + + return actor + + +def get_learner_step_fn( + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> SebulbaLearnerFn[CoreLearnerState, PPOTransition]: + """Get the learner update function which is used to update the actor and critic networks. + This function is used by the learner thread to update the networks.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step( + learner_state: CoreLearnerState, traj_batch: PPOTransition + ) -> Tuple[CoreLearnerState, Tuple]: + + # CALCULATE ADVANTAGE + params, opt_states, key, last_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, last_timestep.observation) + + r_t = traj_batch.reward + v_t = jnp.concatenate([traj_batch.value, last_val[None, ...]], axis=0) + d_t = 1.0 - traj_batch.done.astype(jnp.float32) + d_t = (d_t * config.system.gamma).astype(jnp.float32) + advantages, targets = batch_truncated_generalized_advantage_estimation( + r_t, + d_t, + config.system.gae_lambda, + v_t, + time_major=True, + standardize_advantages=config.system.standardize_advantages, + truncation_flags=traj_batch.truncated, + ) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + traj_batch: PPOTransition, + gae: chex.Array, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # CALCULATE ACTOR LOSS + loss_actor = ppo_clip_loss( + log_prob, traj_batch.log_prob, gae, config.system.clip_eps + ) + entropy = actor_policy.entropy().mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + loss_info = { + "actor_loss": loss_actor, + "entropy": entropy, + } + return total_loss_actor, loss_info + + def _critic_loss_fn( + critic_params: FrozenDict, + traj_batch: PPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + value = critic_apply_fn(critic_params, traj_batch.obs) + + # CALCULATE VALUE LOSS + value_loss = clipped_value_loss( + value, traj_batch.value, targets, config.system.clip_eps + ) + + critic_total_loss = config.system.vf_coef * value_loss + loss_info = { + "value_loss": value_loss, + } + return critic_total_loss, loss_info + + # CALCULATE ACTOR LOSS + actor_grad_fn = jax.grad(_actor_loss_fn, has_aux=True) + actor_grads, actor_loss_info = actor_grad_fn( + params.actor_params, traj_batch, advantages + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.grad(_critic_loss_fn, has_aux=True) + critic_grads, critic_loss_info = critic_grad_fn( + params.critic_params, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This pmean could be a regular mean as the batch axis is on the same device. + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = ActorCriticParams(actor_new_params, critic_new_params) + new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + loss_info = { + **actor_loss_info, + **critic_loss_info, + } + return (new_params, new_opt_state), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key = jax.random.split(key) + + # SHUFFLE MINIBATCHES + # Since we shard the envs per actor across the devices + envs_per_batch = config.arch.actor.envs_per_actor // len(config.arch.learner.device_ids) + batch_size = config.system.rollout_length * envs_per_batch + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), + shuffled_batch, + ) + + # UPDATE MINIBATCHES + (params, opt_states), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = CoreLearnerState(params, opt_states, key, last_timestep) + metrics = traj_batch.info + return learner_state, (metrics, loss_info) + + def learner_step_fn( + learner_state: CoreLearnerState, traj_batch: PPOTransition + ) -> ExperimentOutput[CoreLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + """ + + learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_step_fn + + +def get_learner_rollout_fn( + learn_step: SebulbaLearnerFn[CoreLearnerState, PPOTransition], + config: DictConfig, + eval_queue: Queue, + pipeline: OnPolicyPipeline, + params_sources: Sequence[ParamsSource], +) -> Callable[[CoreLearnerState], None]: + """Get the learner rollout function that is used by the learner thread to update the networks. + This function is what is actually run by the learner thread. It gets the data from the pipeline + and uses the learner update function to update the networks. It then sends these intermediate + network parameters to a queue for evaluation.""" + + def learner_rollout(learner_state: CoreLearnerState) -> None: + # Loop for the total number of evaluations selected to be performed. + for _ in range(config.arch.num_evaluation): + # Create the lists to store metrics and timings for this learning iteration. + metrics: List[Tuple[Dict, Dict]] = [] + rollout_times: List[Dict] = [] + q_sizes: List[int] = [] + learn_timings: Dict[str, List[float]] = defaultdict(list) + # Loop for the number of updates per evaluation + for _ in range(config.arch.num_updates_per_eval): + # Get the trajectory batch from the pipeline + # This is blocking so it will wait until the pipeline has data. + with RecordTimeTo(learn_timings["rollout_get_time"]): + traj_batch, timestep, rollout_time = pipeline.get(block=True) # type: ignore + # We then replace the timestep in the learner state with the latest timestep + # This means the learner has access to the entire trajectory as well as + # an additional timestep which it can use to bootstrap. + learner_state = learner_state._replace(timestep=timestep) + # We then call the update function to update the networks + with RecordTimeTo(learn_timings["learning_time"]): + learner_state, episode_metrics, train_metrics = learn_step( + learner_state, traj_batch + ) + + # We store the metrics and timings for this update + metrics.append((episode_metrics, train_metrics)) + rollout_times.append(rollout_time) + q_sizes.append(pipeline.qsize()) + + # After the update we need to update the params sources with the new params + unreplicated_params = unreplicate(learner_state.params) + # We loop over all params sources and update them with the new params + # This is so that all the actors can get the latest params + for source in params_sources: + source.update(unreplicated_params) + + # We then pass all the environment metrics, training metrics, current learner state + # and timings to the evaluation queue. This is so the evaluator correctly evaluates + # the performance of the networks at this point in time. + episode_metrics, train_metrics = jax.tree.map(lambda *x: np.asarray(x), *metrics) + rollout_times = jax.tree.map(lambda *x: np.mean(x), *rollout_times) + timing_dict = rollout_times | learn_timings + timing_dict["pipeline_qsize"] = q_sizes + timing_dict = jax.tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) + try: + # We add a timeout mainly for sanity checks + # If the queue is full for more than 60 seconds we kill the learner thread + # This should never happen + eval_queue.put( + (episode_metrics, train_metrics, learner_state, timing_dict), timeout=60 + ) + except queue.Full: + warnings.warn( + "Waited too long to add to the evaluation queue, killing the learner thread. " + "This should not happen.", + stacklevel=2, + ) + break + + return learner_rollout + + +def get_learner_thread( + learn: SebulbaLearnerFn[CoreLearnerState, PPOTransition], + learner_state: CoreLearnerState, + config: DictConfig, + eval_queue: Queue, + pipeline: OnPolicyPipeline, + params_sources: Sequence[ParamsSource], +) -> threading.Thread: + """Get the learner thread that is used to update the networks.""" + + learner_rollout_fn = get_learner_rollout_fn(learn, config, eval_queue, pipeline, params_sources) + + learner_thread = threading.Thread( + target=learner_rollout_fn, + args=(learner_state,), + name="Learner", + ) + + return learner_thread + + +def learner_setup( + env_factory: EnvFactory, + keys: chex.Array, + learner_devices: Sequence[jax.Device], + config: DictConfig, +) -> Tuple[ + SebulbaLearnerFn[CoreLearnerState, PPOTransition], + Tuple[ActorApply, CriticApply], + CoreLearnerState, +]: + """Setup for the learner state and networks.""" + + # Create a single environment just to get the observation and action specs. + env = env_factory(num_envs=1) + # Get number/dimension of actions. + num_actions = int(env.action_spec().num_values) + config.system.action_dim = num_actions + example_obs = env.observation_spec().generate_value() + env.close() + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, action_dim=num_actions + ) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso, critic_head=critic_head) + + actor_lr = make_learning_rate( + config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches + ) + critic_lr = make_learning_rate( + config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches + ) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation + init_x = example_obs + init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_net_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Pack params. + params = ActorCriticParams(actor_params, critic_params) + + # Extract apply functions. + actor_network_apply_fn = actor_network.apply + critic_network_apply_fn = critic_network.apply + + # Pack apply and update functions. + apply_fns = (actor_network_apply_fn, critic_network_apply_fn) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over cores. + learn_step = get_learner_step_fn(apply_fns, update_fns, config) + learn_step = jax.pmap(learn_step, axis_name="device") + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.system.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params() + # Update the params + params = restored_params + + # Define params to be replicated across learner devices. + opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states) + + # Duplicate across learner devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) + + # Initialise learner state. + params, opt_states = replicate_learner + key, step_key = jax.random.split(key) + step_keys = jax.random.split(step_key, len(learner_devices)) + init_learner_state = CoreLearnerState(params, opt_states, step_keys, None) + + return learn_step, apply_fns, init_learner_state + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Perform some checks on the config + # This additionally calculates certains + # values based on the config + config = check_total_timesteps(config) + + assert ( + config.arch.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Calculate the number of updates per evaluation + config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation) + + # Get the learner and actor devices + local_devices = jax.local_devices() + global_devices = jax.devices() + assert len(local_devices) == len( + global_devices + ), "Local and global devices must be the same for now. We dont support multihost just yet" + # Extract the actor and learner devices + actor_devices = [local_devices[device_id] for device_id in config.arch.actor.device_ids] + local_learner_devices = [ + local_devices[device_id] for device_id in config.arch.learner.device_ids + ] + # For evaluation we simply use the first learner device + evaluator_device = local_learner_devices[0] + print(f"{Fore.BLUE}{Style.BRIGHT}Actors devices: {actor_devices}{Style.RESET_ALL}") + print(f"{Fore.GREEN}{Style.BRIGHT}Learner devices: {local_learner_devices}{Style.RESET_ALL}") + print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}") + # Set the number of learning and acting devices in the config + # useful for keeping track of experimental setup + config.num_learning_devices = len(local_learner_devices) + config.num_actor_actor_devices = len(actor_devices) + + # Calculate the number of envs per actor + assert ( + config.arch.num_envs == config.arch.total_num_envs + ), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures" + # We first simply take the total number of envs and divide by the number of actor devices + # to get the number of envs per actor device + num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices) + # We then divide this by the number of actors per device to get the number of envs per actor + num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device) + config.arch.actor.envs_per_actor = num_envs_per_actor + + # We then perform a simple check to ensure that the number of envs per actor is + # divisible by the number of learner devices. This is because we shard the envs + # per actor across the learner devices This check is mainly relevant for on-policy + # algorithms + assert num_envs_per_actor % len(local_learner_devices) == 0, ( + f"The number of envs per actor must be divisible by the number of learner devices. " + f"Got {num_envs_per_actor} envs per actor and {len(local_learner_devices)} learner devices" + ) + + # Create the environment factory. + env_factory = environments.make_factory(config) + assert isinstance( + env_factory, EnvFactory + ), "Environment factory must be an instance of EnvFactory" + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config.arch.seed), num=4 + ) + np_rng = np.random.default_rng(config.arch.seed) + + # Setup learner. + learn_step, apply_fns, learner_state = learner_setup( + env_factory, (key, actor_net_key, critic_net_key), local_learner_devices, config + ) + actor_apply_fn, _ = apply_fns + eval_act_fn = get_distribution_act_fn(config, actor_apply_fn) + # Setup evaluator. + evaluator, evaluator_envs = get_sebulba_eval_fn( + env_factory, eval_act_fn, config, np_rng, evaluator_device + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.system.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Get initial parameters + initial_params = unreplicate(learner_state.params) + + # Get the number of steps per rollout + steps_per_rollout = ( + config.system.rollout_length * config.arch.total_num_envs * config.arch.num_updates_per_eval + ) + + # Creating the pipeline + # First we create the lifetime so we can stop the pipeline when we want + pipeline_lifetime = ThreadLifetime() + # Now we create the pipeline + pipeline = OnPolicyPipeline( + config.arch.pipeline_queue_size, local_learner_devices, pipeline_lifetime + ) + # Start the pipeline + pipeline.start() + + # Create a single lifetime for all the actors and params sources + actors_lifetime = ThreadLifetime() + params_sources_lifetime = ThreadLifetime() + + # Create the params sources and actor threads + params_sources: List[ParamsSource] = [] + actor_threads: List[threading.Thread] = [] + for actor_device in actor_devices: + # Create 1 params source per actor device as this will be used + # to pass the params to the actors + params_source = ParamsSource(initial_params, actor_device, params_sources_lifetime) + params_source.start() + params_sources.append(params_source) + # Now for each device we choose to create multiple actor threads + for i in range(config.arch.actor.actor_per_device): + key, actors_key = jax.random.split(key) + seeds = np_rng.integers( + np.iinfo(np.int32).max, size=config.arch.actor.envs_per_actor + ).tolist() + actor_thread = get_actor_thread( + env_factory, + actor_device, + params_source, + pipeline, + apply_fns, + actors_key, + config, + seeds, + actors_lifetime, + f"Actor-{actor_device}-{i}", + ) + actor_thread.start() + actor_threads.append(actor_thread) + + # Create the evaluation queue + eval_queue: Queue = Queue(maxsize=config.arch.num_evaluation) + # Create the learner thread + learner_thread = get_learner_thread( + learn_step, learner_state, config, eval_queue, pipeline, params_sources + ) + learner_thread.start() + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(-1e7) + best_params = initial_params.actor_params + # This is the main loop, all it does is evaluation and logging. + # Acting and learning is happening in their own threads. + # This loop waits for the learner to finish an update before evaluation and logging. + for eval_step in range(config.arch.num_evaluation): + # Get the next set of params and metrics from the learner + episode_metrics, train_metrics, learner_state, timings_dict = eval_queue.get(block=True) + + # Log the metrics and timings + t = int(steps_per_rollout * (eval_step + 1)) + timings_dict["timestep"] = t + logger.log(timings_dict, t, eval_step, LogEvent.MISC) + + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = ( + steps_per_rollout / timings_dict["single_rollout_time"] + ) + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + # Evaluate the current model and log the metrics + unreplicated_actor_params = unreplicate(learner_state.params.actor_params) + key, eval_key = jax.random.split(key, 2) + eval_metrics = evaluator(unreplicated_actor_params, eval_key) + logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) + + episode_return = jnp.mean(eval_metrics["episode_return"]) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=unreplicate(learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(unreplicated_actor_params) + max_episode_return = episode_return + + evaluator_envs.close() + eval_performance = float(jnp.mean(eval_metrics[config.env.eval_metric])) + + print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing learner...{Style.RESET_ALL}") + # Now we stop the learner + learner_thread.join() + + # First we stop all actors + actors_lifetime.stop() + + # Now we stop the actors and params sources + print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing actors...{Style.RESET_ALL}") + for actor in actor_threads: + # We clear the pipeline before stopping each actor thread + # since actors can be blocked on the pipeline + pipeline.clear() + actor.join() + + print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing pipeline...{Style.RESET_ALL}") + # Stop the pipeline + pipeline_lifetime.stop() + pipeline.join() + + print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing params sources...{Style.RESET_ALL}") + # Stop the params sources + params_sources_lifetime.stop() + for param_source in params_sources: + param_source.join() + + # Measure absolute metric. + if config.arch.absolute_metric: + print(f"{Fore.MAGENTA}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") + abs_metric_evaluator, abs_metric_evaluator_envs = get_sebulba_eval_fn( + env_factory, eval_act_fn, config, np_rng, evaluator_device, eval_multiplier=10 + ) + key, eval_key = jax.random.split(key, 2) + eval_metrics = abs_metric_evaluator(best_params, eval_key) + + t = int(steps_per_rollout * (eval_step + 1)) + logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) + abs_metric_evaluator_envs.close() + + # Stop the logger. + logger.stop() + + return eval_performance + + +@hydra.main( + config_path="../../../configs/default/sebulba", + config_name="default_ff_ppo.yaml", + version_base="1.2", +) +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + eval_performance = run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}") + return eval_performance + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/stoix/systems/q_learning/ff_c51.py b/stoix/systems/q_learning/ff_c51.py index 00b2edef..073d2dfd 100644 --- a/stoix/systems/q_learning/ff_c51.py +++ b/stoix/systems/q_learning/ff_c51.py @@ -559,7 +559,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_c51.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_c51.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/q_learning/ff_ddqn.py b/stoix/systems/q_learning/ff_ddqn.py index 516ec870..b7f45b06 100644 --- a/stoix/systems/q_learning/ff_ddqn.py +++ b/stoix/systems/q_learning/ff_ddqn.py @@ -546,7 +546,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_ddqn.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_ddqn.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/q_learning/ff_dqn.py b/stoix/systems/q_learning/ff_dqn.py index 5e938a1b..cf0de979 100644 --- a/stoix/systems/q_learning/ff_dqn.py +++ b/stoix/systems/q_learning/ff_dqn.py @@ -552,7 +552,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_dqn.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_dqn.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/q_learning/ff_dqn_reg.py b/stoix/systems/q_learning/ff_dqn_reg.py index 34fb5656..47df21bf 100644 --- a/stoix/systems/q_learning/ff_dqn_reg.py +++ b/stoix/systems/q_learning/ff_dqn_reg.py @@ -549,7 +549,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_dqn_reg.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_dqn_reg.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/q_learning/ff_mdqn.py b/stoix/systems/q_learning/ff_mdqn.py index 9d72a2f8..cb154988 100644 --- a/stoix/systems/q_learning/ff_mdqn.py +++ b/stoix/systems/q_learning/ff_mdqn.py @@ -549,7 +549,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_mdqn.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_mdqn.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/q_learning/ff_qr_dqn.py b/stoix/systems/q_learning/ff_qr_dqn.py index e92907fe..e5be23a8 100644 --- a/stoix/systems/q_learning/ff_qr_dqn.py +++ b/stoix/systems/q_learning/ff_qr_dqn.py @@ -573,7 +573,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_qr_dqn.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_qr_dqn.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/q_learning/ff_rainbow.py b/stoix/systems/q_learning/ff_rainbow.py index 97134a26..52426c68 100644 --- a/stoix/systems/q_learning/ff_rainbow.py +++ b/stoix/systems/q_learning/ff_rainbow.py @@ -647,7 +647,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_rainbow.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_rainbow.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/sac/ff_sac.py b/stoix/systems/sac/ff_sac.py index 45ecadf0..859fecf6 100644 --- a/stoix/systems/sac/ff_sac.py +++ b/stoix/systems/sac/ff_sac.py @@ -666,7 +666,11 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_sac.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", + config_name="default_ff_sac.yaml", + version_base="1.2", +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/search/ff_az.py b/stoix/systems/search/ff_az.py index a5b5b0fb..1e236f33 100644 --- a/stoix/systems/search/ff_az.py +++ b/stoix/systems/search/ff_az.py @@ -710,7 +710,9 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_az.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", config_name="default_ff_az.yaml", version_base="1.2" +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/search/ff_mz.py b/stoix/systems/search/ff_mz.py index d749f252..5a7b322c 100644 --- a/stoix/systems/search/ff_mz.py +++ b/stoix/systems/search/ff_mz.py @@ -826,7 +826,9 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance -@hydra.main(config_path="../../configs", config_name="default_ff_mz.yaml", version_base="1.2") +@hydra.main( + config_path="../../configs/default/anakin", config_name="default_ff_mz.yaml", version_base="1.2" +) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" # Allow dynamic attributes. diff --git a/stoix/systems/search/ff_sampled_az.py b/stoix/systems/search/ff_sampled_az.py index e3a7b76c..14701b8f 100644 --- a/stoix/systems/search/ff_sampled_az.py +++ b/stoix/systems/search/ff_sampled_az.py @@ -844,7 +844,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_sampled_az.yaml", version_base="1.2" + config_path="../../configs/default/anakin", + config_name="default_ff_sampled_az.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/search/ff_sampled_mz.py b/stoix/systems/search/ff_sampled_mz.py index 5860cd8a..e41b3b50 100644 --- a/stoix/systems/search/ff_sampled_mz.py +++ b/stoix/systems/search/ff_sampled_mz.py @@ -958,7 +958,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_sampled_mz.yaml", version_base="1.2" + config_path="../../configs/default/anakin", + config_name="default_ff_sampled_mz.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/vpg/ff_reinforce.py b/stoix/systems/vpg/ff_reinforce.py index aa5a0c68..5b99ae75 100644 --- a/stoix/systems/vpg/ff_reinforce.py +++ b/stoix/systems/vpg/ff_reinforce.py @@ -22,7 +22,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -43,15 +43,19 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Transition]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Transition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -71,7 +75,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Transi transition = Transition( done, action, value, timestep.reward, last_timestep.observation, info ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -190,11 +194,13 @@ def _critic_loss_fn( **critic_loss_info, } - learner_state = LearnerState(new_params, new_opt_state, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState( + new_params, new_opt_state, key, env_state, last_timestep + ) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") @@ -212,7 +218,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -315,7 +321,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -463,7 +469,9 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", config_name="default_ff_reinforce.yaml", version_base="1.2" + config_path="../../configs/default/anakin", + config_name="default_ff_reinforce.yaml", + version_base="1.2", ) def hydra_entry_point(cfg: DictConfig) -> float: """Experiment entry point.""" diff --git a/stoix/systems/vpg/ff_reinforce_continuous.py b/stoix/systems/vpg/ff_reinforce_continuous.py index 66bd744a..df0d9af0 100644 --- a/stoix/systems/vpg/ff_reinforce_continuous.py +++ b/stoix/systems/vpg/ff_reinforce_continuous.py @@ -22,7 +22,7 @@ CriticApply, ExperimentOutput, LearnerFn, - LearnerState, + OnPolicyLearnerState, ) from stoix.evaluator import evaluator_setup, get_distribution_act_fn from stoix.networks.base import FeedForwardActor as Actor @@ -43,15 +43,19 @@ def get_learner_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> LearnerFn[LearnerState]: +) -> LearnerFn[OnPolicyLearnerState]: """Get the learner function.""" # Get apply and update functions for actor and critic networks. actor_apply_fn, critic_apply_fn = apply_fns actor_update_fn, critic_update_fn = update_fns - def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Transition]: + def _update_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Tuple]: + def _env_step( + learner_state: OnPolicyLearnerState, _: Any + ) -> Tuple[OnPolicyLearnerState, Transition]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state @@ -71,7 +75,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Transi transition = Transition( done, action, value, timestep.reward, last_timestep.observation, info ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = OnPolicyLearnerState(params, opt_states, key, env_state, timestep) return learner_state, transition # STEP ENVIRONMENT FOR ROLLOUT LENGTH @@ -188,11 +192,13 @@ def _critic_loss_fn( **critic_loss_info, } - learner_state = LearnerState(new_params, new_opt_state, key, env_state, last_timestep) + learner_state = OnPolicyLearnerState( + new_params, new_opt_state, key, env_state, last_timestep + ) metric = traj_batch.info return learner_state, (metric, loss_info) - def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + def learner_fn(learner_state: OnPolicyLearnerState) -> ExperimentOutput[OnPolicyLearnerState]: batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") @@ -210,7 +216,7 @@ def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: def learner_setup( env: Environment, keys: chex.Array, config: DictConfig -) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: +) -> Tuple[LearnerFn[OnPolicyLearnerState], Actor, OnPolicyLearnerState]: """Initialise learner_fn, network, optimiser, environment and states.""" # Get available TPU cores. n_devices = len(jax.devices()) @@ -318,7 +324,7 @@ def learner_setup( # Initialise learner state. params, opt_states = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + init_learner_state = OnPolicyLearnerState(params, opt_states, step_keys, env_states, timesteps) return learn, actor_network, init_learner_state @@ -466,7 +472,7 @@ def run_experiment(_config: DictConfig) -> float: @hydra.main( - config_path="../../configs", + config_path="../../configs/default/anakin", config_name="default_ff_reinforce_continuous.yaml", version_base="1.2", ) diff --git a/stoix/utils/env_factory.py b/stoix/utils/env_factory.py new file mode 100644 index 00000000..9fbaf4bd --- /dev/null +++ b/stoix/utils/env_factory.py @@ -0,0 +1,59 @@ +import abc +import threading +from typing import Any + +import envpool +import gymnasium + +from stoix.wrappers.envpool import EnvPoolToJumanji +from stoix.wrappers.gymnasium import VecGymToJumanji + + +class EnvFactory(abc.ABC): + """ + Abstract class to create environments + """ + + def __init__(self, task_id: str, init_seed: int = 42, **kwargs: Any): + self.task_id = task_id + self.seed = init_seed + # a lock is needed because this object will be used from different threads. + # We want to make sure all seeds are unique + self.lock = threading.Lock() + self.kwargs = kwargs + + @abc.abstractmethod + def __call__(self, num_envs: int) -> Any: + pass + + +class EnvPoolFactory(EnvFactory): + """ + Create environments with different seeds for each `Actor` + """ + + def __call__(self, num_envs: int) -> Any: + with self.lock: + seed = self.seed + self.seed += num_envs + return EnvPoolToJumanji( + envpool.make( + task_id=self.task_id, + env_type="gymnasium", + num_envs=num_envs, + seed=seed, + gym_reset_return_info=True, + **self.kwargs + ) + ) + + +class GymnasiumFactory(EnvFactory): + """ + Create environments using gymnasium + """ + + def __call__(self, num_envs: int) -> Any: + with self.lock: + vec_env = gymnasium.make_vec(id=self.task_id, num_envs=num_envs, **self.kwargs) + return VecGymToJumanji(vec_env) diff --git a/stoix/utils/logger.py b/stoix/utils/logger.py index bb2a9d7f..5483ad02 100644 --- a/stoix/utils/logger.py +++ b/stoix/utils/logger.py @@ -372,7 +372,7 @@ def get_logger_path(config: DictConfig, logger_type: str) -> str: def describe(x: ArrayLike) -> Union[Dict[str, ArrayLike], ArrayLike]: """Generate summary statistics for an array of metrics (mean, std, min, max).""" - if not isinstance(x, jax.Array): + if not isinstance(x, (jax.Array, np.ndarray)) or x.size <= 1: return x # np instead of jnp because we don't jit here diff --git a/stoix/utils/make_env.py b/stoix/utils/make_env.py index 67ebed25..65821a8b 100644 --- a/stoix/utils/make_env.py +++ b/stoix/utils/make_env.py @@ -1,5 +1,5 @@ import copy -from typing import Tuple +from typing import Tuple, Union import gymnax import hydra @@ -25,6 +25,7 @@ from xminigrid.registration import _REGISTRY as XMINIGRID_REGISTRY from stoix.utils.debug_env import IdentityGame, SequenceGame +from stoix.utils.env_factory import EnvPoolFactory, GymnasiumFactory from stoix.wrappers import GymnaxWrapper, JumanjiWrapper, RecordEpisodeMetrics from stoix.wrappers.brax import BraxJumanjiWrapper from stoix.wrappers.jaxmarl import JaxMarlWrapper, MabraxWrapper, SmaxWrapper @@ -371,6 +372,20 @@ def make_navix_env(env_name: str, config: DictConfig) -> Tuple[Environment, Envi return env, eval_env +def make_gymnasium_factory(env_name: str, config: DictConfig) -> GymnasiumFactory: + + env_factory = GymnasiumFactory(env_name, init_seed=config.arch.seed, **config.env.kwargs) + + return env_factory + + +def make_envpool_factory(env_name: str, config: DictConfig) -> EnvPoolFactory: + + env_factory = EnvPoolFactory(env_name, init_seed=config.arch.seed, **config.env.kwargs) + + return env_factory + + def make(config: DictConfig) -> Tuple[Environment, Environment]: """ Create environments for training and evaluation.. @@ -379,7 +394,7 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: config (Dict): The configuration of the environment. Returns: - A tuple of the environments. + training and evaluation environments. """ env_name = config.env.scenario.name @@ -409,3 +424,24 @@ def make(config: DictConfig) -> Tuple[Environment, Environment]: envs = apply_optional_wrappers(envs, config) return envs + + +def make_factory(config: DictConfig) -> Union[GymnasiumFactory, EnvPoolFactory]: + """ + Create a env_factory for sebulba systems. + + Args: + config (Dict): The configuration of the environment. + + Returns: + A factory to create environments. + """ + env_name = config.env.scenario.name + suite_name = config.env.env_name + + if "envpool" in suite_name: + return make_envpool_factory(env_name, config) + elif "gymnasium" in suite_name: + return make_gymnasium_factory(env_name, config) + else: + raise ValueError(f"{suite_name} is not a supported suite.") diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py new file mode 100644 index 00000000..b912eaba --- /dev/null +++ b/stoix/utils/sebulba_utils.py @@ -0,0 +1,173 @@ +import queue +import threading +import time +from functools import partial +from typing import Any, Dict, List, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +from colorama import Fore, Style +from jumanji.types import TimeStep + +from stoix.base_types import Parameters, StoixTransition + + +# Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py +class ThreadLifetime: + """Simple class for a mutable boolean that can be used to signal a thread to stop.""" + + def __init__(self) -> None: + self._stop = False + + def should_stop(self) -> bool: + return self._stop + + def stop(self) -> None: + self._stop = True + + +class OnPolicyPipeline(threading.Thread): + """ + The `Pipeline` shards trajectories into `learner_devices`, + ensuring trajectories are consumed in the right order to avoid being off-policy + and limit the max number of samples in device memory at one time to avoid OOM issues. + """ + + def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: ThreadLifetime): + """ + Initializes the pipeline with a maximum size and the devices to shard trajectories across. + + Args: + max_size: The maximum number of trajectories to keep in the pipeline. + learner_devices: The devices to shard trajectories across. + """ + super().__init__(name="Pipeline") + self.learner_devices = learner_devices + self.tickets_queue: queue.Queue = queue.Queue() + self._queue: queue.Queue = queue.Queue(maxsize=max_size) + self.lifetime = lifetime + + def run(self) -> None: + """This function ensures that trajectories on the queue are consumed in the right order. The + start_condition and end_condition are used to ensure that only 1 thread is processing an + item from the queue at one time, ensuring predictable memory usage. + """ + while not self.lifetime.should_stop(): + try: + start_condition, end_condition = self.tickets_queue.get(timeout=1) + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + except queue.Empty: + continue + + def put(self, traj: Sequence[StoixTransition], timestep: TimeStep, timings_dict: Dict) -> None: + """Put a trajectory on the queue to be consumed by the learner.""" + start_condition, end_condition = (threading.Condition(), threading.Condition()) + with start_condition: + self.tickets_queue.put((start_condition, end_condition)) + start_condition.wait() # wait to be allowed to start + + # [Transition(num_envs)] * rollout_len --> Transition[(rollout_len, num_envs,) + traj = self.stack_trajectory(traj) + # Split trajectory on the num envs axis so each learner device gets a valid full rollout + sharded_traj = jax.tree.map(lambda x: self.shard_split_playload(x, axis=1), traj) + + # Timestep[(num_envs, ...), ...] --> + # [(num_envs / num_learner_devices, ...)] * num_learner_devices + sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) + + # We block on the put to ensure that actors wait for the learners to catch up. This does two + # things: + # 1. It ensures that the actors don't get too far ahead of the learners, which could lead to + # off-policy data. + # 2. It ensures that the actors don't in a sense "waste" samples and their time by + # generating samples that the learners can't consume. + # However, we put a timeout of 180 seconds to avoid deadlocks in case the learner + # is not consuming the data. This is a safety measure and should not be hit in normal + # operation. We use a try-finally since the lock has to be released even if an exception + # is raised. + try: + self._queue.put((sharded_traj, sharded_timestep, timings_dict), block=True, timeout=180) + except queue.Full: + print( + f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) + finally: + with end_condition: + end_condition.notify() # tell we have finish + + def qsize(self) -> int: + """Returns the number of trajectories in the pipeline.""" + return self._queue.qsize() + + def get( + self, block: bool = True, timeout: Union[float, None] = None + ) -> Tuple[StoixTransition, TimeStep, Dict]: + """Get a trajectory from the pipeline.""" + return self._queue.get(block, timeout) # type: ignore + + @partial(jax.jit, static_argnums=(0,)) + def stack_trajectory(self, trajectory: List[StoixTransition]) -> StoixTransition: + """Stack a list of parallel_env transitions into a single + transition of shape [rollout_len, num_envs, ...].""" + return jax.tree_map(lambda *x: jnp.stack(x, axis=0), *trajectory) # type: ignore + + def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: + split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) + return jax.device_put_sharded(split_payload, devices=self.learner_devices) + + def clear(self) -> None: + """Clear the pipeline.""" + while not self._queue.empty(): + self._queue.get() + + +class ParamsSource(threading.Thread): + """A `ParamSource` is a component that allows networks params to be passed from a + `Learner` component to `Actor` components. + """ + + def __init__(self, init_value: Parameters, device: jax.Device, lifetime: ThreadLifetime): + super().__init__(name=f"ParamsSource-{device.id}") + self.value: Parameters = jax.device_put(init_value, device) + self.device = device + self.new_value: queue.Queue = queue.Queue() + self.lifetime = lifetime + + def run(self) -> None: + """This function is responsible for updating the value of the `ParamSource` when a new value + is available. + """ + while not self.lifetime.should_stop(): + try: + waiting = self.new_value.get(block=True, timeout=1) + self.value = jax.device_put(jax.block_until_ready(waiting), self.device) + except queue.Empty: + continue + + def update(self, new_params: Parameters) -> None: + """Update the value of the `ParamSource` with a new value. + + Args: + new_params: The new value to update the `ParamSource` with. + """ + self.new_value.put(new_params) + + def get(self) -> Parameters: + """Get the current value of the `ParamSource`.""" + return self.value + + +class RecordTimeTo: + def __init__(self, to: Any): + self.to = to + + def __enter__(self) -> None: + self.start = time.monotonic() + + def __exit__(self, *args: Any) -> None: + end = time.monotonic() + self.to.append(end - self.start) diff --git a/stoix/utils/total_timestep_checker.py b/stoix/utils/total_timestep_checker.py index 1a370c28..9e64685d 100644 --- a/stoix/utils/total_timestep_checker.py +++ b/stoix/utils/total_timestep_checker.py @@ -5,24 +5,42 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: """Check if total_timesteps is set, if not, set it based on the other parameters""" - assert config.arch.total_num_envs % (config.num_devices * config.arch.update_batch_size) == 0, ( + # Check if the number of devices is set in the config + # If not, it is assumed that the number of devices is 1 + # For the case of using a sebulba config, the number of + # devices is set to 1 for the calculation + # of the number of environments per device, etc + if "num_devices" not in config: + num_devices = 1 + else: + num_devices = config.num_devices + + # If update_batch_size is not in the config, usually this means a sebulba config is being used. + if "update_batch_size" not in config.arch: + update_batch_size = 1 + print(f"{Fore.YELLOW}{Style.BRIGHT}Using Sebulba System!{Style.RESET_ALL}") + else: + update_batch_size = config.arch.update_batch_size + print(f"{Fore.YELLOW}{Style.BRIGHT}Using Anakin System!{Style.RESET_ALL}") + + assert config.arch.total_num_envs % (num_devices * update_batch_size) == 0, ( f"{Fore.RED}{Style.BRIGHT}The total number of environments " - + "should be divisible by the n_devices*update_batch_size!{Style.RESET_ALL}" + + f"should be divisible by the n_devices*update_batch_size!{Style.RESET_ALL}" ) - config.arch.num_envs = config.arch.total_num_envs // ( - config.num_devices * config.arch.update_batch_size + config.arch.num_envs = int( + config.arch.total_num_envs // (num_devices * update_batch_size) ) # Number of environments per device if config.arch.total_timesteps is None: config.arch.total_timesteps = ( - config.num_devices + num_devices * config.arch.num_updates * config.system.rollout_length - * config.arch.update_batch_size + * update_batch_size * config.arch.num_envs ) print( - f"{Fore.YELLOW}{Style.BRIGHT} Changing the total number of timesteps " + f"{Fore.YELLOW}{Style.BRIGHT}Changing the total number of timesteps " + f"to {config.arch.total_timesteps}: If you want to train" + " for a specific number of timesteps, please set num_updates to None!" + f"{Style.RESET_ALL}" @@ -31,12 +49,12 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: config.arch.num_updates = ( config.arch.total_timesteps // config.system.rollout_length - // config.arch.update_batch_size + // update_batch_size // config.arch.num_envs - // config.num_devices + // num_devices ) print( - f"{Fore.YELLOW}{Style.BRIGHT} Changing the number of updates " + f"{Fore.YELLOW}{Style.BRIGHT}Changing the number of updates " + f"to {config.arch.num_updates}: If you want to train" + " for a specific number of updates, please set total_timesteps to None!" + f"{Style.RESET_ALL}" @@ -45,10 +63,10 @@ def check_total_timesteps(config: DictConfig) -> DictConfig: # Calculate the actual number of timesteps that will be run num_updates_per_eval = config.arch.num_updates // config.arch.num_evaluation steps_per_rollout = ( - config.num_devices + num_devices * num_updates_per_eval * config.system.rollout_length - * config.arch.update_batch_size + * update_batch_size * config.arch.num_envs ) total_actual_timesteps = steps_per_rollout * config.arch.num_evaluation diff --git a/stoix/wrappers/envpool.py b/stoix/wrappers/envpool.py new file mode 100644 index 00000000..08f0317f --- /dev/null +++ b/stoix/wrappers/envpool.py @@ -0,0 +1,178 @@ +from typing import Any, Dict, Optional + +import numpy as np +from jumanji.specs import Array, DiscreteArray, Spec +from jumanji.types import StepType, TimeStep +from numpy.typing import NDArray + +from stoix.base_types import Observation + +NEXT_OBS_KEY_IN_EXTRAS = "next_obs" + + +class EnvPoolToJumanji: + """Converts from the Gymnasium envpool API to Jumanji's API.""" + + def __init__(self, env: Any): + self.env = env + obs, _ = self.env.reset() + self.num_envs = obs.shape[0] + self.obs_shape = obs.shape[1:] + self.num_actions = self.env.action_space.n + self._default_action_mask = np.ones((self.num_envs, self.num_actions), dtype=np.float32) + + # Create the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # See if the env has lives - Atari specific + info = self.env.step(np.zeros(self.num_envs, dtype=int))[-1] + if "lives" in info and info["lives"].sum() > 0: + print("Env has lives") + self.has_lives = True + else: + self.has_lives = False + self.env.close() + + # Set the flag to use the gym autoreset API + # since envpool does auto resetting slightly differently + self._use_gym_autoreset_api = True + + def reset( + self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None + ) -> TimeStep: + obs, info = self.env.reset() + + ep_done = np.zeros(self.num_envs, dtype=float) + rewards = np.zeros(self.num_envs, dtype=float) + terminated = np.zeros(self.num_envs, dtype=float) + + # Reset the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the metrics dict + metrics = { + "episode_return": np.zeros(self.num_envs, dtype=float), + "episode_length": np.zeros(self.num_envs, dtype=int), + "is_terminal_step": np.zeros(self.num_envs, dtype=bool), + } + + info["metrics"] = metrics + info[NEXT_OBS_KEY_IN_EXTRAS] = obs.copy() + + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) + + return timestep + + def step(self, action: list) -> TimeStep: + obs, rewards, terminated, truncated, info = self.env.step(action) + ep_done = np.logical_or(terminated, truncated) + not_done = 1 - ep_done + info[NEXT_OBS_KEY_IN_EXTRAS] = obs.copy() + if self._use_gym_autoreset_api: + env_ids_to_reset = np.where(ep_done)[0] + if len(env_ids_to_reset) > 0: + ( + reset_obs, + _, + _, + _, + _, + ) = self.env.step(np.zeros_like(action), env_ids_to_reset) + obs[env_ids_to_reset] = reset_obs + + # Counting episode return and length. + if "reward" in info: + metric_reward = info["reward"] + else: + metric_reward = rewards + + # Counting episode return and length. + new_episode_return = self.running_count_episode_return + metric_reward + new_episode_length = self.running_count_episode_length + 1 + + # Previous episode return/length until done and then the next episode return. + # If the env has lives (Atari), we only consider the return and length of the episode + # every time all lives are exhausted. + if self.has_lives: + all_lives_exhausted = info["lives"] == 0 + not_all_lives_exhausted = 1 - all_lives_exhausted + # Update the episode return and length if all lives are exhausted otherwise + # keep the previous values + episode_return_info = ( + self.episode_return * not_all_lives_exhausted + + new_episode_return * all_lives_exhausted + ) + episode_length_info = ( + self.episode_length * not_all_lives_exhausted + + new_episode_length * all_lives_exhausted + ) + # Update the running count + self.running_count_episode_return = new_episode_return * not_all_lives_exhausted + self.running_count_episode_length = new_episode_length * not_all_lives_exhausted + else: + # Update the episode return and length if the episode is done otherwise + # keep the previous values + episode_return_info = self.episode_return * not_done + new_episode_return * ep_done + episode_length_info = self.episode_length * not_done + new_episode_length * ep_done + # Update the running count + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + + self.episode_return = episode_return_info + self.episode_length = episode_length_info + + # Create the metrics dict + metrics = { + "episode_return": episode_return_info, + "episode_length": episode_length_info, + "is_terminal_step": ep_done, + } + info["metrics"] = metrics + + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) + + return timestep + + def _format_observation(self, obs: NDArray, info: Dict) -> Observation: + action_mask = self._default_action_mask + return Observation(agent_view=obs, action_mask=action_mask) + + def _create_timestep( + self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict + ) -> TimeStep: + obs = self._format_observation(obs, info) + extras = {"metrics": info["metrics"]} + extras[NEXT_OBS_KEY_IN_EXTRAS] = self._format_observation( + info[NEXT_OBS_KEY_IN_EXTRAS], info + ) + step_type = np.where(ep_done, StepType.LAST, StepType.MID) + + return TimeStep( + step_type=step_type, + reward=rewards, + discount=1.0 - terminated, + observation=obs, + extras=extras, + ) + + def observation_spec(self) -> Spec: + agent_view_spec = Array(shape=self.obs_shape, dtype=float) + return Spec( + Observation, + "ObservationSpec", + agent_view=agent_view_spec, + action_mask=Array(shape=(self.num_actions,), dtype=float), + step_count=Array(shape=(), dtype=int), + ) + + def action_spec(self) -> Spec: + return DiscreteArray(num_values=self.num_actions) + + def close(self) -> None: + self.env.close() diff --git a/stoix/wrappers/gymnasium.py b/stoix/wrappers/gymnasium.py new file mode 100644 index 00000000..2022e59f --- /dev/null +++ b/stoix/wrappers/gymnasium.py @@ -0,0 +1,155 @@ +from typing import Dict, Optional + +import gymnasium +import numpy as np +from jumanji.specs import Array, DiscreteArray, Spec +from jumanji.types import StepType, TimeStep +from numpy.typing import NDArray + +from stoix.base_types import Observation + +NEXT_OBS_KEY_IN_EXTRAS = "next_obs" + + +class VecGymToJumanji: + """Converts from a Vectorised Gymnasium environment to Jumanji's API.""" + + def __init__(self, env: gymnasium.vector.AsyncVectorEnv): + self.env = env + self.num_envs = int(self.env.num_envs) + if isinstance(self.env.single_action_space, gymnasium.spaces.Discrete): + self.num_actions = self.env.single_action_space.n + self.discrete = True + else: + self.num_actions = self.env.single_action_space.shape[0] + self.discrete = False + self.obs_shape = self.env.single_observation_space.shape + self._default_action_mask = np.ones((self.num_envs, self.num_actions), dtype=np.float32) + + # Create the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + def reset( + self, *, seed: Optional[list[int]] = None, options: Optional[list[dict]] = None + ) -> TimeStep: + obs, info = self.env.reset(seed=seed, options=options) + obs = np.asarray(obs) + ep_done = np.zeros(self.num_envs, dtype=float) + rewards = np.zeros(self.num_envs, dtype=float) + terminated = np.zeros(self.num_envs, dtype=float) + + # Reset the metrics + self.running_count_episode_return = np.zeros(self.num_envs, dtype=float) + self.running_count_episode_length = np.zeros(self.num_envs, dtype=int) + self.episode_return = np.zeros(self.num_envs, dtype=float) + self.episode_length = np.zeros(self.num_envs, dtype=int) + + # Create the metrics dict + metrics = { + "episode_return": np.zeros(self.num_envs, dtype=float), + "episode_length": np.zeros(self.num_envs, dtype=int), + "is_terminal_step": np.zeros(self.num_envs, dtype=bool), + } + if "final_observation" in info: + real_next_obs = info["final_observation"] + real_next_obs = np.asarray( + [obs[i] if x is None else np.asarray(x) for i, x in enumerate(real_next_obs)] + ) + else: + real_next_obs = obs + + info["metrics"] = metrics + info[NEXT_OBS_KEY_IN_EXTRAS] = real_next_obs + + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) + + return timestep + + def step(self, action: list) -> TimeStep: + obs, rewards, terminated, truncated, info = self.env.step(action) + obs = np.asarray(obs) + rewards = np.asarray(rewards) + terminated = np.asarray(terminated) + truncated = np.asarray(truncated) + ep_done = np.logical_or(terminated, truncated) + not_done = 1 - ep_done + + # Counting episode return and length. + new_episode_return = self.running_count_episode_return + rewards + new_episode_length = self.running_count_episode_length + 1 + + # Previous episode return/length until done and then the next episode return. + episode_return_info = self.episode_return * not_done + new_episode_return * ep_done + episode_length_info = self.episode_length * not_done + new_episode_length * ep_done + + metrics = { + "episode_return": episode_return_info, + "episode_length": episode_length_info, + "is_terminal_step": ep_done, + } + info["metrics"] = metrics + + if "final_observation" in info: + real_next_obs = info["final_observation"] + real_next_obs = np.asarray( + [obs[i] if x is None else np.asarray(x) for i, x in enumerate(real_next_obs)] + ) + else: + real_next_obs = obs + + info["metrics"] = metrics + info[NEXT_OBS_KEY_IN_EXTRAS] = real_next_obs + + # Update the metrics + self.running_count_episode_return = new_episode_return * not_done + self.running_count_episode_length = new_episode_length * not_done + self.episode_return = episode_return_info + self.episode_length = episode_length_info + + timestep = self._create_timestep(obs, ep_done, terminated, rewards, info) + + return timestep + + def _format_observation(self, obs: NDArray, info: Dict) -> Observation: + action_mask = self._default_action_mask + return Observation(agent_view=obs, action_mask=action_mask) + + def _create_timestep( + self, obs: NDArray, ep_done: NDArray, terminated: NDArray, rewards: NDArray, info: Dict + ) -> TimeStep: + obs = self._format_observation(obs, info) + extras = {"metrics": info["metrics"]} + extras[NEXT_OBS_KEY_IN_EXTRAS] = self._format_observation( + info[NEXT_OBS_KEY_IN_EXTRAS], info + ) + step_type = np.where(ep_done, StepType.LAST, StepType.MID) + + return TimeStep( + step_type=step_type, + reward=rewards, + discount=1.0 - terminated, + observation=obs, + extras=extras, + ) + + def observation_spec(self) -> Spec: + agent_view_spec = Array(shape=self.obs_shape, dtype=float) + return Spec( + Observation, + "ObservationSpec", + agent_view=agent_view_spec, + action_mask=Array(shape=(self.num_actions,), dtype=float), + step_count=Array(shape=(), dtype=int), + ) + + def action_spec(self) -> Spec: + if self.discrete: + return DiscreteArray(num_values=self.num_actions) + else: + return Array(shape=(self.num_actions,), dtype=float) + + def close(self) -> None: + self.env.close()