forked from dvschultz/stylegan2-ada
-
Notifications
You must be signed in to change notification settings - Fork 15
/
adaptive.py
399 lines (356 loc) · 19.3 KB
/
adaptive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
# coding=utf-8
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Implements the adaptive form of the loss.
You should only use this function if 1) you want the loss to change it's shape
during training (otherwise use general.py) or 2) you want to impose the loss on
a wavelet or DCT image representation, a only this function has easy support for
that.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from robust_loss import distribution
from robust_loss import util
from robust_loss import wavelet
def _check_scale(scale_lo, scale_init):
"""Helper function for checking `scale_lo` and `scale_init`."""
if not np.isscalar(scale_lo):
raise ValueError('`scale_lo` must be a scalar, but is of type {}'.format(
type(scale_lo)))
if not np.isscalar(scale_init):
raise ValueError('`scale_init` must be a scalar, but is of type {}'.format(
type(scale_init)))
if not scale_lo > 0:
raise ValueError('`scale_lo` must be > 0, but is {}'.format(scale_lo))
if not scale_init >= scale_lo:
raise ValueError('`scale_init` = {} must be >= `scale_lo` = {}'.format(
scale_init, scale_lo))
def _construct_scale(x, scale_lo, scale_init, float_dtype, var_suffix=''):
"""Helper function for constructing scale variables."""
if scale_lo == scale_init:
# If the difference between the minimum and initial scale is zero, then
# we just fix `scale` to be a constant.
scale = tf.tile(
tf.cast(scale_init, float_dtype)[tf.newaxis, tf.newaxis],
(1, x.shape[1]))
else:
# Otherwise we construct a "latent" scale variable and define `scale`
# As an affine function of a softplus on that latent variable.
latent_scale = tf.get_variable(
'LatentScale' + var_suffix, initializer=tf.zeros((1, x.shape[1]), float_dtype))
scale = util.affine_softplus(latent_scale, lo=scale_lo, ref=scale_init)
return scale
def lossfun(x,
alpha_lo=0.001,
alpha_hi=1.999,
alpha_init=None,
scale_lo=1e-5,
scale_init=1.,
var_suffix='',
**kwargs):
"""Computes the adaptive form of the robust loss on a matrix.
This function behaves differently from general.lossfun() and
distribution.nllfun(), which are "stateless", allow the caller to specify the
shape and scale of the loss, and allow for arbitrary sized inputs. This
function only allows for rank-2 inputs for the residual `x`, and expects that
`x` is of the form [batch_index, dimension_index]. This function then
constructs free parameters (TF variables) that define the alpha and scale
parameters for each dimension of `x`, such that all alphas are in
(`alpha_lo`, `alpha_hi`) and all scales are in (`scale_lo`, Infinity).
The assumption is that `x` is, say, a matrix where x[i,j] corresponds to a
pixel at location j for image i, with the idea being that all pixels at
location j should be modeled with the same shape and scale parameters across
all images in the batch. This function also returns handles to the scale and
shape parameters being optimized over, mostly for debugging and introspection.
If the user wants to fix alpha or scale to be a constant, this can be done by
setting alpha_lo=alpha_hi or scale_lo=scale_init respectively.
Args:
x: The residual for which the loss is being computed. Must be a rank-2
tensor, where the innermost dimension is the batch index, and the
outermost dimension corresponds to different "channels", where this
function will assign each channel its own variable shape (alpha) and scale
parameters that are constructed as TF variables and can be optimized over.
Must be a TF tensor or numpy array of single or double precision floats.
The precision of `x` will determine the precision of the latent variables
used to model scale and alpha internally.
alpha_lo: The lowest possible value for loss's alpha parameters, must be >=
0 and a scalar. Should probably be in (0, 2).
alpha_hi: The highest possible value for loss's alpha parameters, must be >=
alpha_lo and a scalar. Should probably be in (0, 2).
alpha_init: The value that the loss's alpha parameters will be initialized
to, must be in (`alpha_lo`, `alpha_hi`), unless `alpha_lo` == `alpha_hi`
in which case this will be ignored. Defaults to (`alpha_lo` + `alpha_hi`)
/ 2
scale_lo: The lowest possible value for the loss's scale parameters. Must be
> 0 and a scalar. This value may have more of an effect than you think, as
the loss is unbounded as scale approaches zero (say, at a delta function).
scale_init: The initial value used for the loss's scale parameters. This
also defines the zero-point of the latent representation of scales, so SGD
may cause optimization to gravitate towards producing scales near this
value.
**kwargs: Arguments to be passed to the underlying distribution.nllfun().
Returns:
A tuple of the form (`loss`, `alpha`, `scale`).
`loss`: a TF tensor of the same type and shape as input `x`, containing
the loss at each element of `x` as a function of `x`, `alpha`, and
`scale`. These "losses" are actually negative log-likelihoods (as produced
by distribution.nllfun()) and so they are not actually bounded from below
by zero. You'll probably want to minimize their sum or mean.
`scale`: a TF tensor of the same type as x, of size (1, x.shape[1]), as we
construct a scale variable for each dimension of `x` but not for each
batch element. This contains the current estimated scale parameter for
each dimension, and will change during optimization.
`alpha`: a TF tensor of the same type as x, of size (1, x.shape[1]), as we
construct an alpha variable for each dimension of `x` but not for each
batch element. This contains the current estimated alpha parameter for
each dimension, and will change during optimization.
Raises:
ValueError: If any of the arguments are invalid.
"""
_check_scale(scale_lo, scale_init)
if not np.isscalar(alpha_lo):
raise ValueError('`alpha_lo` must be a scalar, but is of type {}'.format(
type(alpha_lo)))
if not np.isscalar(alpha_hi):
raise ValueError('`alpha_hi` must be a scalar, but is of type {}'.format(
type(alpha_hi)))
if alpha_init is not None and not np.isscalar(alpha_init):
raise ValueError(
'`alpha_init` must be None or a scalar, but is of type {}'.format(
type(alpha_init)))
if not alpha_lo >= 0:
raise ValueError('`alpha_lo` must be >= 0, but is {}'.format(alpha_lo))
if not alpha_hi >= alpha_lo:
raise ValueError('`alpha_hi` = {} must be >= `alpha_lo` = {}'.format(
alpha_hi, alpha_lo))
if alpha_init is not None and alpha_lo != alpha_hi:
if not (alpha_init > alpha_lo and alpha_init < alpha_hi):
raise ValueError(
'`alpha_init` = {} must be in (`alpha_lo`, `alpha_hi`) = ({} {})'
.format(alpha_init, alpha_lo, alpha_hi))
float_dtype = x.dtype
assert_ops = [tf.Assert(tf.equal(tf.rank(x), 2), [tf.rank(x)])]
with tf.control_dependencies(assert_ops):
if alpha_lo == alpha_hi:
# If the range of alphas is a single item, then we just fix `alpha` to be
# a constant.
alpha = tf.tile(
tf.cast(alpha_lo, float_dtype)[tf.newaxis, tf.newaxis],
(1, x.shape[1]))
else:
# Otherwise we construct a "latent" alpha variable and define `alpha`
# As an affine function of a sigmoid on that latent variable, initialized
# such that `alpha` starts off as `alpha_init`.
if alpha_init is None:
alpha_init = (alpha_lo + alpha_hi) / 2.
latent_alpha_init = util.inv_affine_sigmoid(
alpha_init, lo=alpha_lo, hi=alpha_hi)
latent_alpha = tf.get_variable(
'LatentAlpha' + var_suffix,
initializer=tf.fill((1, x.shape[1]),
tf.cast(latent_alpha_init, dtype=float_dtype)))
alpha = util.affine_sigmoid(latent_alpha, lo=alpha_lo, hi=alpha_hi)
scale = _construct_scale(x, scale_lo, scale_init, float_dtype, var_suffix=var_suffix)
loss = distribution.nllfun(x, alpha, scale, **kwargs)
return loss, alpha, scale
def lossfun_students(x, scale_lo=1e-5, scale_init=1., var_suffix=''):
"""A variant of lossfun() that uses the NLL of a Student's t-distribution.
Args:
x: The residual for which the loss is being computed. Must be a rank-2
tensor, where the innermost dimension is the batch index, and the
outermost dimension corresponds to different "channels", where this
function will assign each channel its own variable shape (log-df) and
scale parameters that are constructed as TF variables and can be optimized
over. Must be a TF tensor or numpy array of single or double precision
floats. The precision of `x` will determine the precision of the latent
variables used to model scale and log-df internally.
scale_lo: The lowest possible value for the loss's scale parameters. Must be
> 0 and a scalar. This value may have more of an effect than you think, as
the loss is unbounded as scale approaches zero (say, at a delta function).
scale_init: The initial value used for the loss's scale parameters. This
also defines the zero-point of the latent representation of scales, so SGD
may cause optimization to gravitate towards producing scales near this
value.
Returns:
A tuple of the form (`loss`, `log_df`, `scale`).
`loss`: a TF tensor of the same type and shape as input `x`, containing
the loss at each element of `x` as a function of `x`, `log_df`, and
`scale`. These "losses" are actually negative log-likelihoods (as produced
by distribution.nllfun()) and so they are not actually bounded from below
by zero. You'll probably want to minimize their sum or mean.
`scale`: a TF tensor of the same type as x, of size (1, x.shape[1]), as we
construct a scale variable for each dimension of `x` but not for each
batch element. This contains the current estimated scale parameter for
each dimension, and will change during optimization.
`log_df`: a TF tensor of the same type as x, of size (1, x.shape[1]), as we
construct an log-DF variable for each dimension of `x` but not for each
batch element. This contains the current estimated log(degrees-of-freedom)
parameter for each dimension, and will change during optimization.
Raises:
ValueError: If any of the arguments are invalid.
"""
_check_scale(scale_lo, scale_init)
float_dtype = x.dtype
assert_ops = [tf.Assert(tf.equal(tf.rank(x), 2), [tf.rank(x)])]
with tf.control_dependencies(assert_ops):
log_df = tf.get_variable(
name='LogDf', initializer=tf.zeros((1, x.shape[1]), float_dtype))
scale = _construct_scale(x, scale_lo, scale_init, float_dtype, var_suffix=var_suffix)
loss = util.students_t_nll(x, tf.math.exp(log_df), scale)
return loss, log_df, scale
def image_lossfun(x,
color_space='YUV',
representation='CDF9/7',
wavelet_num_levels=5,
wavelet_scale_base=1,
use_students_t=False,
summarize_loss=True,
**kwargs):
"""Computes the adaptive form of the robust loss on a set of images.
This function is a wrapper around lossfun() above. Like lossfun(), this
function is not "stateless" --- it requires inputs of a specific shape and
size, and constructs TF variables describing each non-batch dimension in `x`.
`x` is expected to be the difference between sets of RGB images, and the other
arguments to this function allow for the color space and spatial
representation of `x` to be changed before the loss is imposed. By default,
this function uses a CDF9/7 wavelet decomposition in a YUV color space, which
often works well. This function also returns handles to the scale and
shape parameters (both in the shape of images) being optimized over,
and summarizes both parameters in TensorBoard.
Args:
x: A set of image residuals for which the loss is being computed. Must be a
rank-4 tensor of size (num_batches, width, height, color_channels). This
is assumed to be a set of differences between RGB images.
color_space: The color space that `x` will be transformed into before
computing the loss. Must be 'RGB' (in which case no transformation is
applied) or 'YUV' (in which case we actually use a volume-preserving
scaled YUV colorspace so that log-likelihoods still have meaning, see
util.rgb_to_syuv()). Note that changing this argument does not change the
assumption that `x` is the set of differences between RGB images, it just
changes what color space `x` is converted to from RGB when computing the
loss.
representation: The spatial image representation that `x` will be
transformed into after converting the color space and before computing the
loss. If this is a valid type of wavelet according to
wavelet.generate_filters() then that is what will be used, but we also
support setting this to 'DCT' which applies a 2D DCT to the images, and to
'PIXEL' which applies no transformation to the image, thereby causing the
loss to be imposed directly on pixels.
wavelet_num_levels: If `representation` is a kind of wavelet, this is the
number of levels used when constructing wavelet representations. Otherwise
this is ignored. Should probably be set to as large as possible a value
that is supported by the input resolution, such as that produced by
wavelet.get_max_num_levels().
wavelet_scale_base: If `representation` is a kind of wavelet, this is the
base of the scaling used when constructing wavelet representations.
Otherwise this is ignored. For image_lossfun() to be volume preserving (a
useful property when evaluating generative models) this value must be ==
1. If the goal of this loss isn't proper statistical modeling, then
modifying this value (say, setting it to 0.5 or 2) may significantly
improve performance.
use_students_t: If true, use the NLL of Student's T-distribution instead
of the adaptive loss. This causes all `alpha_*` inputs to be ignored.
summarize_loss: Whether or not to make TF summaries describing the latent
state of the loss function. True by default.
**kwargs: Arguments to be passed to the underlying lossfun().
Returns:
A tuple of the form (`loss`, `alpha`, `scale`). If use_students_t == True,
then `log(df)` is returned instead of `alpha`.
`loss`: a TF tensor of the same type and shape as input `x`, containing
the loss at each element of `x` as a function of `x`, `alpha`, and
`scale`. These "losses" are actually negative log-likelihoods (as produced
by distribution.nllfun()) and so they are not actually bounded from below
by zero. You'll probably want to minimize their sum or mean.
`scale`: a TF tensor of the same type as x, of size
(width, height, color_channels),
as we construct a scale variable for each spatial and color dimension of `x`
but not for each batch element. This contains the current estimated scale
parameter for each dimension, and will change during optimization.
`alpha`: a TF tensor of the same type as x, of size
(width, height, color_channels),
as we construct an alpha variable for each spatial and color dimension of
`x` but not for each batch element. This contains the current estimated
alpha parameter for each dimension, and will change during optimization.
Raises:
ValueError: if `color_space` of `representation` are unsupported color
spaces or image representations, respectively.
"""
color_spaces = ['RGB', 'YUV']
if color_space not in color_spaces:
raise ValueError('`color_space` must be in {}, but is {!r}'.format(
color_spaces, color_space))
representations = wavelet.generate_filters() + ['DCT', 'PIXEL']
if representation not in representations:
raise ValueError('`representation` must be in {}, but is {!r}'.format(
representations, representation))
assert_ops = [tf.Assert(tf.equal(tf.rank(x), 4), [tf.rank(x)])]
with tf.control_dependencies(assert_ops):
if color_space == 'YUV':
x = util.rgb_to_syuv(x)
# If `color_space` == 'RGB', do nothing.
# Reshape `x` from
# (num_batches, width, height, num_channels) to
# (num_batches * num_channels, width, height)
_, width, height, num_channels = x.shape.as_list()
x_stack = tf.reshape(tf.transpose(x, (0, 3, 1, 2)), (-1, width, height))
# Turn each channel in `x_stack` into the spatial representation specified
# by `representation`.
if representation in wavelet.generate_filters():
x_stack = wavelet.flatten(
wavelet.rescale(
wavelet.construct(x_stack, wavelet_num_levels, representation),
wavelet_scale_base))
elif representation == 'DCT':
x_stack = util.image_dct(x_stack)
# If `representation` == 'PIXEL', do nothing.
# Reshape `x_stack` from
# (num_batches * num_channels, width, height) to
# (num_batches, num_channels * width * height)
x_mat = tf.reshape(
tf.transpose(
tf.reshape(x_stack, [-1, num_channels, width, height]),
[0, 2, 3, 1]), [-1, width * height * num_channels])
# Set up the adaptive loss. Note, if `use_students_t` == True then
# `alpha_mat` actually contains "log(df)" values.
if use_students_t:
loss_mat, alpha_mat, scale_mat = lossfun_students(x_mat, **kwargs)
else:
loss_mat, alpha_mat, scale_mat = lossfun(x_mat, **kwargs)
# Reshape the loss function's outputs to have the shapes as the input.
loss = tf.reshape(loss_mat, [-1, width, height, num_channels])
alpha = tf.reshape(alpha_mat, [width, height, num_channels])
scale = tf.reshape(scale_mat, [width, height, num_channels])
if summarize_loss:
# Summarize the `alpha` and `scale` parameters as images (normalized to
# [0, 1]) and histograms.
# Note that these may look unintuitive unless the colorspace is 'RGB' and
# the image representation is 'PIXEL', as the image summaries (like most
# images) are rendered as RGB pixels.
alpha_min = tf.reduce_min(alpha)
alpha_max = tf.reduce_max(alpha)
tf.summary.image(
'robust/alpha',
(alpha[tf.newaxis] - alpha_min) / (alpha_max - alpha_min + 1e-10))
tf.summary.histogram('robust/alpha', alpha)
log_scale = tf.math.log(scale)
log_scale_min = tf.reduce_min(log_scale)
log_scale_max = tf.reduce_max(log_scale)
tf.summary.image('robust/log_scale',
(log_scale[tf.newaxis] - log_scale_min) /
(log_scale_max - log_scale_min + 1e-10))
tf.summary.histogram('robust/log_scale', log_scale)
return loss, alpha, scale