-
Notifications
You must be signed in to change notification settings - Fork 8
/
critic.py
50 lines (35 loc) · 1.54 KB
/
critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Tuple
import jax.numpy as jnp
from common import Batch, InfoDict, Model, Params
def loss(diff, expectile=0.8):
weight = jnp.where(diff > 0, expectile, (1 - expectile))
return weight * (diff**2)
def update_v(critic: Model, value: Model, batch: Batch,
expectile: float) -> Tuple[Model, InfoDict]:
actions = batch.actions
q1, q2 = critic(batch.observations, actions)
q = jnp.minimum(q1, q2)
def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
v = value.apply({'params': value_params}, batch.observations)
value_loss = loss(q - v, expectile).mean()
return value_loss, {
'value_loss': value_loss,
'v': v.mean(),
}
new_value, info = value.apply_gradient(value_loss_fn)
return new_value, info
def update_q(critic: Model, target_value: Model, batch: Batch,
discount: float) -> Tuple[Model, InfoDict]:
next_v = target_value(batch.next_observations)
target_q = batch.rewards + discount * batch.masks * next_v
def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
q1, q2 = critic.apply({'params': critic_params}, batch.observations,
batch.actions)
critic_loss = ((q1 - target_q)**2 + (q2 - target_q)**2).mean()
return critic_loss, {
'critic_loss': critic_loss,
'q1': q1.mean(),
'q2': q2.mean()
}
new_critic, info = critic.apply_gradient(critic_loss_fn)
return new_critic, info