forked from fabiocarrara/meye
-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_to_tfjs.py
58 lines (46 loc) · 1.86 KB
/
convert_to_tfjs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import os
os.sys.path += ['expman']
import json
from glob import glob
import subprocess
from tqdm import tqdm
import expman
def main(args):
variants = (
('' , []),
# ('_qf16', ['--quantize_float16', '*']),
# ('_qu16', ['--quantize_uint16' , '*']),
# ('_qu8' , ['--quantize_uint8' , '*']),
)
converted_models = []
exps = expman.gather(args.run).filter(args.filter)
for exp_name, exp in tqdm(exps.items()):
# ckpt = exp.path_to('best_model.h5')
# ckpt = ckpt if os.path.exists(ckpt) else exp.path_to('last_model.h5')
ckpt = exp.path_to('best_savedmodel/')
for suffix, extra_args in variants:
name = exp_name + suffix
out = os.path.join(args.output, name) if args.output else exp.path_to(f'tfjs_graph{suffix}')
if args.force or not os.path.exists(out):
os.makedirs(out, exist_ok=True)
cmd = ['tensorflowjs_converter',
'--input_format', 'tf_saved_model',
'--output_format', 'tfjs_graph_model'] + extra_args + [ckpt, out]
subprocess.call(cmd)
converted_models.append(name)
js_output = 'models = ' + json.dumps(converted_models)
if args.output:
js_filename = os.path.join(args.output, 'models.js')
with open(js_filename, 'w') as f:
f.write(js_output)
else:
print(js_output)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert runs to tfjs')
parser.add_argument('-f', '--filter', default={}, type=expman.exp_filter)
parser.add_argument('run', default='runs/')
parser.add_argument('--output', help='output dir for models, defaults to run dir')
parser.add_argument('--force', action='store_true', default=False)
args = parser.parse_args()
main(args)