-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_run_frozen.py
94 lines (83 loc) · 2.78 KB
/
tf_run_frozen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
import argparse
import os
import sys
from typing import Iterable
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('--file', type=str, help='The file name of the frozen graph.')
args = parser.parse_args()
if not os.path.exists(args.file):
parser.exit(1, 'The specified file does not exist: {}'.format(args.file))
graph_def = None
graph = None
print('Loading graph definition ...', file=sys.stderr)
try:
with tf.gfile.GFile(args.file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
except BaseException as e:
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e)))
print('Importing graph ...', file=sys.stderr)
try:
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
except BaseException as e:
parser.exit(2, 'Error importing the graph: {}'.format(str(e)))
print()
print('Operations:')
assert graph is not None
ops = graph.get_operations() # type: Iterable[tf.Operation]
input_nodes = []
last_nodes = []
for op in ops:
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs)))
last_nodes = op.outputs
if op.type == 'Placeholder':
for node in op.outputs:
input_nodes.append(node)
print()
print('Sources (operations without inputs):')
for op in ops:
if len(op.inputs) > 0:
continue
print('- {0}'.format(op.name))
print()
print('Operation inputs:')
for op in ops:
if len(op.inputs) == 0:
continue
print('- {0:20}'.format(op.name))
print(' {0}'.format(', '.join(i.name for i in op.inputs)))
print()
print('Tensors:')
for op in ops:
for out in op.outputs:
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name))
with tf.Session(graph=graph) as sess:
if len(last_nodes) != 1:
raise Exception("Output tensor should be exactly one, while received number = %d" % len(last_nodes))
# logits = graph.get_tensor_by_name('logits:0')
logits = last_nodes[-1]
import numpy as np
import time
feed_dict = {}
for node in input_nodes:
feed_dict[node] = np.ones(node.shape, dtype=node.dtype.as_numpy_dtype())
print('>> Output Shape =', logits.shape)
print('>> Output Value =', sess.run(logits, feed_dict=feed_dict))
print('>> Evalutating Benchmark ...')
num_steps = 100
t_start = time.time()
for step in range(num_steps):
sess.run(logits, feed_dict=feed_dict)
t_end = time.time()
print('>> Average time for each run: %.4f ms;' % ((t_end - t_start) * 1e3 / num_steps))