diff --git a/README.md b/README.md index 96e3334..e6ee350 100644 --- a/README.md +++ b/README.md @@ -199,11 +199,10 @@ If you use Stoix in your work, please cite us: ```bibtex @software{toledo2024stoix, author = {Toledo, Edan}, -doi = {10.5281/zenodo.10916258}, month = apr, +doi = {10.5281/zenodo.10916257}, title = {{Stoix: Distributed Single-Agent Reinforcement Learning End-to-End in JAX}}, url = {https://github.com/EdanToledo/Stoix}, -version = {v0.0.1}, year = {2024} } ``` diff --git a/stoix/configs/network/chained_torsos.yaml b/stoix/configs/network/chained_torsos.yaml new file mode 100644 index 0000000..56f44d7 --- /dev/null +++ b/stoix/configs/network/chained_torsos.yaml @@ -0,0 +1,40 @@ +# ---Example of chaining arbitrary torsos--- +actor_network: + pre_torso: + _target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function + _recursive_: false # we disable recursive instantiation for this object as we do the creation manually + torso_cfgs: + + - _target_: stoix.networks.resnet.ResNetTorso + hidden_units_per_group: [64, 64] + blocks_per_group: [1, 1] + use_layer_norm: False + activation: silu + + - _target_: stoix.networks.torso.MLPTorso + layer_sizes: [64, 64] + use_layer_norm: False + activation: relu + + action_head: + _target_: stoix.networks.heads.CategoricalHead + +critic_network: + pre_torso: + _target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function + _recursive_: false # we disable recursive instantiation for this object as we do the creation manually + torso_cfgs: + + - _target_: stoix.networks.resnet.ResNetTorso + hidden_units_per_group: [64, 64] + blocks_per_group: [1, 1] + use_layer_norm: False + activation: silu + + - _target_: stoix.networks.torso.MLPTorso + layer_sizes: [64, 64] + use_layer_norm: False + activation: relu + + critic_head: + _target_: stoix.networks.heads.ScalarCriticHead diff --git a/stoix/networks/base.py b/stoix/networks/base.py index 8a8afec..3647691 100644 --- a/stoix/networks/base.py +++ b/stoix/networks/base.py @@ -1,8 +1,9 @@ import functools -from typing import Sequence, Tuple, Union +from typing import Any, Dict, List, Sequence, Tuple, Union import chex import distrax +import hydra import jax import jax.numpy as jnp import numpy as np @@ -182,3 +183,16 @@ def __call__( critic_output = self.critic_head(critic_output) return critic_hidden_state, critic_output + + +def chained_torsos(torso_cfgs: List[Dict[str, Any]]) -> nn.Module: + """Create a network by chaining multiple torsos together using a list of configs. + This makes use of hydra to instantiate the modules and the composite network + to chain them together. + + Args: + torso_cfgs: List of dictionaries containing the configuration for each torso. + These configs should use the same format as the individual torso configs.""" + + torso_modules = [hydra.utils.instantiate(torso_cfg) for torso_cfg in torso_cfgs] + return CompositeNetwork(torso_modules)