-
Notifications
You must be signed in to change notification settings - Fork 205
/
controllable_generation.py
301 lines (249 loc) · 13.5 KB
/
controllable_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
from models import utils as mutils
import jax.numpy as jnp
import jax
import jax.random as random
from sampling import NoneCorrector, NonePredictor, shared_corrector_update_fn, shared_predictor_update_fn
import functools
from utils import batch_mul
def get_pc_inpainter(sde, model, predictor, corrector, inverse_scaler, snr,
n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-5):
"""Create an image inpainting function that uses PC samplers.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A `flax.linen.Module` object that represents the architecture of the score-based model.
predictor: A subclass of `sampling.Predictor` that represents a predictor algorithm.
corrector: A subclass of `sampling.Corrector` that represents a corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for the corrector.
n_steps: An integer. The number of corrector steps per update of the corrector.
probability_flow: If `True`, predictor solves the probability flow ODE for sampling.
continuous: `True` indicates that the score-based model was trained with continuous time.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The reverse-time SDE/ODE is integrated to `eps` for numerical stability.
Returns:
A pmapped inpainting function.
"""
# Define predictor & corrector
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
model=model,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
model=model,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def get_inpaint_update_fn(update_fn):
"""Modify the update function of predictor & corrector to incorporate data information."""
def inpaint_update_fn(rng, state, data, mask, x, t):
rng, step_rng = jax.random.split(rng)
vec_t = jnp.ones(data.shape[0]) * t
x, x_mean = update_fn(step_rng, state, x, vec_t)
masked_data_mean, std = sde.marginal_prob(data, vec_t)
masked_data = masked_data_mean + batch_mul(jax.random.normal(rng, x.shape), std)
x = x * (1. - mask) + masked_data * mask
x_mean = x * (1. - mask) + masked_data_mean * mask
return x, x_mean
return inpaint_update_fn
projector_inpaint_update_fn = get_inpaint_update_fn(predictor_update_fn)
corrector_inpaint_update_fn = get_inpaint_update_fn(corrector_update_fn)
def pc_inpainter(rng, state, data, mask):
"""Predictor-Corrector (PC) sampler for image inpainting.
Args:
rng: A JAX random state.
state: A `flax.struct.dataclass` object that contains training state.
data: A JAX array that represents a mini-batch of images to inpaint.
mask: A {0, 1} array with the same shape of `data`. Value `1` marks known pixels,
and value `0` marks pixels that require inpainting.
Returns:
Inpainted (complete) images.
"""
# Initial sample
rng, step_rng = random.split(rng)
x = data * mask + sde.prior_sampling(step_rng, data.shape) * (1. - mask)
timesteps = jnp.linspace(sde.T, eps, sde.N)
def loop_body(i, val):
rng, x, x_mean = val
t = timesteps[i]
rng, step_rng = random.split(rng)
x, x_mean = corrector_inpaint_update_fn(step_rng, state, data, mask, x, t)
rng, step_rng = random.split(rng)
x, x_mean = projector_inpaint_update_fn(step_rng, state, data, mask, x, t)
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, sde.N, loop_body, (rng, x, x))
return inverse_scaler(x_mean if denoise else x)
return jax.pmap(pc_inpainter, axis_name='batch')
def get_pc_colorizer(sde, model, predictor, corrector, inverse_scaler,
snr, n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-5):
"""Create a image colorization function based on Predictor-Corrector (PC) sampling.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A `flax.linen.Module` object that represents the architecture of the score model.
predictor: A subclass of `sampling.Predictor` that represents a predictor algorithm.
corrector: A subclass of `sampling.Corrector` that represents a corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for correctors.
n_steps: An integer. The number of corrector steps per update of the predictor.
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
continuous: `True` indicates that the score-based model was trained with continuous time steps.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The SDE/ODE will start from `eps` to avoid numerical stabilities.
Returns: A pmapped colorization function.
"""
# `M` is an orthonormal matrix to decouple image space to a latent space where the gray-scale image
# occupies a separate channel
M = jnp.asarray([[5.7735014e-01, -8.1649649e-01, 4.7008697e-08],
[5.7735026e-01, 4.0824834e-01, 7.0710671e-01],
[5.7735026e-01, 4.0824822e-01, -7.0710683e-01]])
# M = jnp.asarray([[0.28361226, 0.95408977, 0.09631611],
# [0.95408977, -0.29083936, 0.07159032],
# [0.09631611, 0.07159032, -0.9927729]])
# `invM` is the inverse transformation of `M`
invM = jnp.linalg.inv(M)
# Decouple a gray-scale image with `M`
def decouple(inputs):
return jnp.einsum('BHWi,ij->BHWj', inputs, M)
# The inverse function to `decouple`.
def couple(inputs):
return jnp.einsum('BHWi,ij->BHWj', inputs, invM)
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
model=model,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
model=model,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps)
def get_colorization_update_fn(update_fn):
"Modify update functions of predictor & corrector to incorporate information of gray-scale images."
def colorization_update_fn(rng, state, gray_scale_img, x, t):
mask = get_mask(x)
rng, step_rng = jax.random.split(rng)
vec_t = jnp.ones(x.shape[0]) * t
x, x_mean = update_fn(step_rng, state, x, vec_t)
masked_data_mean, std = sde.marginal_prob(decouple(gray_scale_img), vec_t)
masked_data = masked_data_mean + batch_mul(jax.random.normal(rng, x.shape), std)
x = couple(decouple(x) * (1. - mask) + masked_data * mask)
x_mean = couple(decouple(x) * (1. - mask) + masked_data_mean * mask)
return x, x_mean
return colorization_update_fn
def get_mask(image):
mask = jnp.concatenate([jnp.ones_like(image[..., :1]),
jnp.zeros_like(image[..., 1:])], axis=-1)
return mask
predictor_colorize_update_fn = get_colorization_update_fn(predictor_update_fn)
corrector_colorize_update_fn = get_colorization_update_fn(corrector_update_fn)
def pc_colorizer(rng, state, gray_scale_img):
"""Colorize gray-scale images using Predictor-Corrector (PC) sampler.
Args:
rng: A JAX random state.
state: A `flax.struct.dataclass` object that represents the training state.
gray_scale_img: A minibatch of gray-scale images. Their R,G,B channels have same values.
Returns:
Colorized images.
"""
shape = gray_scale_img.shape
mask = get_mask(gray_scale_img)
# Initial sample
rng, step_rng = random.split(rng)
x = couple(decouple(gray_scale_img) * mask + \
decouple(sde.prior_sampling(step_rng, shape) * (1. - mask)))
timesteps = jnp.linspace(sde.T, eps, sde.N)
def loop_body(i, val):
rng, x, x_mean = val
t = timesteps[i]
rng, step_rng = random.split(rng)
x, x_mean = corrector_colorize_update_fn(step_rng, state, gray_scale_img, x, t)
rng, step_rng = random.split(rng)
x, x_mean = predictor_colorize_update_fn(step_rng, state, gray_scale_img, x, t)
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, sde.N, loop_body, (rng, x, x))
return inverse_scaler(x_mean if denoise else x)
return jax.pmap(pc_colorizer, axis_name='batch')
def get_pc_conditional_sampler(sde, score_model, classifier, classifier_params, shape,
predictor, corrector, inverse_scaler, snr,
n_steps=1, probability_flow=False,
continuous=False, denoise=True, eps=1e-5):
"""Class-conditional sampling with Predictor-Corrector (PC) samplers.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
score_model: A `flax.linen.Module` object that represents the architecture of the score-based model.
classifier: A `flax.linen.Module` object that represents the architecture of the noise-dependent classifier.
classifier_params: A dictionary that contains the weights of the classifier.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.predictor` that represents a predictor algorithm.
corrector: A subclass of `sampling.corrector` that represents a corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for correctors.
n_steps: An integer. The number of corrector steps per update of the predictor.
probability_flow: If `True`, solve the probability flow ODE for sampling with the predictor.
continuous: `True` indicates the score-based model was trained with continuous time.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The SDE/ODE will be integrated to `eps` to avoid numerical issues.
Returns: A pmapped class-conditional image sampler.
"""
# A function that gives the logits of the noise-dependent classifier
logit_fn = mutils.get_logit_fn(classifier, classifier_params)
# The gradient function of the noise-dependent classifier
classifier_grad_fn = mutils.get_classifier_grad_fn(logit_fn)
def conditional_predictor_update_fn(rng, state, x, t, labels):
"""The predictor update function for class-conditional sampling."""
score_fn = mutils.get_score_fn(sde, score_model, state.params_ema, state.model_state, train=False,
continuous=continuous)
def total_grad_fn(x, t):
ve_noise_scale = sde.marginal_prob(x, t)[1]
return score_fn(x, t) + classifier_grad_fn(x, ve_noise_scale, labels)
if predictor is None:
predictor_obj = NonePredictor(sde, total_grad_fn, probability_flow)
else:
predictor_obj = predictor(sde, total_grad_fn, probability_flow)
return predictor_obj.update_fn(rng, x, t)
def conditional_corrector_update_fn(rng, state, x, t, labels):
"""The corrector update function for class-conditional sampling."""
score_fn = mutils.get_score_fn(sde, score_model, state.params_ema, state.model_state, train=False,
continuous=continuous)
def total_grad_fn(x, t):
ve_noise_scale = sde.marginal_prob(x, t)[1]
return score_fn(x, t) + classifier_grad_fn(x, ve_noise_scale, labels)
if corrector is None:
corrector_obj = NoneCorrector(sde, total_grad_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, total_grad_fn, snr, n_steps)
return corrector_obj.update_fn(rng, x, t)
def pc_conditional_sampler(rng, score_state, labels):
"""Generate class-conditional samples with Predictor-Corrector (PC) samplers.
Args:
rng: A JAX random state.
score_state: A `flax.struct.dataclass` object that represents the training state
of the score-based model.
labels: A JAX array of integers that represent the target label of each sample.
Returns:
Class-conditional samples.
"""
# Initial sample
rng, step_rng = random.split(rng)
x = sde.prior_sampling(step_rng, shape)
timesteps = jnp.linspace(sde.T, eps, sde.N)
def loop_body(i, val):
rng, x, x_mean = val
t = timesteps[i]
vec_t = jnp.ones(shape[0]) * t
rng, step_rng = random.split(rng)
x, x_mean = conditional_corrector_update_fn(step_rng, score_state, x, vec_t, labels)
rng, step_rng = random.split(rng)
x, x_mean = conditional_predictor_update_fn(step_rng, score_state, x, vec_t, labels)
return rng, x, x_mean
_, x, x_mean = jax.lax.fori_loop(0, sde.N, loop_body, (rng, x, x))
return inverse_scaler(x_mean if denoise else x)
return jax.pmap(pc_conditional_sampler, axis_name='batch')