-
Notifications
You must be signed in to change notification settings - Fork 205
/
sde_lib.py
256 lines (208 loc) · 7.28 KB
/
sde_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
import abc
import jax.numpy as jnp
import jax
import numpy as np
from utils import batch_mul
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t):
pass
@abc.abstractmethod
def marginal_prob(self, x, t):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
pass
@abc.abstractmethod
def prior_sampling(self, rng, shape):
"""Generate one sample from the prior distribution, $p_T(x)$."""
pass
@abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
def discretize(self, x, t):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a JAX tensor.
t: a JAX float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = 1 / self.N
drift, diffusion = self.sde(x, t)
f = drift * dt
G = diffusion * jnp.sqrt(dt)
return f, G
def reverse(self, score_fn, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_fn: A time-dependent score-based model that takes x and t and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = self.N
T = self.T
sde_fn = self.sde
discretize_fn = self.discretize
# Build the class for reverse-time SDE.
class RSDE(self.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
drift, diffusion = sde_fn(x, t)
score = score_fn(x, t)
drift = drift - batch_mul(diffusion ** 2, score * (0.5 if self.probability_flow else 1.))
# Set the diffusion function to zero for ODEs.
diffusion = jnp.zeros_like(diffusion) if self.probability_flow else diffusion
return drift, diffusion
def discretize(self, x, t):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t)
rev_f = f - batch_mul(G ** 2, score_fn(x, t) * (0.5 if self.probability_flow else 1.))
rev_G = jnp.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
class VPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
"""Construct a Variance Preserving SDE.
Args:
beta_min: value of beta(0)
beta_max: value of beta(1)
N: number of discretization steps
"""
super().__init__(N)
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
self.discrete_betas = jnp.linspace(beta_min / N, beta_max / N, N)
self.alphas = 1. - self.discrete_betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod)
self.sqrt_1m_alphas_cumprod = jnp.sqrt(1. - self.alphas_cumprod)
@property
def T(self):
return 1
def sde(self, x, t):
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * batch_mul(beta_t, x)
diffusion = jnp.sqrt(beta_t)
return drift, diffusion
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
mean = batch_mul(jnp.exp(log_mean_coeff), x)
std = jnp.sqrt(1 - jnp.exp(2. * log_mean_coeff))
return mean, std
def prior_sampling(self, rng, shape):
return jax.random.normal(rng, shape)
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
logp_fn = lambda z: -N / 2. * jnp.log(2 * np.pi) - jnp.sum(z ** 2) / 2.
return jax.vmap(logp_fn)(z)
def discretize(self, x, t):
"""DDPM discretization."""
timestep = (t * (self.N - 1) / self.T).astype(jnp.int32)
beta = self.discrete_betas[timestep]
alpha = self.alphas[timestep]
sqrt_beta = jnp.sqrt(beta)
f = batch_mul(jnp.sqrt(alpha), x) - x
G = sqrt_beta
return f, G
class subVPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
"""Construct the sub-VP SDE that excels at likelihoods.
Args:
beta_min: value of beta(0)
beta_max: value of beta(1)
N: number of discretization steps
"""
super().__init__(N)
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
@property
def T(self):
return 1
def sde(self, x, t):
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * batch_mul(beta_t, x)
discount = 1. - jnp.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)
diffusion = jnp.sqrt(beta_t * discount)
return drift, diffusion
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
mean = batch_mul(jnp.exp(log_mean_coeff), x)
std = 1 - jnp.exp(2. * log_mean_coeff)
return mean, std
def prior_sampling(self, rng, shape):
return jax.random.normal(rng, shape)
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
logp_fn = lambda z: -N / 2. * jnp.log(2 * np.pi) - jnp.sum(z ** 2) / 2.
return jax.vmap(logp_fn)(z)
class VESDE(SDE):
def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
"""Construct a Variance Exploding SDE.
Args:
sigma_min: smallest sigma.
sigma_max: largest sigma.
N: number of discretization steps
"""
super().__init__(N)
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.discrete_sigmas = jnp.exp(np.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.N = N
@property
def T(self):
return 1
def sde(self, x, t):
sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
drift = jnp.zeros_like(x)
diffusion = sigma * jnp.sqrt(2 * (jnp.log(self.sigma_max) - jnp.log(self.sigma_min)))
return drift, diffusion
def marginal_prob(self, x, t):
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
mean = x
return mean, std
def prior_sampling(self, rng, shape):
return jax.random.normal(rng, shape) * self.sigma_max
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
logp_fn = lambda z: -N / 2. * jnp.log(2 * np.pi * self.sigma_max ** 2) - jnp.sum(z ** 2) / (2 * self.sigma_max ** 2)
return jax.vmap(logp_fn)(z)
def discretize(self, x, t):
"""SMLD(NCSN) discretization."""
timestep = (t * (self.N - 1) / self.T).astype(jnp.int32)
sigma = self.discrete_sigmas[timestep]
adjacent_sigma = jnp.where(timestep == 0, jnp.zeros_like(timestep), self.discrete_sigmas[timestep - 1])
f = jnp.zeros_like(x)
G = jnp.sqrt(sigma ** 2 - adjacent_sigma ** 2)
return f, G