diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3e14ded4..0e2d634c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Model instantiator support for different shared model structures in PyTorch
- Support for other model types than Gaussian and Deterministic in runners
- Support for automatic mixed precision training in PyTorch
+- `init_state_dict` method to initialize model's lazy modules in PyTorch
### Changed
- Call agent's `pre_interaction` method during evaluation
diff --git a/docs/source/api/utils/model_instantiators.rst b/docs/source/api/utils/model_instantiators.rst
index 0380f09d..283e50c0 100644
--- a/docs/source/api/utils/model_instantiators.rst
+++ b/docs/source/api/utils/model_instantiators.rst
@@ -59,6 +59,13 @@ Implementation details:
:start-after: [start-structure-yaml]
:end-before: [end-structure-yaml]
+ .. group-tab:: Python
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-structure-python]
+ :end-before: [end-structure-python]
+
|
Inputs
@@ -278,6 +285,42 @@ Apply a linear transformation (:py:class:`torch.nn.Linear` in PyTorch, :py:class
:start-after: [start-layer-linear-dict]
:end-before: [end-layer-linear-dict]
+ .. group-tab:: Python
+
+ .. tabs::
+
+ .. group-tab:: Single value
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-linear-basic-python]
+ :end-before: [end-layer-linear-basic-python]
+
+ .. group-tab:: As int
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-linear-int-python]
+ :end-before: [end-layer-linear-int-python]
+
+ .. group-tab:: As list
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-linear-list-python]
+ :end-before: [end-layer-linear-list-python]
+
+ .. group-tab:: As dict
+
+ .. hint::
+
+ The parameter names can be interchanged/mixed between PyTorch and JAX
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-linear-dict-python]
+ :end-before: [end-layer-linear-dict-python]
+
|
conv2d
@@ -365,6 +408,28 @@ Apply a 2D convolution (:py:class:`torch.nn.Conv2d` in PyTorch, :py:class:`flax.
:start-after: [start-layer-conv2d-dict]
:end-before: [end-layer-conv2d-dict]
+ .. group-tab:: Python
+
+ .. tabs::
+
+ .. group-tab:: As list
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-conv2d-list-python]
+ :end-before: [end-layer-conv2d-list-python]
+
+ .. group-tab:: As dict
+
+ .. hint::
+
+ The parameter names can be interchanged/mixed between PyTorch and JAX
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-conv2d-dict-python]
+ :end-before: [end-layer-conv2d-dict-python]
+
|
flatten
@@ -425,6 +490,35 @@ Flatten a contiguous range of dimensions (:py:class:`torch.nn.Flatten` in PyTorc
:start-after: [start-layer-flatten-dict]
:end-before: [end-layer-flatten-dict]
+ .. group-tab:: Python
+
+ .. tabs::
+
+ .. group-tab:: Single value
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-flatten-basic-python]
+ :end-before: [end-layer-flatten-basic-python]
+
+ .. group-tab:: As list
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-flatten-list-python]
+ :end-before: [end-layer-flatten-list-python]
+
+ .. group-tab:: As dict
+
+ .. hint::
+
+ The parameter names can be interchanged/mixed between PyTorch and JAX
+
+ .. literalinclude:: ../../snippets/model_instantiators.txt
+ :language: python
+ :start-after: [start-layer-flatten-dict-python]
+ :end-before: [end-layer-flatten-dict-python]
+
.. raw:: html
diff --git a/docs/source/examples/gym/jax_gym_cartpole_dqn.py b/docs/source/examples/gym/jax_gym_cartpole_dqn.py
index 2030ce3e..d487c952 100644
--- a/docs/source/examples/gym/jax_gym_cartpole_dqn.py
+++ b/docs/source/examples/gym/jax_gym_cartpole_dqn.py
@@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
# instantiate models' state dict
for role, model in models.items():
diff --git a/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py b/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py
index f723dcf0..a2292669 100644
--- a/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py
+++ b/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py
@@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
# instantiate models' state dict
for role, model in models.items():
diff --git a/docs/source/examples/gym/torch_gym_cartpole_dqn.py b/docs/source/examples/gym/torch_gym_cartpole_dqn.py
index 1e5340c8..74c4b700 100644
--- a/docs/source/examples/gym/torch_gym_cartpole_dqn.py
+++ b/docs/source/examples/gym/torch_gym_cartpole_dqn.py
@@ -38,22 +38,28 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
+
+# initialize models' lazy modules
+for role, model in models.items():
+ model.init_state_dict(role)
# initialize models' parameters (weights and biases)
for model in models.values():
diff --git a/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py b/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py
index fcbcce25..03304a00 100644
--- a/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py
+++ b/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py
@@ -38,22 +38,28 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
+
+# initialize models' lazy modules
+for role, model in models.items():
+ model.init_state_dict(role)
# initialize models' parameters (weights and biases)
for model in models.values():
diff --git a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py
index 473b033d..c537719c 100644
--- a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py
+++ b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py
@@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
# instantiate models' state dict
for role, model in models.items():
diff --git a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py
index cf90d656..f4738ea7 100644
--- a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py
+++ b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py
@@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
# instantiate models' state dict
for role, model in models.items():
diff --git a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py
index 10f13151..0ca9ff22 100644
--- a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py
+++ b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py
@@ -38,22 +38,28 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
+
+# initialize models' lazy modules
+for role, model in models.items():
+ model.init_state_dict(role)
# initialize models' parameters (weights and biases)
for model in models.values():
diff --git a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py
index 63101422..97e5d79b 100644
--- a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py
+++ b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py
@@ -38,22 +38,28 @@
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
- input_shape=Shape.OBSERVATIONS,
- hiddens=[64, 64],
- hidden_activation=["relu", "relu"],
- output_shape=Shape.ACTIONS,
- output_activation=None,
- output_scale=1.0)
+ network=[{
+ "name": "net",
+ "input": "STATES",
+ "layers": [64, 64],
+ "activations": "relu",
+ }],
+ output="ACTIONS")
+
+# initialize models' lazy modules
+for role, model in models.items():
+ model.init_state_dict(role)
# initialize models' parameters (weights and biases)
for model in models.values():
diff --git a/docs/source/snippets/model_instantiators.txt b/docs/source/snippets/model_instantiators.txt
index 014e29c6..f0f50383 100644
--- a/docs/source/snippets/model_instantiators.txt
+++ b/docs/source/snippets/model_instantiators.txt
@@ -13,6 +13,25 @@ network:
-
# [end-structure-yaml]
+# [start-structure-python]
+network=[
+ {
+ "name": , # container name
+ "input": , # container input (certain operations are supported)
+ "layers": [ # list of supported layers
+ ,
+ ...,
+ ,
+ ],
+ "activations": [ # list of supported activation functions
+ ,
+ ...,
+ ,
+ ],
+ },
+]
+# [end-structure-python]
+
# =============================================================================
# [start-layer-linear-basic]
@@ -20,24 +39,48 @@ layers:
- 32
# [end-layer-linear-basic]
+# [start-layer-linear-basic-python]
+"layers": [
+ 32,
+]
+# [end-layer-linear-basic-python]
+
# [start-layer-linear-int]
layers:
- linear: 32
# [end-layer-linear-int]
+# [start-layer-linear-int-python]
+"layers": [
+ {"linear": 32},
+]
+# [end-layer-linear-int-python]
+
# [start-layer-linear-list]
layers:
- linear: [32]
# [end-layer-linear-list]
+# [start-layer-linear-list-python]
+"layers": [
+ {"linear": [32]},
+]
+# [end-layer-linear-list-python]
+
# [start-layer-linear-dict]
layers:
- linear: {out_features: 32}
# [end-layer-linear-dict]
+# [start-layer-linear-dict-python]
+"layers": [
+ {"linear": {"out_features": 32}},
+]
+# [end-layer-linear-dict-python]
+
# =============================================================================
# [start-layer-conv2d-list]
@@ -45,12 +88,24 @@ layers:
- conv2d: [32, 8, [4, 4]]
# [end-layer-conv2d-list]
+# [start-layer-conv2d-list-python]
+"layers": [
+ {"conv2d": [32, 8, [4, 4]]},
+]
+# [end-layer-conv2d-list-python]
+
# [start-layer-conv2d-dict]
layers:
- conv2d: {out_channels: 32, kernel_size: 8, stride: [4, 4]}
# [end-layer-conv2d-dict]
+# [start-layer-conv2d-dict-python]
+"layers": [
+ {"conv2d": {"out_channels": 32, "kernel_size": 8, "stride": [4, 4]}},
+]
+# [end-layer-conv2d-dict-python]
+
# =============================================================================
# [start-layer-flatten-basic]
@@ -58,14 +113,32 @@ layers:
- flatten
# [end-layer-flatten-basic]
+# [start-layer-flatten-basic-python]
+"layers": [
+ "flatten",
+]
+# [end-layer-flatten-basic-python]
+
# [start-layer-flatten-list]
layers:
- flatten: [1, -1]
# [end-layer-flatten-list]
+# [start-layer-flatten-list-python]
+"layers": [
+ {"flatten": [1, -1]},
+]
+# [end-layer-flatten-list-python]
+
# [start-layer-flatten-dict]
layers:
- flatten: {start_dim: 1, end_dim: -1}
# [end-layer-flatten-dict]
+
+# [start-layer-flatten-dict-python]
+"layers": [
+ {"flatten": {"start_dim": 1, "end_dim": -1}},
+]
+# [end-layer-flatten-dict-python]