-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
83 lines (57 loc) · 2.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
from funcy import *
from game import *
from pathlib import *
def computational_graph():
def add():
return tf.keras.layers.Add()
def batch_normalization():
return tf.keras.layers.BatchNormalization()
def conv(filter_size, kernel_size=3):
return tf.keras.layers.Conv2D(filter_size, kernel_size, padding='same', use_bias=False, kernel_initializer='he_normal')
def dense(unit_size):
return tf.keras.layers.Dense(unit_size, use_bias=False, kernel_initializer='he_normal')
def global_average_pooling():
return tf.keras.layers.GlobalAveragePooling2D()
def relu():
return tf.keras.layers.ReLU()
####
def residual_block(width):
return rcompose(ljuxt(rcompose(batch_normalization(),
conv(width),
batch_normalization(),
relu(),
conv(width),
batch_normalization()),
identity),
add())
W = 1024
H = 4
return rcompose(conv(W, 1),
rcompose(*repeatedly(partial(residual_block, W), H)),
global_average_pooling(),
dense(1),
relu()) # マイナスの値が出ると面倒な気がするので、ReLUしてみました。
def main():
def create_model():
result = tf.keras.Model(*juxt(identity, computational_graph())(tf.keras.Input(shape=(3, 3, 6 * 6))))
result.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_absolute_error'])
result.summary()
return result
def create_generator(batch_size):
while True:
xs = []
ys = []
for i in range(batch_size):
step = randrange(1, 32)
xs.append(get_x(get_random_state(step)[0]))
ys.append(step)
yield np.array(xs), np.array(ys)
model_path = Path('./model/cost.h5')
model = create_model() if not model_path.exists() else tf.keras.models.load_model(model_path)
model.fit_generator(create_generator(1000), steps_per_epoch=1000, epochs=100)
model_path.parent.mkdir(exist_ok=True)
tf.keras.models.save_model(model, 'model/cost.h5')
tf.keras.backend.clear_session()
if __name__ == '__main__':
main()