-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathdiffusion_ppo.py
199 lines (173 loc) · 6.97 KB
/
diffusion_ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
DPPO: Diffusion Policy Policy Optimization.
K: number of denoising steps
To: observation sequence length
Ta: action chunk size
Do: observation dimension
Da: action dimension
C: image channels
H, W: image height and width
"""
from typing import Optional
import torch
import logging
import math
log = logging.getLogger(__name__)
from model.diffusion.diffusion_vpg import VPGDiffusion
class PPODiffusion(VPGDiffusion):
def __init__(
self,
gamma_denoising: float,
clip_ploss_coef: float,
clip_ploss_coef_base: float = 1e-3,
clip_ploss_coef_rate: float = 3,
clip_vloss_coef: Optional[float] = None,
clip_advantage_lower_quantile: float = 0,
clip_advantage_upper_quantile: float = 1,
norm_adv: bool = True,
**kwargs,
):
super().__init__(**kwargs)
# Whether to normalize advantages within batch
self.norm_adv = norm_adv
# Clipping value for policy loss
self.clip_ploss_coef = clip_ploss_coef
self.clip_ploss_coef_base = clip_ploss_coef_base
self.clip_ploss_coef_rate = clip_ploss_coef_rate
# Clipping value for value loss
self.clip_vloss_coef = clip_vloss_coef
# Discount factor for diffusion MDP
self.gamma_denoising = gamma_denoising
# Quantiles for clipping advantages
self.clip_advantage_lower_quantile = clip_advantage_lower_quantile
self.clip_advantage_upper_quantile = clip_advantage_upper_quantile
def loss(
self,
obs,
chains_prev,
chains_next,
denoising_inds,
returns,
oldvalues,
advantages,
oldlogprobs,
use_bc_loss=False,
reward_horizon=4,
):
"""
PPO loss
obs: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
rgb: (B, To, C, H, W)
chains: (B, K+1, Ta, Da)
returns: (B, )
values: (B, )
advantages: (B,)
oldlogprobs: (B, K, Ta, Da)
use_bc_loss: whether to add BC regularization loss
reward_horizon: action horizon that backpropagates gradient
"""
# Get new logprobs for denoising steps from T-1 to 0 - entropy is fixed fod diffusion
newlogprobs, eta = self.get_logprobs_subsample(
obs,
chains_prev,
chains_next,
denoising_inds,
get_ent=True,
)
entropy_loss = -eta.mean()
newlogprobs = newlogprobs.clamp(min=-5, max=2)
oldlogprobs = oldlogprobs.clamp(min=-5, max=2)
# only backpropagate through the earlier steps (e.g., ones actually executed in the environment)
newlogprobs = newlogprobs[:, :reward_horizon, :]
oldlogprobs = oldlogprobs[:, :reward_horizon, :]
# Get the logprobs - batch over B and denoising steps
newlogprobs = newlogprobs.mean(dim=(-1, -2)).view(-1)
oldlogprobs = oldlogprobs.mean(dim=(-1, -2)).view(-1)
bc_loss = 0
if use_bc_loss:
# See Eqn. 2 of https://arxiv.org/pdf/2403.03949.pdf
# Give a reward for maximizing probability of teacher policy's action with current policy.
# Actions are chosen along trajectory induced by current policy.
# Get counterfactual teacher actions
samples = self.forward(
cond=obs,
deterministic=False,
return_chain=True,
use_base_policy=True,
)
# Get logprobs of teacher actions under this policy
bc_logprobs = self.get_logprobs(
obs,
samples.chains,
get_ent=False,
use_base_policy=False,
)
bc_logprobs = bc_logprobs.clamp(min=-5, max=2)
bc_logprobs = bc_logprobs.mean(dim=(-1, -2)).view(-1)
bc_loss = -bc_logprobs.mean()
# normalize advantages
if self.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Clip advantages by 5th and 95th percentile
advantage_min = torch.quantile(advantages, self.clip_advantage_lower_quantile)
advantage_max = torch.quantile(advantages, self.clip_advantage_upper_quantile)
advantages = advantages.clamp(min=advantage_min, max=advantage_max)
# denoising discount
discount = torch.tensor(
[
self.gamma_denoising ** (self.ft_denoising_steps - i - 1)
for i in denoising_inds
]
).to(self.device)
advantages *= discount
# get ratio
logratio = newlogprobs - oldlogprobs
ratio = logratio.exp()
# exponentially interpolate between the base and the current clipping value over denoising steps and repeat
t = (denoising_inds.float() / (self.ft_denoising_steps - 1)).to(self.device)
if self.ft_denoising_steps > 1:
clip_ploss_coef = self.clip_ploss_coef_base + (
self.clip_ploss_coef - self.clip_ploss_coef_base
) * (torch.exp(self.clip_ploss_coef_rate * t) - 1) / (
math.exp(self.clip_ploss_coef_rate) - 1
)
else:
clip_ploss_coef = t
# get kl difference and whether value clipped
with torch.no_grad():
# old_approx_kl: the approximate Kullback–Leibler divergence, measured by (-logratio).mean(), which corresponds to the k1 estimator in John Schulman’s blog post on approximating KL http://joschu.net/blog/kl-approx.html
# approx_kl: better alternative to old_approx_kl measured by (logratio.exp() - 1) - logratio, which corresponds to the k3 estimator in approximating KL http://joschu.net/blog/kl-approx.html
# old_approx_kl = (-logratio).mean()
approx_kl = ((ratio - 1) - logratio).mean()
clipfrac = ((ratio - 1.0).abs() > clip_ploss_coef).float().mean().item()
# Policy loss with clipping
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio, 1 - clip_ploss_coef, 1 + clip_ploss_coef
)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
# Value loss optionally with clipping
newvalues = self.critic(obs).view(-1)
if self.clip_vloss_coef is not None:
v_loss_unclipped = (newvalues - returns) ** 2
v_clipped = oldvalues + torch.clamp(
newvalues - oldvalues,
-self.clip_vloss_coef,
self.clip_vloss_coef,
)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
else:
v_loss = 0.5 * ((newvalues - returns) ** 2).mean()
return (
pg_loss,
entropy_loss,
v_loss,
clipfrac,
approx_kl.item(),
ratio.mean().item(),
bc_loss,
eta.mean().item(),
)