Skip to content

Commit

Permalink
feat: add function to create arbitrary chained torsos and an example …
Browse files Browse the repository at this point in the history
…config
  • Loading branch information
EdanToledo committed Sep 16, 2024
1 parent 8cd990e commit 13dc19f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
40 changes: 40 additions & 0 deletions stoix/configs/network/chained_torsos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ---Example of chaining arbitrary torsos---
actor_network:
pre_torso:
_target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function
_recursive_: false # we disable recursive instantiation for this object as we do the creation manually
torso_cfgs:

- _target_: stoix.networks.resnet.ResNetTorso
hidden_units_per_group: [64, 64]
blocks_per_group: [1, 1]
use_layer_norm: False
activation: silu

- _target_: stoix.networks.torso.MLPTorso
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu

action_head:
_target_: stoix.networks.heads.CategoricalHead

critic_network:
pre_torso:
_target_: stoix.networks.base.chained_torsos # we call the chained torsos creation function
_recursive_: false # we disable recursive instantiation for this object as we do the creation manually
torso_cfgs:

- _target_: stoix.networks.resnet.ResNetTorso
hidden_units_per_group: [64, 64]
blocks_per_group: [1, 1]
use_layer_norm: False
activation: silu

- _target_: stoix.networks.torso.MLPTorso
layer_sizes: [64, 64]
use_layer_norm: False
activation: relu

critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
16 changes: 15 additions & 1 deletion stoix/networks/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import functools
from typing import Sequence, Tuple, Union
from typing import Any, Dict, List, Sequence, Tuple, Union

import chex
import distrax
import hydra
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -182,3 +183,16 @@ def __call__(
critic_output = self.critic_head(critic_output)

return critic_hidden_state, critic_output


def chained_torsos(torso_cfgs: List[Dict[str, Any]]) -> nn.Module:
"""Create a network by chaining multiple torsos together using a list of configs.
This makes use of hydra to instantiate the modules and the composite network
to chain them together.
Args:
torso_cfgs: List of dictionaries containing the configuration for each torso.
These configs should use the same format as the individual torso configs."""

torso_modules = [hydra.utils.instantiate(torso_cfg) for torso_cfg in torso_cfgs]
return CompositeNetwork(torso_modules)

0 comments on commit 13dc19f

Please sign in to comment.