-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
158 lines (135 loc) · 5.14 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# utilities for PPO training and evaluation
import jax
import jax.numpy as jnp
from flax import struct
from flax.training.train_state import TrainState
from xminigrid.environment import Environment, EnvParams
# Training stuff
class Transition(struct.PyTreeNode):
done: jax.Array
action: jax.Array
value: jax.Array
reward: jax.Array
log_prob: jax.Array
# for obs
obs: jax.Array
dir: jax.Array
# for rnn policy
prev_action: jax.Array
prev_reward: jax.Array
def calculate_gae(
transitions: Transition,
last_val: jax.Array,
gamma: float,
gae_lambda: float,
) -> tuple[jax.Array, jax.Array]:
# single iteration for the loop
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
delta = transition.reward + gamma * next_value * (1 - transition.done) - transition.value
gae = delta + gamma * gae_lambda * (1 - transition.done) * gae
return (gae, transition.value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
transitions,
reverse=True,
)
# advantages and values (Q)
return advantages, advantages + transitions.value
def ppo_update_networks(
train_state: TrainState,
transitions: Transition,
init_hstate: jax.Array,
advantages: jax.Array,
targets: jax.Array,
clip_eps: float,
vf_coef: float,
ent_coef: float,
):
# NORMALIZE ADVANTAGES
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
def _loss_fn(params):
# RERUN NETWORK
dist, value, _ = train_state.apply_fn(
params,
{
# [batch_size, seq_len, ...]
"obs_img": transitions.obs,
"obs_dir": transitions.dir,
"prev_action": transitions.prev_action,
"prev_reward": transitions.prev_reward,
},
init_hstate,
)
log_prob = dist.log_prob(transitions.action)
# CALCULATE VALUE LOSS
value_pred_clipped = transitions.value + (value - transitions.value).clip(-clip_eps, clip_eps)
value_loss = jnp.square(value - targets)
value_loss_clipped = jnp.square(value_pred_clipped - targets)
value_loss = 0.5 * jnp.maximum(value_loss, value_loss_clipped).mean()
# TODO: ablate this!
# value_loss = jnp.square(value - targets).mean()
# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - transitions.log_prob)
actor_loss1 = advantages * ratio
actor_loss2 = advantages * jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
actor_loss = -jnp.minimum(actor_loss1, actor_loss2).mean()
entropy = dist.entropy().mean()
total_loss = actor_loss + vf_coef * value_loss - ent_coef * entropy
return total_loss, (value_loss, actor_loss, entropy)
(loss, (vloss, aloss, entropy)), grads = jax.value_and_grad(_loss_fn, has_aux=True)(train_state.params)
(loss, vloss, aloss, entropy, grads) = jax.lax.pmean((loss, vloss, aloss, entropy, grads), axis_name="devices")
train_state = train_state.apply_gradients(grads=grads)
update_info = {
"total_loss": loss,
"value_loss": vloss,
"actor_loss": aloss,
"entropy": entropy,
}
return train_state, update_info
# for evaluation (evaluate for N consecutive episodes, sum rewards)
# N=1 single task, N>1 for meta-RL
class RolloutStats(struct.PyTreeNode):
reward: jax.Array = jnp.asarray(0.0)
length: jax.Array = jnp.asarray(0)
episodes: jax.Array = jnp.asarray(0)
def rollout(
rng: jax.Array,
env: Environment,
env_params: EnvParams,
train_state: TrainState,
init_hstate: jax.Array,
num_consecutive_episodes: int = 1,
) -> RolloutStats:
def _cond_fn(carry):
rng, stats, timestep, prev_action, prev_reward, hstate = carry
return jnp.less(stats.episodes, num_consecutive_episodes)
def _body_fn(carry):
rng, stats, timestep, prev_action, prev_reward, hstate = carry
rng, _rng = jax.random.split(rng)
dist, _, hstate = train_state.apply_fn(
train_state.params,
{
"obs_img": timestep.observation["img"][None, None, ...],
"obs_dir": timestep.observation["direction"][None, None, ...],
"prev_action": prev_action[None, None, ...],
"prev_reward": prev_reward[None, None, ...],
},
hstate,
)
action = dist.sample(seed=_rng).squeeze()
timestep = env.step(env_params, timestep, action)
stats = stats.replace(
reward=stats.reward + timestep.reward,
length=stats.length + 1,
episodes=stats.episodes + timestep.last(),
)
carry = (rng, stats, timestep, action, timestep.reward, hstate)
return carry
timestep = env.reset(env_params, rng)
prev_action = jnp.asarray(0)
prev_reward = jnp.asarray(0)
init_carry = (rng, RolloutStats(), timestep, prev_action, prev_reward, init_hstate)
final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
return final_carry[1]