Skip to content

Commit

Permalink
Fix onnx generation
Browse files Browse the repository at this point in the history
  • Loading branch information
janjongboom committed Nov 11, 2022
1 parent 7ffd4e4 commit 6180451
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
22 changes: 11 additions & 11 deletions ei-onnx-tools/onnx_operation_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def add(
if graph_node.name == destop_name:
dest_ops.append(graph_node)

print('src_ops', src_ops)
print('dest_ops', dest_ops)
# print('src_ops', src_ops)
# print('dest_ops', dest_ops)

# Rewrite the input of the connection Gen OP
if single_op_graph_node_inputs:
Expand All @@ -243,8 +243,8 @@ def add(
if srcop_graph_node.name == srcop_name:
found_output = False
for srcop_graph_node_output in srcop_graph_node.outputs:
print('srcop_graph_node_output.name', srcop_graph_node_output.name,
'srcop_output_name', srcop_output_name)
# print('srcop_graph_node_output.name', srcop_graph_node_output.name,
# 'srcop_output_name', srcop_output_name)
if srcop_graph_node_output.name == srcop_output_name:
for idxs, single_op_graph_node_input in enumerate(single_op_graph_node_inputs):
if single_op_graph_node_input.name == addop_input_name:
Expand All @@ -257,9 +257,9 @@ def add(
else:
continue

if not found_output:
print('not found', srcop_output_name, 'in',
[ x.name for x in srcop_graph_node.outputs ])
# if not found_output:
# print('not found', srcop_output_name, 'in',
# [ x.name for x in srcop_graph_node.outputs ])

break
else:
Expand All @@ -275,13 +275,13 @@ def add(
if destop_graph_node.name == destop_name:
found_input = False
for idxd, destop_graph_node_input in enumerate(destop_graph_node.inputs):
print('destop_graph_node_input.name', destop_graph_node_input.name,
'destop_input_name', destop_input_name)
# print('destop_graph_node_input.name', destop_graph_node_input.name,
# 'destop_input_name', destop_input_name)

if destop_graph_node_input.name == destop_input_name:
for single_op_graph_node_output in single_op_graph_node_outputs:
print('single_op_graph_node_output.name', single_op_graph_node_output.name,
'addop_output_name', addop_output_name)
# print('single_op_graph_node_output.name', single_op_graph_node_output.name,
# 'addop_output_name', addop_output_name)
if single_op_graph_node_output.name == addop_output_name:
found_input = True
destop_graph_node.inputs[idxd] = single_op_graph_node_output
Expand Down
10 changes: 9 additions & 1 deletion yolox-repo/tools/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,15 @@ def main():
if args.ckpt is None:
file_name = os.path.join(exp.output_dir, args.experiment_name)
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
if not os.path.exists(ckpt_file):
ckpt_file = os.path.join(file_name, "latest_ckpt.pth")
elif args.ckpt == "random":
pass
else:
ckpt_file = args.ckpt

print('checkpoint file is', ckpt_file)

if args.ckpt == "random":
#Proceed with initialized values
ckpt = None
Expand All @@ -175,7 +179,11 @@ def main():
args.output = 'detections'

logger.info("loading checkpoint done.")
img = cv2.imread("/scripts/out/train/000000000000.jpg")

dir_path = os.path.dirname(os.path.realpath(__file__))
img_path = os.path.join(dir_path, '..', 'datasets', 'COCO', 'train2017', '000000000000.jpg')

img = cv2.imread(img_path)
img, ratio = preprocess(img, exp.test_size)
img = img[None, ...]
img = img.astype('float32')
Expand Down

0 comments on commit 6180451

Please sign in to comment.