Skip to content

Commit e914db2

Browse files
committed
Update model instantiators implementations
1 parent 8f1900e commit e914db2

9 files changed

+361
-351
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,5 @@
1-
from enum import Enum
2-
31
from skrl.utils.model_instantiators.torch.categorical import categorical_model
42
from skrl.utils.model_instantiators.torch.deterministic import deterministic_model
53
from skrl.utils.model_instantiators.torch.gaussian import gaussian_model
64
from skrl.utils.model_instantiators.torch.multivariate_gaussian import multivariate_gaussian_model
75
from skrl.utils.model_instantiators.torch.shared import shared_model
8-
9-
10-
# keep for compatibility with versions prior to 1.3.0
11-
class Shape(Enum):
12-
"""
13-
Enum to select the shape of the model's inputs and outputs
14-
"""
15-
16-
ONE = 1
17-
STATES = 0
18-
OBSERVATIONS = 0
19-
ACTIONS = -1
20-
STATES_ACTIONS = -2

skrl/utils/model_instantiators/torch/categorical.py

+29-36
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,37 @@
88

99
from skrl.models.torch import CategoricalMixin # noqa
1010
from skrl.models.torch import Model
11-
from skrl.utils.model_instantiators.torch.common import convert_deprecated_parameters, generate_containers
11+
from skrl.utils.model_instantiators.torch.common import generate_containers
1212
from skrl.utils.spaces.torch import unflatten_tensorized_space # noqa
1313

1414

