-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathartificial_data.py
46 lines (42 loc) · 2.43 KB
/
artificial_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
from model import GIN
parser = argparse.ArgumentParser(description='Artificial data experiments (2d data embedded in 10d) with GIN.')
parser.add_argument('--n_clusters', type=int, default=5,
help='Number of components in gaussian mixture (default 5)')
parser.add_argument('--n_data_points', type=int, default=10000,
help='Number of data points in artificial data set (default 10000)')
parser.add_argument('--n_epochs', type=int, default=80,
help='Number of training epochs (default 80)')
parser.add_argument('--epochs_per_line', type=int, default=5,
help='Print a new line after this many epochs (default 5)')
parser.add_argument('--lr', type=float, default=1e-2,
help='Learn rate (default 1e-2)')
parser.add_argument('--lr_schedule', nargs='+', type=int, default=[50],
help='Learn rate schedule (decrease lr by factor of 10 at these epochs, default [50]). \
Usage example: --lr_schedule 20 40')
parser.add_argument('--batch_size', type=int, default=1000,
help='Batch size (default 1000)')
parser.add_argument('--incompressible_flow', type=int, default=1,
help='Use an incompressible flow (GIN) (1, default) or compressible flow (GLOW) (0)')
parser.add_argument('--empirical_vars', type=int, default=1,
help='Estimate empirical variables (means and stds) for each batch (1, default) or learn them along \
with model weights (0)')
parser.add_argument('--init_identity', type=int, default=1,
help='Initialize the network as the identity (1, default) or not (0)')
args = parser.parse_args()
assert args.incompressible_flow in [0,1], 'Argument should be 0 or 1'
assert args.empirical_vars in [0,1], 'Argument should be 0 or 1'
assert args.init_identity in [0,1], 'Argument should be 0 or 1'
model = GIN(dataset='10d',
n_classes=args.n_clusters,
n_data_points=args.n_data_points,
n_epochs=args.n_epochs,
epochs_per_line=args.epochs_per_line,
lr=args.lr,
lr_schedule=args.lr_schedule,
batch_size=args.batch_size,
incompressible_flow=args.incompressible_flow,
empirical_vars=args.empirical_vars,
init_identity=args.init_identity,
save_frequency=args.n_epochs)
model.train_model()