-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathmain.py
113 lines (90 loc) · 3.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright © 2023-2024 Apple Inc.
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from flows import RealNVP
from sklearn import datasets, preprocessing
from tqdm import trange
def get_moons_dataset(n_samples=100_000, noise=0.06):
"""Get two moons dataset with given noise level."""
x, _ = datasets.make_moons(n_samples=n_samples, noise=noise)
scaler = preprocessing.StandardScaler()
x = scaler.fit_transform(x)
return x
def main(args):
x = get_moons_dataset(n_samples=100_000, noise=args.noise)
model = RealNVP(args.n_transforms, args.d_params, args.d_hidden, args.n_layers)
mx.eval(model.parameters())
def loss_fn(model, x):
return -mx.mean(model(x))
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.Adam(learning_rate=args.learning_rate)
with trange(args.n_steps) as steps:
for step in steps:
idx = np.random.choice(x.shape[0], replace=False, size=args.n_batch)
loss, grads = loss_and_grad_fn(model, mx.array(x[idx]))
optimizer.update(model, grads)
mx.eval(model.parameters())
steps.set_postfix(val=loss)
# Plot samples from trained flow
fig, axs = plt.subplots(1, args.n_transforms + 2, figsize=(26, 4))
cmap = plt.get_cmap("Blues")
bins = 100
# Sample from intermediate flow-transformed distributions
for n_transforms in range(args.n_transforms + 1):
x_samples = model.sample((100_000, 2), n_transforms=n_transforms)
axs[n_transforms].hist2d(x_samples[:, 0], x_samples[:, 1], bins=bins, cmap=cmap)
axs[n_transforms].set_xlim(-2, 2)
axs[n_transforms].set_ylim(-2, 2)
axs[n_transforms].set_title(
f"{n_transforms} transforms" if n_transforms > 0 else "Base distribution"
)
axs[n_transforms].set_xticklabels([])
axs[n_transforms].set_yticklabels([])
# Plot original data
axs[-1].hist2d(x[:, 0], x[:, 1], bins=bins, cmap=cmap)
axs[-1].set_xlim(-2, 2)
axs[-1].set_ylim(-2, 2)
axs[-1].set_title("Original data")
axs[-1].set_xticklabels([])
axs[-1].set_yticklabels([])
plt.tight_layout()
plt.savefig("samples.png")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_steps", type=int, default=5_000, help="Number of steps to train"
)
parser.add_argument("--n_batch", type=int, default=64, help="Batch size")
parser.add_argument(
"--n_transforms", type=int, default=6, help="Number of flow transforms"
)
parser.add_argument(
"--d_params", type=int, default=2, help="Dimensionality of modeled distribution"
)
parser.add_argument(
"--d_hidden",
type=int,
default=128,
help="Hidden dimensionality of coupling conditioner",
)
parser.add_argument(
"--n_layers",
type=int,
default=4,
help="Number of layers in coupling conditioner",
)
parser.add_argument(
"--learning_rate", type=float, default=3e-4, help="Learning rate"
)
parser.add_argument(
"--noise", type=float, default=0.06, help="Noise level in two moons dataset"
)
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.cpu)
main(args)