-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
101 lines (79 loc) · 3.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 12 09:15:27 2018
@author: shen1994
"""
import os
import cv2
import numpy as np
import tensorflow as tf
from model import PRNet
from generate import Generator
from predictor import Predictor
from show import show_G
if __name__ == "__main__":
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
model_path = 'model'
if not os.path.exists(model_path):
os.mkdir(model_path)
# define params
batch_size = 32
epocs = 100001
image_shape = (256, 256, 3)
# define model
x, m, y, pos, loss, optimizer = PRNet()
# define predictor
v_predictor = Predictor()
# define generator
train_paths = ['images/300W_LP/AFW_GEN', 'images/300W_LP/HELEN_GEN',
'images/300W_LP/IBUG_GEN', 'images/300W_LP/LFPW_GEN',
'images/300W_LP/AFW_Flip_GEN', 'images/300W_LP/HELEN_Flip_GEN',
'images/300W_LP/IBUG_Flip_GEN', 'images/300W_LP/LFPW_Flip_GEN']
valid_path = 'images/AFLW2000_GEN'
mask_path = 'images/uv_weight_mask2.png'
t_generator = Generator(train_paths=train_paths,
valid_path=valid_path,
mask_path=mask_path,
image_shape=image_shape,
batch_size=batch_size).generate(is_training=True)
v_generator = Generator(train_paths=train_paths,
valid_path=valid_path,
mask_path=mask_path,
image_shape=image_shape,
batch_size=batch_size).generate(is_training=False)
saver = tf.train.Saver()
with tf.Session() as sess:
# initial variables
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
# restore model
try:
ckpt = tf.train.latest_checkpoint(model_path)
saver.restore(sess, ckpt)
except Exception:
print('No existed model to use!')
# train data
step = 0
total_coss = 0
while step < epocs:
x_in, y_in, m_in = t_generator.__next__()
_ = sess.run(optimizer, feed_dict={x: x_in, y: y_in, m: m_in})
total_coss += sess.run(loss, feed_dict={x: x_in, y: y_in, m: m_in})
if step % 100 == 0:
# show total loss
print(str(step) + ": train --->" + "cost:" + str(total_coss))
print("---------------------------------------->")
total_coss = 0
# show keypoints
x_ou = v_generator.__next__()
x_pre_ou = []
for i in range(16):
x_pre_ou.append(v_predictor.predictor(sess, x, pos, x_ou[i]))
show_G(x_ou[:16], np.array(x_pre_ou), 16, "3DFace")
# save model
if step % 1000 == 0 and step != 0:
saver.save(sess, 'model/model%d.ckpt' % step)
# exit programs
if cv2.waitKey(1) == ord('q'):
exit()
step += 1