diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e14ded4..0e2d634c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Model instantiator support for different shared model structures in PyTorch - Support for other model types than Gaussian and Deterministic in runners - Support for automatic mixed precision training in PyTorch +- `init_state_dict` method to initialize model's lazy modules in PyTorch ### Changed - Call agent's `pre_interaction` method during evaluation diff --git a/docs/source/api/utils/model_instantiators.rst b/docs/source/api/utils/model_instantiators.rst index 0380f09d..283e50c0 100644 --- a/docs/source/api/utils/model_instantiators.rst +++ b/docs/source/api/utils/model_instantiators.rst @@ -59,6 +59,13 @@ Implementation details: :start-after: [start-structure-yaml] :end-before: [end-structure-yaml] + .. group-tab:: Python + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-structure-python] + :end-before: [end-structure-python] + | Inputs @@ -278,6 +285,42 @@ Apply a linear transformation (:py:class:`torch.nn.Linear` in PyTorch, :py:class :start-after: [start-layer-linear-dict] :end-before: [end-layer-linear-dict] + .. group-tab:: Python + + .. tabs:: + + .. group-tab:: Single value + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-linear-basic-python] + :end-before: [end-layer-linear-basic-python] + + .. group-tab:: As int + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-linear-int-python] + :end-before: [end-layer-linear-int-python] + + .. group-tab:: As list + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-linear-list-python] + :end-before: [end-layer-linear-list-python] + + .. group-tab:: As dict + + .. hint:: + + The parameter names can be interchanged/mixed between PyTorch and JAX + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-linear-dict-python] + :end-before: [end-layer-linear-dict-python] + | conv2d @@ -365,6 +408,28 @@ Apply a 2D convolution (:py:class:`torch.nn.Conv2d` in PyTorch, :py:class:`flax. :start-after: [start-layer-conv2d-dict] :end-before: [end-layer-conv2d-dict] + .. group-tab:: Python + + .. tabs:: + + .. group-tab:: As list + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-conv2d-list-python] + :end-before: [end-layer-conv2d-list-python] + + .. group-tab:: As dict + + .. hint:: + + The parameter names can be interchanged/mixed between PyTorch and JAX + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-conv2d-dict-python] + :end-before: [end-layer-conv2d-dict-python] + | flatten @@ -425,6 +490,35 @@ Flatten a contiguous range of dimensions (:py:class:`torch.nn.Flatten` in PyTorc :start-after: [start-layer-flatten-dict] :end-before: [end-layer-flatten-dict] + .. group-tab:: Python + + .. tabs:: + + .. group-tab:: Single value + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-flatten-basic-python] + :end-before: [end-layer-flatten-basic-python] + + .. group-tab:: As list + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-flatten-list-python] + :end-before: [end-layer-flatten-list-python] + + .. group-tab:: As dict + + .. hint:: + + The parameter names can be interchanged/mixed between PyTorch and JAX + + .. literalinclude:: ../../snippets/model_instantiators.txt + :language: python + :start-after: [start-layer-flatten-dict-python] + :end-before: [end-layer-flatten-dict-python] + .. raw:: html
diff --git a/docs/source/examples/gym/jax_gym_cartpole_dqn.py b/docs/source/examples/gym/jax_gym_cartpole_dqn.py index 2030ce3e..d487c952 100644 --- a/docs/source/examples/gym/jax_gym_cartpole_dqn.py +++ b/docs/source/examples/gym/jax_gym_cartpole_dqn.py @@ -42,22 +42,24 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") # instantiate models' state dict for role, model in models.items(): diff --git a/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py b/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py index f723dcf0..a2292669 100644 --- a/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py +++ b/docs/source/examples/gym/jax_gym_cartpole_vector_dqn.py @@ -42,22 +42,24 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") # instantiate models' state dict for role, model in models.items(): diff --git a/docs/source/examples/gym/torch_gym_cartpole_dqn.py b/docs/source/examples/gym/torch_gym_cartpole_dqn.py index 1e5340c8..74c4b700 100644 --- a/docs/source/examples/gym/torch_gym_cartpole_dqn.py +++ b/docs/source/examples/gym/torch_gym_cartpole_dqn.py @@ -38,22 +38,28 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") + +# initialize models' lazy modules +for role, model in models.items(): + model.init_state_dict(role) # initialize models' parameters (weights and biases) for model in models.values(): diff --git a/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py b/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py index fcbcce25..03304a00 100644 --- a/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py +++ b/docs/source/examples/gym/torch_gym_cartpole_vector_dqn.py @@ -38,22 +38,28 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") + +# initialize models' lazy modules +for role, model in models.items(): + model.init_state_dict(role) # initialize models' parameters (weights and biases) for model in models.values(): diff --git a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py index 473b033d..c537719c 100644 --- a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py +++ b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_dqn.py @@ -42,22 +42,24 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") # instantiate models' state dict for role, model in models.items(): diff --git a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py index cf90d656..f4738ea7 100644 --- a/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py +++ b/docs/source/examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py @@ -42,22 +42,24 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") # instantiate models' state dict for role, model in models.items(): diff --git a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py index 10f13151..0ca9ff22 100644 --- a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py +++ b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_dqn.py @@ -38,22 +38,28 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") + +# initialize models' lazy modules +for role, model in models.items(): + model.init_state_dict(role) # initialize models' parameters (weights and biases) for model in models.values(): diff --git a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py index 63101422..97e5d79b 100644 --- a/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py +++ b/docs/source/examples/gymnasium/torch_gymnasium_cartpole_vector_dqn.py @@ -38,22 +38,28 @@ action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") models["target_q_network"] = deterministic_model(observation_space=env.observation_space, action_space=env.action_space, device=device, clip_actions=False, - input_shape=Shape.OBSERVATIONS, - hiddens=[64, 64], - hidden_activation=["relu", "relu"], - output_shape=Shape.ACTIONS, - output_activation=None, - output_scale=1.0) + network=[{ + "name": "net", + "input": "STATES", + "layers": [64, 64], + "activations": "relu", + }], + output="ACTIONS") + +# initialize models' lazy modules +for role, model in models.items(): + model.init_state_dict(role) # initialize models' parameters (weights and biases) for model in models.values(): diff --git a/docs/source/snippets/model_instantiators.txt b/docs/source/snippets/model_instantiators.txt index 014e29c6..f0f50383 100644 --- a/docs/source/snippets/model_instantiators.txt +++ b/docs/source/snippets/model_instantiators.txt @@ -13,6 +13,25 @@ network: - # [end-structure-yaml] +# [start-structure-python] +network=[ + { + "name": , # container name + "input": , # container input (certain operations are supported) + "layers": [ # list of supported layers + , + ..., + , + ], + "activations": [ # list of supported activation functions + , + ..., + , + ], + }, +] +# [end-structure-python] + # ============================================================================= # [start-layer-linear-basic] @@ -20,24 +39,48 @@ layers: - 32 # [end-layer-linear-basic] +# [start-layer-linear-basic-python] +"layers": [ + 32, +] +# [end-layer-linear-basic-python] + # [start-layer-linear-int] layers: - linear: 32 # [end-layer-linear-int] +# [start-layer-linear-int-python] +"layers": [ + {"linear": 32}, +] +# [end-layer-linear-int-python] + # [start-layer-linear-list] layers: - linear: [32] # [end-layer-linear-list] +# [start-layer-linear-list-python] +"layers": [ + {"linear": [32]}, +] +# [end-layer-linear-list-python] + # [start-layer-linear-dict] layers: - linear: {out_features: 32} # [end-layer-linear-dict] +# [start-layer-linear-dict-python] +"layers": [ + {"linear": {"out_features": 32}}, +] +# [end-layer-linear-dict-python] + # ============================================================================= # [start-layer-conv2d-list] @@ -45,12 +88,24 @@ layers: - conv2d: [32, 8, [4, 4]] # [end-layer-conv2d-list] +# [start-layer-conv2d-list-python] +"layers": [ + {"conv2d": [32, 8, [4, 4]]}, +] +# [end-layer-conv2d-list-python] + # [start-layer-conv2d-dict] layers: - conv2d: {out_channels: 32, kernel_size: 8, stride: [4, 4]} # [end-layer-conv2d-dict] +# [start-layer-conv2d-dict-python] +"layers": [ + {"conv2d": {"out_channels": 32, "kernel_size": 8, "stride": [4, 4]}}, +] +# [end-layer-conv2d-dict-python] + # ============================================================================= # [start-layer-flatten-basic] @@ -58,14 +113,32 @@ layers: - flatten # [end-layer-flatten-basic] +# [start-layer-flatten-basic-python] +"layers": [ + "flatten", +] +# [end-layer-flatten-basic-python] + # [start-layer-flatten-list] layers: - flatten: [1, -1] # [end-layer-flatten-list] +# [start-layer-flatten-list-python] +"layers": [ + {"flatten": [1, -1]}, +] +# [end-layer-flatten-list-python] + # [start-layer-flatten-dict] layers: - flatten: {start_dim: 1, end_dim: -1} # [end-layer-flatten-dict] + +# [start-layer-flatten-dict-python] +"layers": [ + {"flatten": {"start_dim": 1, "end_dim": -1}}, +] +# [end-layer-flatten-dict-python]