Skip to content

Commit

Permalink
Edan toledo/issue112 (#117)
Browse files Browse the repository at this point in the history
* feat: add function to create arbitrary chained torsos and an example config
  • Loading branch information
EdanToledo authored Sep 17, 2024
1 parent 6125d36 commit ce449a2
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,10 @@ If you use Stoix in your work, please cite us:
```bibtex
@software{toledo2024stoix,
author = {Toledo, Edan},
doi = {10.5281/zenodo.10916258},
month = apr,
doi = {10.5281/zenodo.10916257},
title = {{Stoix: Distributed Single-Agent Reinforcement Learning End-to-End in JAX}},
url = {https://github.com/EdanToledo/Stoix},
version = {v0.0.1},
year = {2024}
}
```
Expand Down
40 changes: 40 additions & 0 deletions stoix/configs/network/chained_torsos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ---Example of chaining arbitrary torsos---
actor_network:
pre_torso:
_target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function
_recursive_: false # we disable recursive instantiation for this object as we do the creation manually
torso_cfgs:

- _target_: stoix.networks.resnet.ResNetTorso
hidden_units_per_group: [64, 64]
blocks_per_group: [1, 1]
use_layer_norm: False
activation: silu

- _target_: stoix.networks.torso.MLPTorso
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu

action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function
_recursive_: false # we disable recursive instantiation for this object as we do the creation manually
torso_cfgs:

- _target_: stoix.networks.resnet.ResNetTorso
hidden_units_per_group: [64, 64]
blocks_per_group: [1, 1]
use_layer_norm: False
activation: silu

- _target_: stoix.networks.torso.MLPTorso
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu

critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
16 changes: 15 additions & 1 deletion stoix/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
from typing import Sequence, Tuple, Union
from typing import Any, Dict, List, Sequence, Tuple, Union

import chex
import distrax
import hydra
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -182,3 +183,16 @@ def __call__(
critic_output = self.critic_head(critic_output)

return critic_hidden_state, critic_output


def chained_torsos(torso_cfgs: List[Dict[str, Any]]) -> nn.Module:
"""Create a network by chaining multiple torsos together using a list of configs.
This makes use of hydra to instantiate the modules and the composite network
to chain them together.
Args:
torso_cfgs: List of dictionaries containing the configuration for each torso.
These configs should use the same format as the individual torso configs."""

torso_modules = [hydra.utils.instantiate(torso_cfg) for torso_cfg in torso_cfgs]
return CompositeNetwork(torso_modules)

0 comments on commit ce449a2

Please sign in to comment.