Skip to content

Commit

Permalink
Merge pull request #732 from PINTO0309/feat/nms_v5
Browse files Browse the repository at this point in the history
Enable selection of V4 and V5 for `NonMaxSuppression`
  • Loading branch information
PINTO0309 authored Jan 21, 2025
2 parents f16ea9d + 0026c8d commit 372a8da
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 30 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.26.4
ghcr.io/pinto0309/onnx2tf:1.26.5

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.26.4
docker.io/pinto0309/onnx2tf:1.26.5

or

Expand Down Expand Up @@ -1526,6 +1526,7 @@ usage: onnx2tf
[-ois OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...]]
[-nlt]
[-onwdt]
[-snms {v4,v5}]
[-k KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES [KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...]]
[-kt KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES [KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES ...]]
[-kat KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES [KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES ...]]
Expand Down Expand Up @@ -1725,6 +1726,12 @@ optional arguments:
enable --output_nms_with_dynamic_tensor:
output_tensor_shape: [N, 7]
-snms {v4,v5}, --switch_nms_version {v4,v5}
Switch the NMS version to V4 or V5 to convert.
e.g.
NonMaxSuppressionV4(default): --switch_nms_version v4
NonMaxSuppressionV5: --switch_nms_version v5
-k KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES [KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...], \
--keep_ncw_or_nchw_or_ncdhw_input_names KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES \
[KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...]
Expand Down Expand Up @@ -2010,6 +2017,7 @@ convert(
overwrite_input_shape: Union[List[str], NoneType] = None,
no_large_tensor: Optional[bool] = False,
output_nms_with_dynamic_tensor: Optional[bool] = False,
switch_nms_version: Optional[str] = 'v4',
keep_ncw_or_nchw_or_ncdhw_input_names: Union[List[str], NoneType] = None,
keep_nwc_or_nhwc_or_ndhwc_input_names: Union[List[str], NoneType] = None,
keep_shape_absolutely_input_names: Optional[List[str]] = None,
Expand Down Expand Up @@ -2215,6 +2223,12 @@ convert(
enable --output_nms_with_dynamic_tensor:
output_tensor_shape: [N, 7]

switch_nms_version {v4,v5}
Switch the NMS version to V4 or V5 to convert.
e.g.
NonMaxSuppressionV4(default): switch_nms_version="v4"
NonMaxSuppressionV5: switch_nms_version="v5"

keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]]
Holds the NCW or NCHW or NCDHW of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.26.4'
__version__ = '1.26.5'
21 changes: 21 additions & 0 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def convert(
overwrite_input_shape: Optional[List[str]] = None,
no_large_tensor: Optional[bool] = False,
output_nms_with_dynamic_tensor: Optional[bool] = False,
switch_nms_version: Optional[str] = 'v4',
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]] = None,
keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]] = None,
keep_shape_absolutely_input_names: Optional[List[str]] = None,
Expand Down Expand Up @@ -270,6 +271,12 @@ def convert(
enable --output_nms_with_dynamic_tensor:\n
output_tensor_shape: [N, 7]
switch_nms_version: Optional[str]
Switch the NMS version to V4 or V5 to convert.\n\n
e.g.\n
NonMaxSuppressionV4(default): --switch_nms_version v4\n
NonMaxSuppressionV5: --switch_nms_version v5
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]]
Holds the NCW or NCHW or NCDHW of the input shape for the specified INPUT OP names.\n
If a nonexistent INPUT OP name is specified, it is ignored.\n
Expand Down Expand Up @@ -921,6 +928,7 @@ def sanitizing(node):
'mvn_epsilon': mvn_epsilon,
'output_signaturedefs': output_signaturedefs,
'output_nms_with_dynamic_tensor': output_nms_with_dynamic_tensor,
'switch_nms_version': switch_nms_version,
'output_integer_quantized_tflite': output_integer_quantized_tflite,
'gelu_replace_op_names': {},
'space_to_depth_replace_op_names': {},
Expand Down Expand Up @@ -2233,6 +2241,18 @@ def main():
'enable --output_nms_with_dynamic_tensor: \n' +
' output_tensor_shape: [N, 7]'
)
parser.add_argument(
'-snms',
'--switch_nms_version',
type=str,
choices=['v4', 'v5'],
default='v4',
help=\
'Switch the NMS version to V4 or V5 to convert. \n' +
'e.g. \n' +
'NonMaxSuppressionV4(default): --switch_nms_version v4 \n' +
'NonMaxSuppressionV5: --switch_nms_version v5'
)
parser.add_argument(
'-k',
'--keep_ncw_or_nchw_or_ncdhw_input_names',
Expand Down Expand Up @@ -2623,6 +2643,7 @@ def main():
overwrite_input_shape=args.overwrite_input_shape,
no_large_tensor=args.no_large_tensor,
output_nms_with_dynamic_tensor=args.output_nms_with_dynamic_tensor,
switch_nms_version=args.switch_nms_version,
keep_ncw_or_nchw_or_ncdhw_input_names=args.keep_ncw_or_nchw_or_ncdhw_input_names,
keep_nwc_or_nhwc_or_ndhwc_input_names=args.keep_nwc_or_nhwc_or_ndhwc_input_names,
keep_shape_absolutely_input_names=args.keep_shape_absolutely_input_names,
Expand Down
82 changes: 55 additions & 27 deletions onnx2tf/ops/NonMaxSuppression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@


class NMSLayer(tf_keras.layers.Layer):
def __init__(self):
def __init__(self, switch_nms_version='v4'):
super(NMSLayer, self).__init__()
self.switch_nms_version = switch_nms_version

@dispatch.add_dispatch_support
def non_max_suppression(
Expand All @@ -40,29 +41,56 @@ def non_max_suppression(
name=None,
):
with ops.name_scope(name, 'non_max_suppression'):
selected_indices, num_valid = gen_image_ops.non_max_suppression_v4(
boxes=boxes,
scores=scores,
max_output_size=max_output_size \
if not isinstance(max_output_size, np.ndarray) \
else tf.convert_to_tensor(
value=max_output_size,
name='max_output_size'
),
iou_threshold=iou_threshold \
if not isinstance(iou_threshold, np.ndarray) \
else tf.convert_to_tensor(
value=iou_threshold,
name='iou_threshold',
),
score_threshold=score_threshold \
if not isinstance(score_threshold, np.ndarray) \
else tf.convert_to_tensor(
value=score_threshold,
name='score_threshold',
),
pad_to_max_output_size=pad_to_max_output_size,
)
if self.switch_nms_version == 'v4':
selected_indices, num_valid = gen_image_ops.non_max_suppression_v4(
boxes=boxes,
scores=scores,
max_output_size=max_output_size \
if not isinstance(max_output_size, np.ndarray) \
else tf.convert_to_tensor(
value=max_output_size,
name='max_output_size'
),
iou_threshold=iou_threshold \
if not isinstance(iou_threshold, np.ndarray) \
else tf.convert_to_tensor(
value=iou_threshold,
name='iou_threshold',
),
score_threshold=score_threshold \
if not isinstance(score_threshold, np.ndarray) \
else tf.convert_to_tensor(
value=score_threshold,
name='score_threshold',
),
pad_to_max_output_size=pad_to_max_output_size,
)

elif self.switch_nms_version == 'v5':
selected_indices, selected_scores, num_valid = gen_image_ops.non_max_suppression_v5(
boxes=boxes,
scores=scores,
max_output_size=max_output_size \
if not isinstance(max_output_size, np.ndarray) \
else tf.convert_to_tensor(
value=max_output_size,
name='max_output_size'
),
iou_threshold=iou_threshold \
if not isinstance(iou_threshold, np.ndarray) \
else tf.convert_to_tensor(
value=iou_threshold,
name='iou_threshold',
),
score_threshold=score_threshold \
if not isinstance(score_threshold, np.ndarray) \
else tf.convert_to_tensor(
value=score_threshold,
name='score_threshold',
),
soft_nms_sigma=0.0,
pad_to_max_output_size=pad_to_max_output_size,
)
if pad_to_max_output_size:
return selected_indices

Expand Down Expand Up @@ -130,8 +158,8 @@ def make_node(
scores = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2

output_nms_with_dynamic_tensor: bool = \
kwargs['output_nms_with_dynamic_tensor']
output_nms_with_dynamic_tensor: bool = kwargs['output_nms_with_dynamic_tensor']
switch_nms_version: str = kwargs['switch_nms_version']

# Pre-process transpose
boxes = pre_process_transpose(
Expand Down Expand Up @@ -339,7 +367,7 @@ def make_node(
axis=0,
)
# get the selected boxes indices
nms = NMSLayer()
nms = NMSLayer(switch_nms_version=switch_nms_version)
selected_indices = nms(
boxes=tf_boxes,
scores=tf_scores,
Expand Down

0 comments on commit 372a8da

Please sign in to comment.