-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinverse_nn.py
68 lines (52 loc) · 1.74 KB
/
inverse_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch as th
import torch.nn as nn
from numpy import allclose
from tqdm import tqdm
th.manual_seed(1337)
def print(*xs):
tqdm.write(" ".join([str(x) for x in xs]))
def inv_tanh(x):
"""The inverse of the tanh function."""
return (th.log1p(x) - th.log1p(-x)) / 2
# return (1 / 2) * th.log((1 + x) / (1 - x))
def invert_our_diffeomorphism(sequential):
"""Return inverse function of W2 @ tanh(W1 @ x + b1) + b2"""
def unlinear(y, linear_layer):
bias = linear_layer.bias
weight = linear_layer.weight
return (y - bias) @ weight.inverse().t()
def my_function(y):
# print("y: ", y)
for i, layer in enumerate(reversed(sequential)):
if isinstance(layer, th.nn.Linear):
y = unlinear(y, layer)
elif isinstance(layer, th.nn.Tanh):
y = inv_tanh(y)
else:
raise ValueError("What kind of diffeomorphism did you make?!")
assert not th.isnan(y).any()
# print(f"l{i}: ", y)
x = y
# print("x: ", x)
return x
return my_function
def main():
dim = 3
diffeomorphism = nn.Sequential(
nn.Linear(dim, dim),
nn.Tanh(),
nn.Linear(dim, dim)
)
nn.init.normal_(diffeomorphism[-1].weight)
nn.init.normal_(diffeomorphism[-1].bias)
weight = diffeomorphism[-1].weight
bias = diffeomorphism[-1].bias
inverted = invert_our_diffeomorphism(diffeomorphism)
data = th.rand(dim, dim)
# print(data)
print(diffeomorphism(data))
print(th.addmm(bias, nn.functional.tanh(data), weight.t()))
assert allclose(data.numpy(),
inverted(diffeomorphism(data)).detach().numpy())
if __name__ == '__main__':
main()