1515
def categorical_model(
16-
observation_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
17-
action_space: Optional[Union[int, Tuple[int], gymnasium.Space]] = None,
16+
*,
17+
observation_space: Optional[gymnasium.Space] = None,
18+
state_space: Optional[gymnasium.Space] = None,
19+
action_space: Optional[gymnasium.Space] = None,
1820
device: Optional[Union[str, torch.device]] = None,
1921
unnormalized_log_prob: bool = True,
2022
network: Sequence[Mapping[str, Any]] = [],
2123
output: Union[str, Sequence[str]] = "",
2224
return_source: bool = False,
23-
*args,
24-
**kwargs,
2525
) -> Union[Model, str]:
26-
"""Instantiate a categorical model
26+
"""Instantiate a :class:`~skrl.models.torch.categorical.CategoricalMixin`-based model.
2727
28-
:param observation_space: Observation/state space or shape (default: None).
29-
If it is not None, the num_observations property will contain the size of that space
30-
:type observation_space: int, tuple or list of integers, gymnasium.Space or None, optional
31-
:param action_space: Action space or shape (default: None).
32-
If it is not None, the num_actions property will contain the size of that space
33-
:type action_space: int, tuple or list of integers, gymnasium.Space or None, optional
34-
:param device: Device on which a tensor/array is or will be allocated (default: ``None``).
35-
If None, the device will be either ``"cuda"`` if available or ``"cpu"``
36-
:type device: str or torch.device, optional
37-
:param unnormalized_log_prob: Flag to indicate how to be interpreted the model's output (default: True).
38-
If True, the model's output is interpreted as unnormalized log probabilities
39-
(it can be any real number), otherwise as normalized probabilities
40-
(the output must be non-negative, finite and have a non-zero sum)
41-
:type unnormalized_log_prob: bool, optional
42-
:param network: Network definition (default: [])
43-
:type network: list of dict, optional
44-
:param output: Output expression (default: "")
45-
:type output: list or str, optional
28+
:param observation_space: Observation space. The ``num_observations`` property will contain the size of the space.
29+
:param state_space: State space. The ``num_states`` property will contain the size of the space.
30+
:param action_space: Action space. The ``num_actions`` property will contain the size of the space.
31+
:param device: Data allocation and computation device. If not specified, the default device will be used.
32+
:param unnormalized_log_prob: Flag to indicate how to the model's output will be interpreted.
33+
If True, the model's output is interpreted as unnormalized log probabilities (it can be any real number),
34+
otherwise as normalized probabilities (the output must be non-negative, finite and have a non-zero sum).
35+
:param network: Network definition.
36+
:param output: Output expression.
4637
:param return_source: Whether to return the source string containing the model class used to
47-
instantiate the model rather than the model instance (default: False).
48-
:type return_source: bool, optional
38+
instantiate the model rather than the model instance.
4939
50-
:return: Categorical model instance or definition source
51-
:rtype: Model
40+
:return: Categorical model instance or definition source (if ``return_source`` is True).
5241
"""
53-
# compatibility with versions prior to 1.3.0
54-
if not network and kwargs:
55-
network, output = convert_deprecated_parameters(kwargs)
56-
5742
# parse model definition
5843
containers, output = generate_containers(network, output, embed_output=True, indent=1)
5944

@@ -77,14 +62,21 @@ def categorical_model(
7762
forward = textwrap.indent("\n".join(forward), prefix=" " * 8)[8:]
7863

7964
template = f"""class CategoricalModel(CategoricalMixin, Model):
80-
def __init__(self, observation_space, action_space, device, unnormalized_log_prob):
81-
Model.__init__(self, observation_space, action_space, device)
82-
CategoricalMixin.__init__(self, unnormalized_log_prob)
65+
def __init__(self, observation_space, state_space, action_space, device=None, unnormalized_log_prob=True, role=""):
66+
Model.__init__(
67+
self,
68+
observation_space=observation_space,
69+
state_space=state_space,
70+
action_space=action_space,
71+
device=device,
72+
)
73+
CategoricalMixin.__init__(self, unnormalized_log_prob=unnormalized_log_prob, role=role)
8374
8475
{networks}
8576
8677
def compute(self, inputs, role=""):
87-
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
78+
observations = unflatten_tensorized_space(self.observation_space, inputs.get("observations"))
79+
states = unflatten_tensorized_space(self.state_space, inputs.get("states"))
8880
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
8981
{forward}
9082
return output, {{}}
@@ -98,6 +90,7 @@ def compute(self, inputs, role=""):
9890
exec(template, globals(), _locals)
9991
return _locals["CategoricalModel"](
10092
observation_space=observation_space,
93+
state_space=state_space,
10194
action_space=action_space,
10295
device=device,
10396
unnormalized_log_prob=unnormalized_log_prob,

skrl/utils/model_instantiators/torch/common.py

+33-71
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
import ast
44

5-
from skrl import logger
6-
75

86
def _get_activation_function(activation: Union[str, None], as_module: bool = True) -> Union[str, None]:
9-
"""Get the activation function
7+
"""Get the activation function.
108
119
Supported activation functions:
1210
@@ -20,10 +18,10 @@ def _get_activation_function(activation: Union[str, None], as_module: bool = Tru
2018
- "softsign"
2119
- "tanh"
2220
23-
:param activation: Activation function name
24-
:param as_module: Whether to return a PyTorch module instance rather than a functional method
21+
:param activation: Activation function name.
22+
:param as_module: Whether to return a PyTorch module instance rather than a functional method.
2523
26-
:return: Activation function or None if the activation is not supported
24+
:return: Activation function or ``None`` if the activation is not supported.
2725
"""
2826
activations = {
2927
"elu": "nn.ELU()" if as_module else "functional.elu",
@@ -40,11 +38,11 @@ def _get_activation_function(activation: Union[str, None], as_module: bool = Tru
4038

4139

4240
def _parse_input(source: str) -> str:
43-
"""Parse a network input expression by replacing substitutions and applying operations
41+
"""Parse a network input expression by replacing substitutions and applying operations.
4442
45-
:param source: Input expression
43+
:param source: Input expression.
4644
47-
:return: Parsed network input
45+
:return: Parsed network input.
4846
"""
4947

5048
class NodeTransformer(ast.NodeTransformer):
@@ -64,24 +62,20 @@ def visit_Call(self, node: ast.Call):
6462
NodeTransformer().visit(tree)
6563
source = ast.unparse(tree)
6664
# enum substitutions
67-
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace(
68-
"STATES_ACTIONS", "torch.cat([states, taken_actions], dim=1)"
69-
)
70-
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace(
71-
"OBSERVATIONS_ACTIONS", "torch.cat([states, taken_actions], dim=1)"
72-
)
73-
source = source.replace("Shape.STATES", "STATES").replace("STATES", "states")
74-
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states")
75-
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions")
65+
source = source.replace("OBSERVATIONS_ACTIONS", "torch.cat([observations, taken_actions], dim=1)")
66+
source = source.replace("STATES_ACTIONS", "torch.cat([states, taken_actions], dim=1)")
67+
source = source.replace("OBSERVATIONS", "observations")
68+
source = source.replace("STATES", "states")
69+
source = source.replace("ACTIONS", "taken_actions")
7670
return source
7771

7872

7973
def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]:
80-
"""Parse the network output expression by replacing substitutions and applying operations
74+
"""Parse the network output expression by replacing substitutions and applying operations.
8175
82-
:param source: Output expression
76+
:param source: Output expression.
8377
84-
:return: Tuple with the parsed network output, generated modules and output size/shape
78+
:return: Tuple with the parsed network output, generated modules and output size/shape.
8579
"""
8680

8781
class NodeTransformer(ast.NodeTransformer):
@@ -101,7 +95,6 @@ def visit_Call(self, node: ast.Call):
10195
modules = []
10296
if type(source) is str:
10397
# enum substitutions
104-
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("Shape.ONE", "ONE")
10598
token = "ACTIONS" if "ACTIONS" in source else None
10699
token = "ONE" if "ONE" in source else token
107100
if token:
@@ -120,13 +113,13 @@ def visit_Call(self, node: ast.Call):
120113

121114

122115
def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], str]) -> Sequence[str]:
123-
"""Generate network modules
116+
"""Generate network modules.
124117
125118
:param layers: Layer definitions
126119
:param activations: Activation function definitions applied after each layer (except ``flatten`` layers).
127-
If a single activation function is specified (str or lis), it will be applied after each layer
120+
If a single activation function is specified (str or list), it will be applied after each layer.
128121
129-
:return: A list of generated modules
122+
:return: A list of generated modules.
130123
"""
131124
# expand activations
132125
if type(activations) is str:
@@ -224,21 +217,24 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s
224217

225218

226219
def get_num_units(token: Union[str, Any]) -> Union[str, Any]:
227-
"""Get the number of units/features a token represent
220+
"""Get the number of units/features a token represents.
228221
229-
:param token: Token
222+
:param token: Token.
230223
231-
:return: Number of units/features a token represent. If the token is unknown, its value will be returned as it
224+
:return: Number of units/features a token represents. If the token is unknown, its value will be returned as it.
232225
"""
233226
num_units = {
234227
"ONE": "1",
235-
"STATES": "self.num_observations",
228+
"NUM_OBSERVATIONS": "self.num_observations",
229+
"NUM_STATES": "self.num_states",
230+
"NUM_ACTIONS": "self.num_actions",
236231
"OBSERVATIONS": "self.num_observations",
232+
"STATES": "self.num_states",
237233
"ACTIONS": "self.num_actions",
238-
"STATES_ACTIONS": "self.num_observations + self.num_actions",
239234
"OBSERVATIONS_ACTIONS": "self.num_observations + self.num_actions",
235+
"STATES_ACTIONS": "self.num_states + self.num_actions",
240236
}
241-
token_as_str = str(token).replace("Shape.", "")
237+
token_as_str = str(token)
242238
if token_as_str in num_units:
243239
return num_units[token_as_str]
244240
return token
@@ -247,16 +243,16 @@ def get_num_units(token: Union[str, Any]) -> Union[str, Any]:
247243
def generate_containers(
248244
network: Sequence[Mapping[str, Any]], output: Union[str, Sequence[str]], embed_output: bool = True, indent: int = -1
249245
) -> Tuple[Sequence[Mapping[str, Any]], Mapping[str, Any]]:
250-
"""Generate network containers
246+
"""Generate network containers.
251247
252-
:param network: Network definition
253-
:param output: Network's output expression
248+
:param network: Network definition.
249+
:param output: Network's output expression.
254250
:param embed_output: Whether to embed the output modules (if any) in the container definition.
255-
If True, the output modules will be append to the last container module
251+
If True, the output modules will be append to the last container module.
256252
:param indent: Indentation level used to generate the Sequential definition.
257-
If negative, no indentation will be applied
253+
If negative, no indentation will be applied.
258254
259-
:return: Network containers and output
255+
:return: Network containers and output.
260256
"""
261257
# parse output
262258
output, output_modules, output_size = _parse_output(output)
@@ -290,37 +286,3 @@ def generate_containers(
290286
output = output.replace("PLACEHOLDER", container["name"] if embed_output else "output")
291287
output = {"output": output, "modules": output_modules, "size": output_size}
292288
return containers, output
293-
294-
295-
def convert_deprecated_parameters(parameters: Mapping[str, Any]) -> Tuple[Mapping[str, Any], str]:
296-
"""Function to convert deprecated parameters to network-output format
297-
298-
:param parameters: Deprecated parameters and their values.
299-
300-
:return: Network and output definitions
301-
"""
302-
logger.warning(
303-
f'The following parameters ({", ".join(list(parameters.keys()))}) are deprecated. '
304-
"See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html"
305-
)
306-
# network definition
307-
activations = parameters.get("hidden_activation", [])
308-
if type(activations) in [list, tuple] and len(set(activations)) == 1:
309-
activations = activations[0]
310-
network = [
311-
{
312-
"name": "net",
313-
"input": str(parameters.get("input_shape", "STATES")),
314-
"layers": parameters.get("hiddens", []),
315-
"activations": activations,
316-
}
317-
]
318-
# output
319-
output_scale = parameters.get("output_scale", 1.0)
320-
scale_operation = f"{output_scale} * " if output_scale != 1.0 else ""
321-
if parameters.get("output_activation", None):
322-
output = f'{scale_operation}{parameters["output_activation"]}({str(parameters.get("output_shape", "ACTIONS"))})'
323-
else:
324-
output = f'{scale_operation}{str(parameters.get("output_shape", "ACTIONS"))}'
325-
326-
return network, output

0 commit comments

Comments
 (0)