Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edan toledo/issue112 #117

Merged
merged 4 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,9 @@ If you use Stoix in your work, please cite us:
```bibtex
@software{toledo2024stoix,
author = {Toledo, Edan},
doi = {10.5281/zenodo.10916258},
month = apr,
doi = {10.5281/zenodo.10916257},
EdanToledo marked this conversation as resolved.
Show resolved Hide resolved
title = {{Stoix: Distributed Single-Agent Reinforcement Learning End-to-End in JAX}},
url = {https://github.com/EdanToledo/Stoix},
version = {v0.0.1},
year = {2024}
}
```
Expand Down
40 changes: 40 additions & 0 deletions stoix/configs/network/chained_torsos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ---Example of chaining arbitrary torsos---
actor_network:
pre_torso:
_target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function
_recursive_: false # we disable recursive instantiation for this object as we do the creation manually
torso_cfgs:

- _target_: stoix.networks.resnet.ResNetTorso
hidden_units_per_group: [64, 64]
blocks_per_group: [1, 1]
use_layer_norm: False
activation: silu

- _target_: stoix.networks.torso.MLPTorso
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu

action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function
_recursive_: false # we disable recursive instantiation for this object as we do the creation manually
torso_cfgs:

- _target_: stoix.networks.resnet.ResNetTorso
hidden_units_per_group: [64, 64]
blocks_per_group: [1, 1]
use_layer_norm: False
activation: silu

- _target_: stoix.networks.torso.MLPTorso
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu

critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
16 changes: 15 additions & 1 deletion stoix/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
from typing import Sequence, Tuple, Union
from typing import Any, Dict, List, Sequence, Tuple, Union

import chex
import distrax
import hydra
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -182,3 +183,16 @@ def __call__(
critic_output = self.critic_head(critic_output)

return critic_hidden_state, critic_output


def chained_torsos(torso_cfgs: List[Dict[str, Any]]) -> nn.Module:
"""Create a network by chaining multiple torsos together using a list of configs.
This makes use of hydra to instantiate the modules and the composite network
to chain them together.

Args:
torso_cfgs: List of dictionaries containing the configuration for each torso.
These configs should use the same format as the individual torso configs."""

torso_modules = [hydra.utils.instantiate(torso_cfg) for torso_cfg in torso_cfgs]
return CompositeNetwork(torso_modules)
Loading