-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
154 lines (122 loc) · 5.17 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
This script is used to train the model.
"""
import tensorflow as tf
from src.models.captioner import ImageCaptioningModel
from src.objects.data_generator import DataGenerator
from src.utils.data_utils import DataUtils
# ====================================== PARAMETERS ====================================== #
# path to the folder containing the images
images_folder_path = 'data/training/images'
# path to the file containing the captions (can either be a csv or a txt file)
# It should be structured as follows: image,caption (include the header)
captions_path = 'data/training/captions.csv'
# preprocess function to use.
preprocess_function = tf.keras.applications.xception.preprocess_input
# validation split (percentage of the data used for validation)
val_split = 0
# batch size
# Make sure that your batch size is < the number of samples in both your training and validation datasets for the generators to work properly
batch_size = 20
# epochs
epochs = 30
# image dimensions
# The Xception model works best with 299x299 images, but you can try other sizes as well if you're having memory issues.
# The dimensios should not be below 71
image_dimensions = (192, 192)
# embedding dimension (dimension of the Dense layer in the encoder and the Embedding layer in the decoder)
embedding_dim = 128
# number of units in the LSTM, Bahdanau attention and Dense layers
units = 256
if __name__ == '__main__':
# ====================================== DATA GENERATION ====================================== #
# load the data and unpack it
data_dic = DataUtils.load_training_data(images_folder_path, captions_path,
image_dimensions, preprocess_function)
# unpack it
images_dic = data_dic['images_dic']
captions_dic = data_dic['captions_dic']
importance_features_dic = data_dic['importance_features_dic']
tokenizer = data_dic['tokenizer']
max_caption_length = data_dic['max_caption_length']
# split the data into training and validation sets if val_split > 0
if val_split > 0:
split_dic = DataUtils.train_test_split(
images_dic, importance_features_dic, captions_dic, val_split)
# unpack the split data
train_images_dic = split_dic['train_images_dic']
train_importance_features_dic = split_dic['train_importance_features_dic']
train_captions_dic = split_dic['train_captions_dic']
val_images_dic = split_dic['val_images_dic']
val_importance_features_dic = split_dic['val_importance_features_dic']
val_captions_dic = split_dic['val_captions_dic']
# create the training and validation data generators
train_generator = DataGenerator(
images_dic=train_images_dic,
captions_dic=train_captions_dic,
importance_features_dic=train_importance_features_dic,
batch_size=batch_size
)
val_generator = DataGenerator(
images_dic=val_images_dic,
captions_dic=val_captions_dic,
importance_features_dic=val_importance_features_dic,
batch_size=batch_size
)
else:
# create the training data generator
train_generator = DataGenerator(
images_dic=images_dic,
captions_dic=captions_dic,
importance_features_dic=importance_features_dic,
batch_size=batch_size
)
# free up memory
del images_dic
del captions_dic
del importance_features_dic
# ====================================== MODEL ====================================== #
# create the model
model = ImageCaptioningModel(
tokenizer=tokenizer,
max_length=max_caption_length,
units=units,
embedding_dim=embedding_dim,
image_dimensions=image_dimensions
)
# create the callback
callback = tf.keras.callbacks.EarlyStopping(
monitor='loss', verbose=1, mode='min', patience=5)
# compile the model
# Leave run_eagerly=True because it doesn't work otherwise AND I DON'T KNOW WHY :)
model.compile(optimizer=tf.keras.optimizers.Adam(), run_eagerly=True)
# ====================================== TRAINING ====================================== #
# train the model
while True:
try:
if val_split > 0:
model.fit(
x=train_generator,
epochs=epochs,
validation_data=val_generator,
callbacks=[callback]
)
else:
model.fit(
x=train_generator,
epochs=epochs,
callbacks=[callback]
)
break
except:
batch_size = input(
'Batch size is too large. Please enter a smaller batch size: ')
batch_size = int(batch_size.strip())
if val_split > 0:
train_generator.batch_size = batch_size
val_generator.batch_size = batch_size
else:
train_generator.batch_size = batch_size
# ====================================== SAVING ====================================== #
# save the model
model.save('saved_models')