-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_model.py
152 lines (120 loc) · 5.15 KB
/
export_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from deeplab import common
from deeplab import input_preprocess
from deeplab import model
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path')
flags.DEFINE_string('export_path', None,
'Path to output Tensorflow frozen graph.')
flags.DEFINE_integer('num_classes', 21, 'Number of classes.')
flags.DEFINE_multi_integer('crop_size', [513, 513],
'Crop size [height, width].')
flags.DEFINE_multi_integer('atrous_rates', None,
'Atrous rates for atrous spatial pyramid pooling.')
flags.DEFINE_integer('output_stride', 8,
'The ratio of input to output spatial resolution.')
flags.DEFINE_multi_float('inference_scales', [1.0],
'The scales to resize images for inference.')
flags.DEFINE_bool('add_flipped_images', False,
'Add flipped images during inference or not.')
flags.DEFINE_integer(
'quantize_delay_step', -1,
'Steps to start quantized training. If < 0, will not quantize model.')
flags.DEFINE_bool('save_inference_graph', False,
'Save inference graph in text proto.')
_INPUT_NAME = 'ImageTensor'
_OUTPUT_NAME = 'SemanticPredictions'
_RAW_OUTPUT_NAME = 'RawSemanticPredictions'
def _create_input_tensors():
input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME)
original_image_size = tf.shape(input_image)[1:3]
# Squeeze the dimension in axis=0 since `preprocess_image_and_label` assumes
# image to be 3-D.
image = tf.squeeze(input_image, axis=0)
resized_image, image, _ = input_preprocess.preprocess_image_and_label(
image,
label=None,
crop_height=FLAGS.crop_size[0],
crop_width=FLAGS.crop_size[1],
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
is_training=False,
model_variant=FLAGS.model_variant)
resized_image_size = tf.shape(resized_image)[:2]
# Expand the dimension in axis=0, since the following operations assume the
# image to be 4-D.
image = tf.expand_dims(image, 0)
return image, original_image_size, resized_image_size
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info('Prepare to export model to: %s', FLAGS.export_path)
with tf.Graph().as_default():
image, image_size, resized_image_size = _create_input_tensors()
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
crop_size=FLAGS.crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride)
if tuple(FLAGS.inference_scales) == (1.0,):
tf.logging.info('Exported model performs single-scale inference.')
predictions = model.predict_labels(
image,
model_options=model_options,
image_pyramid=FLAGS.image_pyramid)
else:
tf.logging.info('Exported model performs multi-scale inference.')
if FLAGS.quantize_delay_step >= 0:
raise ValueError(
'Quantize mode is not supported with multi-scale test.')
predictions = model.predict_labels_multi_scale(
image,
model_options=model_options,
eval_scales=FLAGS.inference_scales,
add_flipped_images=FLAGS.add_flipped_images)
raw_predictions = tf.identity(
tf.cast(predictions[common.OUTPUT_TYPE], tf.float32),
_RAW_OUTPUT_NAME)
# Crop the valid regions from the predictions.
semantic_predictions = tf.slice(
raw_predictions,
[0, 0, 0],
[1, resized_image_size[0], resized_image_size[1]])
# Resize back the prediction to the original image size.
def _resize_label(label, label_size):
# Expand dimension of label to [1, height, width, 1] for resize operation.
label = tf.expand_dims(label, 3)
resized_label = tf.image.resize_images(
label,
label_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True)
return tf.cast(tf.squeeze(resized_label, 3), tf.int32)
semantic_predictions = _resize_label(semantic_predictions, image_size)
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME)
if FLAGS.quantize_delay_step >= 0:
tf.contrib.quantize.create_eval_graph()
saver = tf.train.Saver(tf.all_variables())
dirname = os.path.dirname(FLAGS.export_path)
tf.gfile.MakeDirs(dirname)
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True)
freeze_graph.freeze_graph_with_def_protos(
graph_def,
saver.as_saver_def(),
FLAGS.checkpoint_path,
_OUTPUT_NAME,
restore_op_name=None,
filename_tensor_name=None,
output_graph=FLAGS.export_path,
clear_devices=True,
initializer_nodes=None)
if FLAGS.save_inference_graph:
tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt')
if __name__ == '__main__':
flags.mark_flag_as_required('checkpoint_path')
flags.mark_flag_as_required('export_path')
tf.app.run()