-
Notifications
You must be signed in to change notification settings - Fork 8
/
actor.py
30 lines (21 loc) · 986 Bytes
/
actor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Tuple
import jax
import jax.numpy as jnp
from common import Batch, InfoDict, Model, Params, PRNGKey
def update(key: PRNGKey, actor: Model, critic: Model, value: Model,
batch: Batch, temperature: float) -> Tuple[Model, InfoDict]:
v = value(batch.observations)
q1, q2 = critic(batch.observations, batch.actions)
q = jnp.minimum(q1, q2)
exp_a = jnp.exp((q - v) * temperature)
exp_a = jnp.minimum(exp_a, 100.0)
def actor_loss_fn(actor_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
dist = actor.apply({'params': actor_params},
batch.observations,
training=True,
rngs={'dropout': key})
log_probs = dist.log_prob(batch.actions)
actor_loss = -(exp_a * log_probs).mean()
return actor_loss, {'actor_loss': actor_loss, 'adv': q - v}
new_actor, info = actor.apply_gradient(actor_loss_fn)
return new_actor, info