Skip to content

Commit

Permalink
[Doc] Fix modules doc (#2531)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 31, 2024
1 parent 6799a7f commit edbf3de
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 160 deletions.
211 changes: 52 additions & 159 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ Some algorithms such as PPO require a probabilistic policy to be implemented.
In TorchRL, these policies take the form of a model, followed by a distribution
constructor.

.. note::
The choice of a probabilistic or regular actor class depends on the algorithm
.. note:: The choice of a probabilistic or regular actor class depends on the algorithm
that is being implemented. On-policy algorithms usually require a probabilistic
actor, off-policy usually have a deterministic actor with an extra exploration
strategy. There are, however, many exceptions to this rule.
Expand All @@ -103,8 +102,12 @@ and outputs the parameters of a distribution, while the distribution constructor
reads these parameters and gets a random sample from the distribution and/or
provides a :class:`torch.distributions.Distribution` object.

>>> from tensordict.nn import NormalParamExtractor, TensorDictSequential
>>> from tensordict.nn import NormalParamExtractor, TensorDictSequential, TensorDictModule
>>> from torchrl.modules import SafeProbabilisticModule
>>> from torchrl.envs import GymEnv
>>> from torch.distributions import Normal
>>> from torch import nn
>>>
>>> env = GymEnv("Pendulum-v1")
>>> action_spec = env.action_spec
>>> model = nn.Sequential(nn.LazyLinear(action_spec.shape[-1] * 2), NormalParamExtractor())
Expand All @@ -125,6 +128,7 @@ provides a :class:`torch.distributions.Distribution` object.
To facilitate the construction of probabilistic policies, we provide a dedicated
:class:`~torchrl.modules.tensordict_module.ProbabilisticActor`:

>>> from torchrl.modules import ProbabilisticActor
>>> policy = ProbabilisticActor(
... model,
... in_keys=["loc", "scale"],
Expand Down Expand Up @@ -154,69 +158,31 @@ of this action.
Q-Value actors
~~~~~~~~~~~~~~

Q-Value actors are a special type of policy that does not directly predict an action
from an observation, but picks the action that maximised the value (or *quality*)
of a (s,a) -> v map. This map can be a table or a function.
For discrete action spaces with continuous (or near-continuous such as pixels)
states, it is customary to use a non-linear model such as a neural network for
the map.
The semantic of the Q-Value network is hopefully quite simple: we just need to
feed a tensor-to-tensor map that given a certain state (the input tensor),
outputs a list of action values to choose from. The wrapper will write the
resulting action in the input tensordict along with the list of action values.
Q-Value actors are a type of policy that selects actions based on the maximum value
(or "quality") of a state-action pair. This value can be represented as a table or a
function. For discrete action spaces with continuous states, it's common to use a non-linear
model like a neural network to represent this function.

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
>>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> module = nn.Linear(3, 4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)
QValueActor
^^^^^^^^^^^

Distributional Q-learning is slightly different: in this case, the value network
does not output a scalar value for each state-action value.
Instead, the value space is divided in a an arbitrary number of "bins". The
value network outputs a probability that the state-action value belongs to one bin
or another.
Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
the value network encodes a
of a (s,a) -> v map. This map can be a table or a function.
For discrete action spaces with continuous (or near-continuous such as pixels)
states, it is customary to use a non-linear model such as a neural network for
the map.
The semantic of the Q-Value network is hopefully quite simple: we just need to
feed a tensor-to-tensor map that given a certain state (the input tensor),
outputs a list of action values to choose from. The wrapper will write the
resulting action in the input tensordict along with the list of action values.
The :class:`~torchrl.modules.QValueActor` class takes in a module and an action
specification, and outputs the selected action and its corresponding value.

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> # Create a tensor dict with an observation
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
>>> # Define the action space
>>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> # Create a linear module to output action values
>>> module = nn.Linear(3, 4)
>>> # Create a QValueActor instance
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> # Run the actor on the tensor dict
>>> qvalue_actor(td)
>>> print(td)
TensorDict(
Expand All @@ -229,122 +195,48 @@ resulting action in the input tensordict along with the list of action values.
device=None,
is_shared=False)

Distributional Q-learning is slightly different: in this case, the value network
does not output a scalar value for each state-action value.
Instead, the value space is divided in a an arbitrary number of "bins". The
value network outputs a probability that the state-action value belongs to one bin
or another.
Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
the value network encodes a
of a (s,a) -> v map. This map can be a table or a function.
For discrete action spaces with continuous (or near-continuous such as pixels)
states, it is customary to use a non-linear model such as a neural network for
the map.
The semantic of the Q-Value network is hopefully quite simple: we just need to
feed a tensor-to-tensor map that given a certain state (the input tensor),
outputs a list of action values to choose from. The wrapper will write the
resulting action in the input tensordict along with the list of action values.
This will output a tensor dict with the selected action and its corresponding value.

Distributional Q-Learning
^^^^^^^^^^^^^^^^^^^^^^^^^

Distributional Q-learning is a variant of Q-learning that represents the value function as a
probability distribution over possible values, rather than a single scalar value.
This allows the agent to learn about the uncertainty in the environment and make more informed
decisions.
In TorchRL, distributional Q-learning is implemented using the :class:`~torchrl.modules.DistributionalQValueActor`
class. This class takes in a module, an action specification, and a support vector, and outputs the selected
action and its corresponding value distribution.


>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn.functional_modules import make_functional
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules.tensordict_module.actors import QValueActor
>>> td = TensorDict({'observation': torch.randn(5, 3)}, [5])
>>> # we have 4 actions to choose from
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> # Create a tensor dict with an observation
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> # Define the action space
>>> action_spec = OneHot(4)
>>> # the model reads a state of dimension 3 and outputs 4 values, one for each action available
>>> module = nn.Linear(3, 4)
>>> qvalue_actor = QValueActor(module=module, spec=action_spec)
>>> qvalue_actor(td)
>>> # Define the number of bins for the value distribution
>>> nbins = 3
>>> # Create an MLP module to output logits for the value distribution
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> # Create a DistributionalQValueActor instance
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> # Run the actor on the tensor dict
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False),
chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

Distributional Q-learning is slightly different: in this case, the value network
does not output a scalar value for each state-action value.
Instead, the value space is divided in a an arbitrary number of "bins". The
value network outputs a probability that the state-action value belongs to one bin
or another.
Hence, for a state space of dimension M, an action space of dimension N and a number of bins B,
the value network encodes a :math:`\mathbb{R}^{M} \rightarrow \mathbb{R}^{N \times B}`
map. The following example shows how this works in TorchRL with the :class:`~torchrl.modules.tensordict_module.DistributionalQValueActor`
class:

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

>>> import torch
>>> from tensordict import TensorDict
>>> from torch import nn
>>> from torchrl.data import OneHot
>>> from torchrl.modules import DistributionalQValueActor, MLP
>>> td = TensorDict({'observation': torch.randn(5, 4)}, [5])
>>> nbins = 3
>>> # our model reads the observation and outputs a stack of 4 logits (one for each action) of size nbins=3
>>> module = MLP(out_features=(nbins, 4), depth=2)
>>> action_spec = OneHot(4)
>>> qvalue_actor = DistributionalQValueActor(module=module, spec=action_spec, support=torch.arange(nbins))
>>> td = qvalue_actor(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.int64, is_shared=False),
action_value: Tensor(shape=torch.Size([5, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
observation: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([5]),
device=None,
is_shared=False)

This will output a tensor dict with the selected action and its corresponding value distribution.

.. currentmodule:: torchrl.modules.tensordict_module

Expand Down Expand Up @@ -403,11 +295,10 @@ without shared parameters. It is mainly intended as a replacement for

Domain-specific TensorDict modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torchrl.modules.tensordict_module

These modules include dedicated solutions for MBRL or RLHF pipelines.

.. currentmodule:: torchrl.modules.tensordict_module

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst
Expand Down Expand Up @@ -558,9 +449,11 @@ Some distributions are typically used in RL scripts.

Utils
-----

.. currentmodule:: torchrl.modules.utils

The module utils include functionals used to do some custom mappings as well as a tool to
build :class:`~torchrl.envs.TensorDictPrimer` instances from a given module.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
if spec is not None and not isinstance(spec, TensorSpec):
raise TypeError("spec must be a TensorSpec subclass")
elif spec is not None and not isinstance(spec, Composite):
if len(self.out_keys) > 1:
if len(self.out_keys) - return_log_prob > 1:
raise RuntimeError(
f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
"Consider using a Composite object or no spec at all."
Expand Down

0 comments on commit edbf3de

Please sign in to comment.