-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactor_critic_mlp.py
77 lines (61 loc) · 1.91 KB
/
actor_critic_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
from mlp import MLP
class ActorCriticMLP(nn.Module):
def __init__(
self,
num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims,
critic_hidden_dims,
activation="elu",
init_noise_std=1.0,
fixed_std=False,
**kwargs,
):
print("----------------------------------")
print("ActorCriticMLP")
if kwargs:
print(
"ActorCritic.__init__ got unexpected arguments, which will be ignored: "
+ str([key for key in kwargs.keys()])
)
super(ActorCriticMLP, self).__init__()
# Policy
actor_num_input = num_actor_obs
actor_num_output = num_actions
actor_activation = activation
self.actor = MLP(
actor_num_input,
actor_num_output,
actor_hidden_dims,
actor_activation,
norm="none",
)
print(f"Actor MLP: {self.actor}")
# Value function
critic_num_input = num_critic_obs
critic_num_output = 1
critic_activation = activation
self.critic = MLP(
critic_num_input,
critic_num_output,
critic_hidden_dims,
critic_activation,
norm="none",
)
print(f"Critic MLP: {self.critic}")
# Action noise
self.fixed_std = fixed_std
std = init_noise_std * torch.ones(num_actions)
self.std = torch.tensor(std) if fixed_std else nn.Parameter(std)
self.distribution = None
def reset(self, dones=None):
pass
def forward(self, observations):
actions = self.actor(observations.detach())
return actions
def evaluate(self, critic_observations, **kwargs):
values = self.critic(critic_observations)
return values