-
Notifications
You must be signed in to change notification settings - Fork 0
/
paper.tex
350 lines (272 loc) · 14.3 KB
/
paper.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
\documentclass{article}
\title{Bijective networks}
\author{Arthur Breitman}
\date{\today}
\usepackage{amsfonts}
\usepackage{amsmath}
\begin{document}
\maketitle
\begin{abstract}
We introduce bijective networks -- thus named because they form a bijection
between their input and output space -- as a means of parametrizing
multivariate probability distributions. We present applications to probability
density estimation, and variational inference. In these networks, the quantity
of interest is primarily the determinant of the Jacobian of the network which
we interpret as a probability density. We also describe a mechanism to
efficiently train bijective networks.
\end{abstract}
\section{Bijective networks}
\subsection{Definition}
We define a bijective network as a triplet
\((\mathcal{I}, \mathcal{O}, f : \mathcal{I} \mapsto \mathcal{O})\)
where \(f\) is a diffeomorphism. We use the term ``bijective networks''
because ``diffeomorphic networks'' doesn't quite have the same ring to it,
and the differentiability is generally implied in neural networks\footnote{The
requirement can also be weakened to accomodate activation functions like
\textrm{Relu} which are differentiable almost everywhere.}.
Throughout this paper, differentiability is assumed and we use ``bijective''
and ``diffeomorphic'' interchangeably; we also refer to the determinant of the
Jacobian matrix of that diffeomorphism simply as ``the Jacobian'' and the
logarithm of that determinant as ``the log-Jacobian''.
\subsection{Motivation}
A rich, parametric, model of arbitrary probability distributions is a useful
building block for many statistical learning algorithms such as variational
inference or density estimation.
\(f\) can be interpreted as a pull-back measure, with
\(\left|J_f\right| = \frac{d \mathcal{O}}{d \mathcal{I}}\), and
\(\left|J_f\right|^{-1}\) can be interpreted as a probability distribution
over \(\mathcal{I}\). If \(\mathcal{O}\) is a hypercube, \(f^{-1}\) is,
essentially, a copula.
\subsection{Construction}
\subsubsection{Layers \& activation functions}
Concretely, we implement bijective networks as a composition of fully
connected layers of \(n\) neurons each. Note that \(n\) is the dimension of
the input layer as well as that of the output layer. If the problem's
dimensionality is too low and impedes on the expressivity of the network,
it's possible to increase \(n\) by introducing dummy, independent, variables,
and marginalizing over them later.
Hidden layers take their values in \(\mathbb{R}^n\) and the output layer in
\((-1,1)^n\).
The input layer can take its value in a variety of domains,
provided they are diffeomorphic to \(\mathbb{R}^n\). These could be cartesian
products of \(\mathbb{R}\) and intervals, where intervals are mapped to
\(\mathbb{R}\) via an affine transform and \(\mathrm{arctanh}\) or the probit
function, but they could easily be spheres, simplices, etc.
In what follows we assume the input layer takes its values in \(\mathbb{R}^n\),
without a loss of generality. For the sake of concretness, we will specify the
non-linear activation functions used in the construction, but other appropriate
functions (e.g. leaky Relu) may of course be substituted.
We use \(\mathrm{arsinh}\), as an unbounded close cousin of the traditional
sigmoid. Since its asymptotic behavior is logarithmic it can easily lead to
problematic vanishing gradients, therefore, we alternate with using
\(\mathrm{sinh}\) on every other layer, though careful initialization might
solve that problem by itself. The final output layer squashes the values to
\((-1,1)^n\) by applying \(\mathrm{tanh}\). All weight matrices are,
and remain, invertible. As a finite composition of diffeomorphisms the network
is, itself, a diffeomorphism.
\subsubsection{Notation}
The input layer is represented as vector \(x\), the \(2m\) hidden layers as
\(h_i, \in 0 \ldots 2m-1 \) the ouput layer is denoted \(y'\). In general,
if \(l\) is a layer, \(l'\) designates the ``activated'' layer, which has been
passed through the activation function.
All of these vectors have dimension \(n\). For ease of notation, we also let
\(h'_{-1} = x\)
\[
\left\{
\begin{aligned}
h_{i} &= W_{i} \cdot h'_{i-1} + b_{i}, \forall i \in 0 \ldots 2m-1 \\
h'_{2k} &= \mathrm{arsinh}(h_{2k}), \forall k \in 0 \ldots m-1 \\
h'_{2k+1} &= \mathrm{sinh}(h_{2k+1}), \forall k \in 0 \ldots m-1\\
y &= W_{2m} \cdot h'_{2m} + b_{2m} \\
y' &= \mathrm{tanh}(y) \\
\end{aligned}
\right.
\]
\subsection{Computing}
\subsubsection{Computing the Jacobian}
The Jacobian of the network is the product of the Jacobians of each
layer. The Jacobian of the transformation \( u \mapsto W u + b \) is simply
\(W\), the Jacobian of the coordinate-wise non-linear transforms is simply
the diagonal matrix with the derivative of the activation function at the
matching coordinates.
More specifically we are interested in the log of the determinant of the
Jacobian matrix (the log Jacobian). The log Jacobian of the network we defined
is exacly
\[
J(x) = \sum_{i=0}^{2m} \log |W_{i}| +
\sum_{j=1}^n \log(1-{y'_j}^2)
+ \sum_{k=0}^{m-1} \sum_{j=1}^n
\frac{1}{2} \log \left(\frac{1 + (h'_{2k+1})_j^2}{1 + (h_{2k})_j^2}\right)
\]
\subsubsection{Breaking up matrices}
We are interested in computing the gradient of the log Jacobian with respect to
our parameters, that is the weight matrices and bias vectors. The chain rule
can be used, as in regular neural networks, with the only peculiarity that we
need to take derivatives of the log determinant of matrices with respect to
themselves.
This derivative exists in closed form through identity
\(\frac{d \log |W|}{d W} = (W^{-1})^\top\). It looks as
though we might need to compute the inverse of the weight matrix, a costly
operation\footnote{\(\mathcal{O}(n^{2.4\ldots})\), for the optimist but
\(\mathcal{O}(n^3)\) for the realist.}
each time we compute a gradient. Fortunately, we can, without a loss of
generality, represent \(W\) as the product of \(n(n-1)\) matrices that each
operate on just two coordinates. To see why that is the case, consider that
the gaussian elimination algorithm expresses the inverse of a matrix as the
product of such elementary matrices.
Concretely, instead of working with a full representation of the matrix \(W\) we
maintain a list of \(n(n-1)\) \((2 \times 2)\) matrices. The time complexity of
the matrix vector product remains \(\mathcal{O}(n^2)\), as if we were
multiplying directly by \(W\), but all other operations
(inverse, gradient updates, computation of the log Jacobian) can now
also be computed in \(\mathcal{O}(n^2)\).
One downside is losing optimized BLAS methods for matrix-vector products.
Additionaly, large batch sizes cannot benefit from Strassen multiplication
with this representation.
TODO: explicitely derive gradients
\section{Application to density estimation}
Assume a sample of \(N\) points, each in \(\mathcal{R}^n\). We are interested in
finding a probability distribution to fit that sample. Since a dirac-comb,
the empirical distribution, will trivially do the job, we need some sort of
regularization. In this case, we try to represent that probability distribution
with a bijective network and the regularization comes from the limited
capacity of the network and the use of the stochastic-gradient descent
algorithm with early stopping.
One way to look at the problem of density optimization is through the lens of
the minimum description length principle. Suppose that, to compress the
sample, we first pass it through an invertible function, and compress the
result by reducing the precision and truncating the output.
The more we can truncate the output without losing our ability to recover the
input within a certain tolerance, the better the compression.
Therefore, we would like for that function to blow up its input in the
regions around the sample points, and consequently to contract it in other
regions, since the image of the input domain is a hypercube of fixed volume.
This is equivalent to saying the sum of the log Jacobian, taken over the
antecedant of the sample, should be maximized.
The Jacobian (divided by \(2^n\)) is a probability distribution
over the input space which smoothly approximates the empirical distribution.
We are essentially training the network to approximate a copula which evenly
spreads out the sample over a hypercube.
\section{Application to variational inference}
\subsection{Variational inference}
In this setup, we are given a prior \(P(Z)\) over some latent variable \(Z\),
a generative model \(P(X | Z)\) and a sample \(\overline{X}\). We would like to
estimate \(P(Z | \overline{X})\). Markov chain Monte Carlo techniques let us
sample from \(P(Z | \overline{X})\), but they tend to be slow,
hard to diagnose, an can suffer from poor convergence.
A popular approximate technique is variational inference. In this model, a
parametric distribution \(Q_{\theta}\) is optimized to minimize a lower bound
on the KL divergence between \(Q_{\theta}\) and \(P(Z | X)\)
\[
\int Q_{\theta}(Z) \log \frac{Q_{\theta}(Z)}{P(Z, X)} \mathrm{d}Z
\]
Typically, \(Q_{\theta}\) is taken to have a simple parametric form that factors
into a product of independent distributions over several dimensions which
renders the integration tractable and lends itself to fast optimization
algorithm.
However,\(Q_{\theta}\) is oftentimes a rather poor approximation of
\(P(Z | X)\) because the parametric family is too rigid to represent
distributions close to the shape of the true posterior.
\subsection{Variational inference, with bijective networks}
The Jacobian of a bijective network (divided by \(2^n\)
is a probability distribution. Its parametric representation is rich
enough to fit complex posteriors. What's more, it is easy to sample from it,
by drawing an output value uniformly at random, and computing its antecedant.
To our knowledge, no other model offers:
\begin{itemize}
\item A rich parametrization capable of representing any probability
distribution
\item Efficient sampling
\item Knowlege of the partition function (the constant \(2^n\) in our case)
\end{itemize}
The quantity we are wish to minimize is
\[
\mathcal{L} = \mathbb{E}_{Q(Z)} \left(\log \frac{Q(Z)}{P(Z,X)} \right)
\]
By performing the change of variable $U = f(Z)$
\[
\begin{aligned}
\mathcal{L} &= \mathbb{E}_{U} \left(\frac{Q(Z)}{|J_{f}(Z)|}
\log \frac{Q(f^{-1}(U))}{P(f^{-1}(Z),X)}\right)\\
&= \int \log \frac{Q(f^{-1}(U))}{P(f^{-1}(U),X)} \mathrm{d}U
\end{aligned}
\]
The latest quantity can be seen as a sum over $U$ which we can minimize using
the stochastic gradient descent algorithm. Our algorithm, broadly therefore
consists of:
\begin{enumerate}
\item Draw \(U_i\) uniformly, at random, in \((-1,1)^n\)
\item Let \(Z_i = f_{\theta}^{-1}(U_i)\) and \(Q_\theta(Z_i) = J_{f_\theta}(Z_i)\)
\item Take a gradient step to minimize \(\log \frac{Q_\theta(Z_i)}{P(Z_i,X)}\)
\end{enumerate}
At a first glance, it might seem that \(P\) doesn't appear in the derivative
of the cost function with respect to the network weights, which means the
algorithm would have no chance of working! In fact, the dependency on
\(P\) appears when we take into account the fact that \(Z_i\) depends the
network weights (it's essentially the reparametrization trick).
Crucially, this means that \(P(Z,X)\) must be differentiable with respect to
\(Z\).
Note that:
\begin{itemize}
\item We approximate \(\mathcal{L}\) by sampling through a \emph{fixed} distribution.
\item There is no overfitting! The lower \(\mathcal{L}\) the better the approximation
of the true posterior. Therefore, it's OK to build very large network and to run
stochastic gradient descent all the way.
\end{itemize}
\subsection{What if \(P\) isn't differentiable?}
If \(P\) isn't differentiable, we can use a variant of the previous algorithm,
though convergence is harder to prove.
Let \(Q_{\theta_i}\) be the distribution \(Q_{\theta}\) at timestep \(i\).
We use \(Q_{\theta_i}\) as a proposal distribution to estimate \(\mathcal{L}\)
using Monte-Carlo integration. Therefore, we draw \(Z_{i+1}\) from
\(Q_{\theta_i}\) and then take a gradient step to minimize
\[
\frac{Q_{\theta_{i+1}}(Z_{i+1})}{Q_{\theta_i}(Z_{i+1})} \log
\frac{Q_{\theta_{i+1}}(Z_{i+1})}{P(Z_{i+1},X)}
\]
Note that both \(Z_{i+1}\) and \(Q_{\theta_i}(Z_{i+1})\) are \emph{constant}
with respect to \(\theta_{i+1}\) the parameter being optimized. The derivative
with respect to \(\theta_{i+1}\) yields:
\[
\frac{1}{Q_{\theta_i}(Z_{i+1})}
\frac{\partial Q_{\theta_{i+1}}(Z_{i+1})}{\partial \theta_{i+1}}
\left( \log \frac{Q_{\theta_{i+1}}(Z_{i+1})}{P(Z_{i+1},X)} + 1 \right)
\]
The sign changes depending on whether \(Q_{\theta_{i+1}}(Z_{i+1})\) is greater
or smaller than \(P(Z_{t+1},X) \exp(1)\). That \(e\) factor is hard to explain
and may suggest a problem with this approach. It may be problematic that
the proposal distribution keeps changing over time.
\subsection{Switching the KL divergence}
In fact, since we are not attempting to find closed form formulas for
integrating over \(Q(Z)\), we do not have to use the Variational inference
trick of minimizing \(D_{KL}(Q(Z)||P(Z|\overline{X}))\) and we can tackle
\(D_{KL}(P(Z|\overline{X})||Q(Z))\) which is a more natural metric. We note that
\[
\begin{aligned}
& D_{KL}(P(Z|\overline{X})|| Q(Z) \\
&= \int_Z P(z|\overline{X}) \log \frac{P(z|\overline{X})}{Q(z)} \mathrm{d}z \\
&= \int_Z Q(z) \frac{P(z|\overline{X})}{Q(z)}
\log \frac{P(z|\overline{X})}{Q(z)} \mathrm{d}z \\
&= \int_Z Q(z) \frac{P(\overline{X}, z)}{P(\overline{X})Q(z)}
\log \frac{P(\overline{X}, z)}{P(\overline{X})Q(z)} \mathrm{d}z \\
&= \int\frac{P(\overline{X}, f^{-1}(u))}{
P(\overline{X})Q(f^{-1}(u))}
\log \frac{P(\overline{X}, f^{-1}(u))}{Q(f^{-1}(u))} \mathrm{d}u
- \int P(z|\overline{X}) \log P(\overline{X}) \mathrm{d}z
\end{aligned}
\]
We can minimize this metric by repeatedly sampling \(z_i\) from \(Q\),
as before, and minimizing
\[
\frac{P(\overline{X}, z_i)}{Q(z_i)} \log \frac{P(\overline{X}, z_i)}{Q(z_i)}
\]
instead of \(\log(Q(z_i)/P(z_i, X))\)
TODO: The gradient is 0 for \(Q(z_i) = e P(z_i)\) something's weird this is wrong
\section{Empirical results}
Test this thing
\section{Conclusion}
Bijective neural networks can be trained to fit distributions instead of
functions, and they have a backpropagation algorithm with the same asymptotic
complexity. This opens interesting possibilities for Bayesian inference and
density estimation.
\end{document}