-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathrun.py
33 lines (28 loc) · 896 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Example."""
import numpy as np
from loss_landscape_anim import LeNet, MNISTDataModule, loss_landscape_anim
if __name__ == "__main__":
loss_landscape_anim(n_epochs=300)
"""
u_gen = np.random.normal(size=61706)
u = u_gen / np.linalg.norm(u_gen)
v_gen = np.random.normal(size=61706)
v = v_gen / np.linalg.norm(v_gen)
bs = 16
lr = 1e-3
datamodule = MNISTDataModule(batch_size=bs, n_examples=3000)
model = LeNet(learning_rate=lr, num_classes=10) # num_classes = 10 for MNIST, required for Accuracy metric
loss_landscape_anim(
n_epochs=10,
model=model,
datamodule=datamodule,
optimizer="adam",
reduction_method="custom",
custom_directions=(u, v),
giffps=15,
seed=180224,
load_model=False,
output_to_file=True,
gpus=0, # Set to # gpus if available
)
"""