-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_cms_train_vae_template.py
109 lines (86 loc) · 4.84 KB
/
main_cms_train_vae_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import setGPU
import numpy as np
from collections import namedtuple
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
print('tensorflow version: ', tf.__version__)
tf.config.run_functions_eagerly(True)
#tf.debugging.enable_check_numerics()
from tensorflow.python import debug as tf_debug
import vae.vae_particle as vap
import vae.losses as losses
import case_paths.path_constants.sample_dict_file_parts_input as sdi
import case_paths.util.experiment as expe
import case_paths.util.sample_factory as safa
import util.data_generator as dage
import case_readers.data_reader as dare
import case_paths.phase_space.cut_constants as cuts
import training as tra
import random
import sys
seed = int(sys.argv[1])
def set_seeds(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
tf.random.set_seed(seed)
np.random.seed(seed)
set_seeds(seed)
# ********************************************************
# runtime params
# ********************************************************
Parameters = namedtuple('Parameters', 'run_n input_shape kernel_sz kernel_ini_n beta epochs train_total_n gen_part_n valid_total_n batch_n z_sz activation initializer learning_rate max_lr_decay lambda_reg')
params = Parameters(run_n=seed,
input_shape=(100,3),
kernel_sz=(1,3),
kernel_ini_n=12,
beta=0.5,
epochs=20,
train_total_n=int(2e6),
valid_total_n=int(5e5),
gen_part_n=int(5e5),
batch_n=128,
z_sz=12,
activation='elu',
initializer='he_uniform',
learning_rate=0.001,
max_lr_decay=8,
lambda_reg=0.1) # 'L1L2'
experiment = expe.Experiment(params.run_n).setup(model_dir=True, fig_dir=True)
paths = safa.SamplePathDirFactory(sdi.path_dict)
# ********************************************************
# prepare training (generator) and validation data
# ********************************************************
# train (generator)
print('>>> Preparing training dataset generator')
data_train_generator = dage.CaseDataGenerator(path=paths.sample_dir_path('qcdSideTrain'), sample_part_n=params.gen_part_n, sample_max_n=params.train_total_n, **cuts.global_cuts) # generate 10 M jet samples
train_ds = tf.data.Dataset.from_generator(data_train_generator, output_types=tf.float32, output_shapes=params.input_shape).batch(params.batch_n, drop_remainder=True) # already shuffled
# validation (full tensor, 1M events -> 2M samples)
print('>>> Preparing validation dataset')
print(paths.sample_dir_path('qcdSideTest'))
const_valid, _, features_valid, _, truth_valid = dare.CaseDataReader(path=paths.sample_dir_path('qcdSideTest')).read_events_from_dir(max_n=params.valid_total_n, **cuts.global_cuts)
data_valid = dage.events_to_input_samples(const_valid, features_valid)
valid_ds = tf.data.Dataset.from_tensor_slices(data_valid).batch(params.batch_n, drop_remainder=True)
# stats for normalization layer
mean_stdev = data_train_generator.get_mean_and_stdev()
print("Printing mean and stdev")
print(mean_stdev)
print("wow")
# *******************************************************
# training options
# *******************************************************
print('>>> Preparing optimizer')
optimizer = tf.keras.optimizers.Adam(learning_rate=params.learning_rate)
loss_fn = losses.threeD_loss
# *******************************************************
# build model
# *******************************************************
print('>>> Building model')
vae = vap.VAEparticle(input_shape=params.input_shape, z_sz=params.z_sz, kernel_ini_n=params.kernel_ini_n, kernel_sz=params.kernel_sz, activation=params.activation, initializer=params.initializer, beta=params.beta)
vae.build(mean_stdev)
# ******************************************************* # train and save # *******************************************************
print('>>> Launching Training')
trainer = tra.Trainer(optimizer=optimizer, beta=params.beta, patience=3, min_delta=0.03, max_lr_decay=params.max_lr_decay, lambda_reg=params.lambda_reg, annealing=False, datalength=params.train_total_n, batchsize=params.batch_n)
losses_reco, losses_valid = trainer.train(vae=vae, loss_fn=loss_fn, train_ds=train_ds, valid_ds=valid_ds, epochs=params.epochs, model_dir=experiment.model_dir)
tra.plot_training_results(losses_reco, losses_valid, experiment.fig_dir)
vae.save(path=experiment.model_dir)