forked from jjcheer/ocrcn_tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_simple.py
executable file
·120 lines (101 loc) · 4.21 KB
/
train_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
training a simple net on Chinese Characters classification dataset
we got about 90% accuracy by simply applying a simple CNN net
"""
from alfred.dl.tf.common import mute_tf
mute_tf()
import os
import sys
import numpy as np
import tensorflow as tf
from alfred.utils.log import logger as logging
import tensorflow_datasets as tfds
from dataset.casia_hwdb import load_ds, load_characters, load_val_ds
from models.cnn_net import CNNNet, build_net_002, build_net_003
target_size = 64
num_classes = 7356
# use_keras_fit = False
use_keras_fit = True
ckpt_path = './checkpoints/cn_ocr-{epoch}.ckpt'
def preprocess(x):
"""
minus mean pixel or normalize?
"""
# original is 64x64, add a channel dim
x['image'] = tf.expand_dims(x['image'], axis=-1)
x['image'] = tf.image.resize(x['image'], (target_size, target_size))
x['image'] = (x['image'] - 128.) / 128.
return x['image'], x['label']
def train():
all_characters = load_characters()
num_classes = len(all_characters)
logging.info('all characters: {}'.format(num_classes))
train_dataset = load_ds()
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()
val_ds = load_val_ds()
val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()
for data in train_dataset.take(2):
print(data)
# init model
model = build_net_003((64, 64, 1), num_classes)
model.summary()
logging.info('model loaded.')
start_epoch = 0
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
if latest_ckpt:
start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
model.load_weights(latest_ckpt)
logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
else:
logging.info('passing resume since weights not there. training from scratch')
if use_keras_fit:
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
callbacks = [
tf.keras.callbacks.ModelCheckpoint(ckpt_path,
save_weights_only=True,
verbose=1,
period=500)
]
try:
model.fit(
train_dataset,
validation_data=val_ds,
validation_steps=1000,
epochs=15000,
steps_per_epoch=1024,
callbacks=callbacks)
except KeyboardInterrupt:
model.save_weights(ckpt_path.format(epoch=0))
logging.info('keras model saved.')
model.save_weights(ckpt_path.format(epoch=0))
model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))
else:
loss_fn = tf.losses.SparseCategoricalCrossentropy()
optimizer = tf.optimizers.RMSprop()
train_loss = tf.metrics.Mean(name='train_loss')
train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
for epoch in range(start_epoch, 120):
try:
for batch, data in enumerate(train_dataset):
# images, labels = data['image'], data['label']
images, labels = data
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
if batch % 10 == 0:
logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
epoch, batch, train_loss.result(), train_accuracy.result()))
except KeyboardInterrupt:
logging.info('interrupted.')
model.save_weights(ckpt_path.format(epoch=epoch))
logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch)))
exit(0)
if __name__ == "__main__":
train()