Skip to content
sngyo edited this page Nov 4, 2019 · 2 revisions
import tf2onnx

sess = tf.Session()

BIG_MODEL=False

if BIG_MODEL:
    EYE_WIDTH=180
    EYE_HEIGHT=108
else:
    EYE_WIDTH=60
    EYE_HEIGHT=36

if BIG_MODEL:
    elgmodel = ELG(
                sess, train_data={'videostream': data_source},
                first_layer_stride=3,
                num_modules=3,
                num_feature_maps=64,
                learning_schedule=[
                    {
                        'loss_terms_to_optimize': {'dummy': ['hourglass', 'radius']},
                    },
                ],
            )
else:
    elgmodel = ELG(
                sess, train_data={'videostream': data_source},
                first_layer_stride=1,
                num_modules=2,
                num_feature_maps=32,
                learning_schedule=[
                    {
                        'loss_terms_to_optimize': {'dummy': ['hourglass', 'radius']},
                    },
                ],
            )

elgmodel.initialize_if_not(training=False)
elgmodel.checkpoint.load_all()

eye = sess.graph.get_tensor_by_name('eye:0')
if BIG_MODEL:
    heatmaps = sess.graph.get_tensor_by_name('hourglass/hg_3/after/hmap/conv/BiasAdd:0')
else:
    heatmaps = sess.graph.get_tensor_by_name('hourglass/hg_2/after/hmap/conv/BiasAdd:0')
landmarks = sess.graph.get_tensor_by_name('upscale/mul:0')
radius = sess.graph.get_tensor_by_name('radius/out/fc/BiasAdd:0')

# fix batch norm nodes
gd = sess.graph.as_graph_def()
for node in gd.node:
    if node.op == 'RefSwitch':
        node.op = 'Switch'
        for index in range(len(node.input)):
            if 'moving_' in node.input[index]:
                node.input[index] = node.input[index] + '/read'
    elif node.op == 'AssignSub':
        node.op = 'Sub'
        if 'use_locking' in node.attr: del node.attr['use_locking']

# Freeze the graph
if BIG_MODEL:
    output_node_names=["upscale/mul","hourglass/hg_3/after/hmap/conv/BiasAdd","radius/out/fc/BiasAdd"]
else:
    output_node_names=["upscale/mul","hourglass/hg_2/after/hmap/conv/BiasAdd","radius/out/fc/BiasAdd"]
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
    sess,
    gd,
    output_node_names
    )

# Convert placeholder to constant
target_node_name = "learning_params/Placeholder_1"
c = tf.constant(False, dtype=bool, shape=[], name=target_node_name)

from tensorflow.core.framework import graph_pb2
import copy

detected=False
new_graph_def = graph_pb2.GraphDef()
for node in frozen_graph_def.node:
    print(node.name+"/"+target_node_name)
    if node.name == target_node_name:
        detected=True
        new_graph_def.node.extend([c.op.node_def])
    else:
        new_graph_def.node.extend([copy.deepcopy(node)])

frozen_graph_def = new_graph_def

# Convert to onnx
input_names=["import/eye:0"]
if BIG_MODEL:
    output_names=["import/upscale/mul:0","import/hourglass/hg_3/after/hmap/conv/BiasAdd:0","import/radius/out/fc/BiasAdd:0"]
    onnx_name="gazeml_elg_i180x108_n64.onnx"
else:
    output_names=["import/upscale/mul:0","import/hourglass/hg_2/after/hmap/conv/BiasAdd:0","import/radius/out/fc/BiasAdd:0"]
    onnx_name="gazeml_elg_i60x36_n32.onnx"
graph1 = tf.Graph()
with graph1.as_default():
    tf.import_graph_def(frozen_graph_def)
    onnx_graph = tf2onnx.tfonnx.process_tf_graph(graph1, input_names=input_names, output_names=output_names, opset=10)

    from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
    optimizer = TransposeOptimizer()
    opt_model_proto = optimizer.optimize(onnx_graph)

    model_proto = onnx_graph.make_model("gazeml")
    with open(onnx_name, "wb") as f:
        f.write(model_proto.SerializeToString())
Clone this wiki locally