-
Notifications
You must be signed in to change notification settings - Fork 56
/
tfkeras_example.py
62 lines (55 loc) · 2.3 KB
/
tfkeras_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow
from tensorflow import keras
import onnx
import tf2onnx
import onnx_tool
from onnx_tool import create_ndarray_f32
temp_model_file = 'tmp.onnx'
def InceptionV3():
inputshape = (1, 299, 299, 3)
model = tensorflow.keras.applications.InceptionV3(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=inputshape[1:],
pooling=None,
classes=1000,
classifier_activation="softmax",
)
onnx_model = tf2onnx.convert.from_keras(model,
input_signature=None, opset=None, custom_ops=None,
custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False, output_path=None)
if isinstance(onnx_model, (list, tuple)):
onnxproto = onnx_model[0]
onnx.save_model(onnxproto, temp_model_file)
dynamics_input = {
'input_1': create_ndarray_f32(inputshape)
}
onnx_tool.model_profile(temp_model_file, dynamic_shapes=dynamics_input)
def MobileNetV3Large():
inputshape = (1, 299, 299, 3)
model = tensorflow.keras.applications.MobileNetV3Large(
include_top=True,
weights="imagenet",
input_tensor=None,
input_shape=inputshape[1:],
pooling=None,
classes=1000,
classifier_activation="softmax",
)
onnx_model = tf2onnx.convert.from_keras(model,
input_signature=None, opset=None, custom_ops=None,
custom_op_handlers=None, custom_rewriter=None,
inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None,
shape_override=None, target=None, large_model=False, output_path=None)
if isinstance(onnx_model, (list, tuple)):
onnxproto = onnx_model[0]
onnx.save_model(onnxproto, temp_model_file)
dynamics_input = {
'input_2': create_ndarray_f32(inputshape)
}
onnx_tool.model_profile(temp_model_file, saveshapesmodel='shapes.onnx', dynamic_shapes=dynamics_input)
InceptionV3()
MobileNetV3Large()