Skip to content

Commit

Permalink
Update examples that use model instantiators to the latest API and In…
Browse files Browse the repository at this point in the history
…itialize models' lazy modules (#247)
  • Loading branch information
Toni-SM authored Jan 7, 2025
1 parent 65da82b commit deb28ff
Show file tree
Hide file tree
Showing 11 changed files with 296 additions and 96 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Model instantiator support for different shared model structures in PyTorch
- Support for other model types than Gaussian and Deterministic in runners
- Support for automatic mixed precision training in PyTorch
- `init_state_dict` method to initialize model's lazy modules in PyTorch

### Changed
- Call agent's `pre_interaction` method during evaluation
Expand Down
94 changes: 94 additions & 0 deletions docs/source/api/utils/model_instantiators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ Implementation details:
:start-after: [start-structure-yaml]
:end-before: [end-structure-yaml]

.. group-tab:: Python

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-structure-python]
:end-before: [end-structure-python]

|
Inputs
Expand Down Expand Up @@ -278,6 +285,42 @@ Apply a linear transformation (:py:class:`torch.nn.Linear` in PyTorch, :py:class
:start-after: [start-layer-linear-dict]
:end-before: [end-layer-linear-dict]

.. group-tab:: Python

.. tabs::

.. group-tab:: Single value

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-linear-basic-python]
:end-before: [end-layer-linear-basic-python]

.. group-tab:: As int

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-linear-int-python]
:end-before: [end-layer-linear-int-python]

.. group-tab:: As list

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-linear-list-python]
:end-before: [end-layer-linear-list-python]

.. group-tab:: As dict

.. hint::

The parameter names can be interchanged/mixed between PyTorch and JAX

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-linear-dict-python]
:end-before: [end-layer-linear-dict-python]

|
conv2d
Expand Down Expand Up @@ -365,6 +408,28 @@ Apply a 2D convolution (:py:class:`torch.nn.Conv2d` in PyTorch, :py:class:`flax.
:start-after: [start-layer-conv2d-dict]
:end-before: [end-layer-conv2d-dict]

.. group-tab:: Python

.. tabs::

.. group-tab:: As list

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-conv2d-list-python]
:end-before: [end-layer-conv2d-list-python]

.. group-tab:: As dict

.. hint::

The parameter names can be interchanged/mixed between PyTorch and JAX

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-conv2d-dict-python]
:end-before: [end-layer-conv2d-dict-python]

|
flatten
Expand Down Expand Up @@ -425,6 +490,35 @@ Flatten a contiguous range of dimensions (:py:class:`torch.nn.Flatten` in PyTorc
:start-after: [start-layer-flatten-dict]
:end-before: [end-layer-flatten-dict]

.. group-tab:: Python

.. tabs::

.. group-tab:: Single value

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-flatten-basic-python]
:end-before: [end-layer-flatten-basic-python]

.. group-tab:: As list

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-flatten-list-python]
:end-before: [end-layer-flatten-list-python]

.. group-tab:: As dict

.. hint::

The parameter names can be interchanged/mixed between PyTorch and JAX

.. literalinclude:: ../../snippets/model_instantiators.txt
:language: python
:start-after: [start-layer-flatten-dict-python]
:end-before: [end-layer-flatten-dict-python]

.. raw:: html

<br>
Expand Down
26 changes: 14 additions & 12 deletions docs/source/examples/gym/jax_gym_cartpole_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")

# instantiate models' state dict
for role, model in models.items():
Expand Down
26 changes: 14 additions & 12 deletions docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")

# instantiate models' state dict
for role, model in models.items():
Expand Down
30 changes: 18 additions & 12 deletions docs/source/examples/gym/torch_gym_cartpole_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,28 @@
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")

# initialize models' lazy modules
for role, model in models.items():
model.init_state_dict(role)

# initialize models' parameters (weights and biases)
for model in models.values():
Expand Down
30 changes: 18 additions & 12 deletions docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,28 @@
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")

# initialize models' lazy modules
for role, model in models.items():
model.init_state_dict(role)

# initialize models' parameters (weights and biases)
for model in models.values():
Expand Down
26 changes: 14 additions & 12 deletions docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")

# instantiate models' state dict
for role, model in models.items():
Expand Down
26 changes: 14 additions & 12 deletions docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,24 @@
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
clip_actions=False,
input_shape=Shape.OBSERVATIONS,
hiddens=[64, 64],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1.0)
network=[{
"name": "net",
"input": "STATES",
"layers": [64, 64],
"activations": "relu",
}],
output="ACTIONS")

# instantiate models' state dict
for role, model in models.items():
Expand Down
Loading

0 comments on commit deb28ff

Please sign in to comment.