2
2
3
3
import ast
4
4
5
- from skrl import logger
6
-
7
5
8
6
def _get_activation_function (activation : Union [str , None ], as_module : bool = True ) -> Union [str , None ]:
9
- """Get the activation function
7
+ """Get the activation function.
10
8
11
9
Supported activation functions:
12
10
@@ -20,10 +18,10 @@ def _get_activation_function(activation: Union[str, None], as_module: bool = Tru
20
18
- "softsign"
21
19
- "tanh"
22
20
23
- :param activation: Activation function name
24
- :param as_module: Whether to return a PyTorch module instance rather than a functional method
21
+ :param activation: Activation function name.
22
+ :param as_module: Whether to return a PyTorch module instance rather than a functional method.
25
23
26
- :return: Activation function or None if the activation is not supported
24
+ :return: Activation function or `` None`` if the activation is not supported.
27
25
"""
28
26
activations = {
29
27
"elu" : "nn.ELU()" if as_module else "functional.elu" ,
@@ -40,11 +38,11 @@ def _get_activation_function(activation: Union[str, None], as_module: bool = Tru
40
38
41
39
42
40
def _parse_input (source : str ) -> str :
43
- """Parse a network input expression by replacing substitutions and applying operations
41
+ """Parse a network input expression by replacing substitutions and applying operations.
44
42
45
- :param source: Input expression
43
+ :param source: Input expression.
46
44
47
- :return: Parsed network input
45
+ :return: Parsed network input.
48
46
"""
49
47
50
48
class NodeTransformer (ast .NodeTransformer ):
@@ -64,24 +62,20 @@ def visit_Call(self, node: ast.Call):
64
62
NodeTransformer ().visit (tree )
65
63
source = ast .unparse (tree )
66
64
# enum substitutions
67
- source = source .replace ("Shape.STATES_ACTIONS" , "STATES_ACTIONS" ).replace (
68
- "STATES_ACTIONS" , "torch.cat([states, taken_actions], dim=1)"
69
- )
70
- source = source .replace ("Shape.OBSERVATIONS_ACTIONS" , "OBSERVATIONS_ACTIONS" ).replace (
71
- "OBSERVATIONS_ACTIONS" , "torch.cat([states, taken_actions], dim=1)"
72
- )
73
- source = source .replace ("Shape.STATES" , "STATES" ).replace ("STATES" , "states" )
74
- source = source .replace ("Shape.OBSERVATIONS" , "OBSERVATIONS" ).replace ("OBSERVATIONS" , "states" )
75
- source = source .replace ("Shape.ACTIONS" , "ACTIONS" ).replace ("ACTIONS" , "taken_actions" )
65
+ source = source .replace ("OBSERVATIONS_ACTIONS" , "torch.cat([observations, taken_actions], dim=1)" )
66
+ source = source .replace ("STATES_ACTIONS" , "torch.cat([states, taken_actions], dim=1)" )
67
+ source = source .replace ("OBSERVATIONS" , "observations" )
68
+ source = source .replace ("STATES" , "states" )
69
+ source = source .replace ("ACTIONS" , "taken_actions" )
76
70
return source
77
71
78
72
79
73
def _parse_output (source : Union [str , Sequence [str ]]) -> Tuple [Union [str , Sequence [str ]], Sequence [str ], int ]:
80
- """Parse the network output expression by replacing substitutions and applying operations
74
+ """Parse the network output expression by replacing substitutions and applying operations.
81
75
82
- :param source: Output expression
76
+ :param source: Output expression.
83
77
84
- :return: Tuple with the parsed network output, generated modules and output size/shape
78
+ :return: Tuple with the parsed network output, generated modules and output size/shape.
85
79
"""
86
80
87
81
class NodeTransformer (ast .NodeTransformer ):
@@ -101,7 +95,6 @@ def visit_Call(self, node: ast.Call):
101
95
modules = []
102
96
if type (source ) is str :
103
97
# enum substitutions
104
- source = source .replace ("Shape.ACTIONS" , "ACTIONS" ).replace ("Shape.ONE" , "ONE" )
105
98
token = "ACTIONS" if "ACTIONS" in source else None
106
99
token = "ONE" if "ONE" in source else token
107
100
if token :
@@ -120,13 +113,13 @@ def visit_Call(self, node: ast.Call):
120
113
121
114
122
115
def _generate_modules (layers : Sequence [str ], activations : Union [Sequence [str ], str ]) -> Sequence [str ]:
123
- """Generate network modules
116
+ """Generate network modules.
124
117
125
118
:param layers: Layer definitions
126
119
:param activations: Activation function definitions applied after each layer (except ``flatten`` layers).
127
- If a single activation function is specified (str or lis ), it will be applied after each layer
120
+ If a single activation function is specified (str or list ), it will be applied after each layer.
128
121
129
- :return: A list of generated modules
122
+ :return: A list of generated modules.
130
123
"""
131
124
# expand activations
132
125
if type (activations ) is str :
@@ -224,21 +217,24 @@ def _generate_modules(layers: Sequence[str], activations: Union[Sequence[str], s
224
217
225
218
226
219
def get_num_units (token : Union [str , Any ]) -> Union [str , Any ]:
227
- """Get the number of units/features a token represent
220
+ """Get the number of units/features a token represents.
228
221
229
- :param token: Token
222
+ :param token: Token.
230
223
231
- :return: Number of units/features a token represent . If the token is unknown, its value will be returned as it
224
+ :return: Number of units/features a token represents . If the token is unknown, its value will be returned as it.
232
225
"""
233
226
num_units = {
234
227
"ONE" : "1" ,
235
- "STATES" : "self.num_observations" ,
228
+ "NUM_OBSERVATIONS" : "self.num_observations" ,
229
+ "NUM_STATES" : "self.num_states" ,
230
+ "NUM_ACTIONS" : "self.num_actions" ,
236
231
"OBSERVATIONS" : "self.num_observations" ,
232
+ "STATES" : "self.num_states" ,
237
233
"ACTIONS" : "self.num_actions" ,
238
- "STATES_ACTIONS" : "self.num_observations + self.num_actions" ,
239
234
"OBSERVATIONS_ACTIONS" : "self.num_observations + self.num_actions" ,
235
+ "STATES_ACTIONS" : "self.num_states + self.num_actions" ,
240
236
}
241
- token_as_str = str (token ). replace ( "Shape." , "" )
237
+ token_as_str = str (token )
242
238
if token_as_str in num_units :
243
239
return num_units [token_as_str ]
244
240
return token
@@ -247,16 +243,16 @@ def get_num_units(token: Union[str, Any]) -> Union[str, Any]:
247
243
def generate_containers (
248
244
network : Sequence [Mapping [str , Any ]], output : Union [str , Sequence [str ]], embed_output : bool = True , indent : int = - 1
249
245
) -> Tuple [Sequence [Mapping [str , Any ]], Mapping [str , Any ]]:
250
- """Generate network containers
246
+ """Generate network containers.
251
247
252
- :param network: Network definition
253
- :param output: Network's output expression
248
+ :param network: Network definition.
249
+ :param output: Network's output expression.
254
250
:param embed_output: Whether to embed the output modules (if any) in the container definition.
255
- If True, the output modules will be append to the last container module
251
+ If True, the output modules will be append to the last container module.
256
252
:param indent: Indentation level used to generate the Sequential definition.
257
- If negative, no indentation will be applied
253
+ If negative, no indentation will be applied.
258
254
259
- :return: Network containers and output
255
+ :return: Network containers and output.
260
256
"""
261
257
# parse output
262
258
output , output_modules , output_size = _parse_output (output )
@@ -290,37 +286,3 @@ def generate_containers(
290
286
output = output .replace ("PLACEHOLDER" , container ["name" ] if embed_output else "output" )
291
287
output = {"output" : output , "modules" : output_modules , "size" : output_size }
292
288
return containers , output
293
-
294
-
295
- def convert_deprecated_parameters (parameters : Mapping [str , Any ]) -> Tuple [Mapping [str , Any ], str ]:
296
- """Function to convert deprecated parameters to network-output format
297
-
298
- :param parameters: Deprecated parameters and their values.
299
-
300
- :return: Network and output definitions
301
- """
302
- logger .warning (
303
- f'The following parameters ({ ", " .join (list (parameters .keys ()))} ) are deprecated. '
304
- "See https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html"
305
- )
306
- # network definition
307
- activations = parameters .get ("hidden_activation" , [])
308
- if type (activations ) in [list , tuple ] and len (set (activations )) == 1 :
309
- activations = activations [0 ]
310
- network = [
311
- {
312
- "name" : "net" ,
313
- "input" : str (parameters .get ("input_shape" , "STATES" )),
314
- "layers" : parameters .get ("hiddens" , []),
315
- "activations" : activations ,
316
- }
317
- ]
318
- # output
319
- output_scale = parameters .get ("output_scale" , 1.0 )
320
- scale_operation = f"{ output_scale } * " if output_scale != 1.0 else ""
321
- if parameters .get ("output_activation" , None ):
322
- output = f'{ scale_operation } { parameters ["output_activation" ]} ({ str (parameters .get ("output_shape" , "ACTIONS" ))} )'
323
- else :
324
- output = f'{ scale_operation } { str (parameters .get ("output_shape" , "ACTIONS" ))} '
325
-
326
- return network , output
0 commit comments