From d7152a4d3b92f1be339f71493135627f9a3529c8 Mon Sep 17 00:00:00 2001 From: Ofri Masad Date: Mon, 8 Apr 2024 16:37:19 +0300 Subject: [PATCH] allow passing input_names and output_names params to onnx convert (#1941) * allow passing input_names and output_names params to onnx convert * fix for None onnx_export_kwargs * fix crash --------- Co-authored-by: Eugene Khvedchenya Co-authored-by: Shay Aharon <80472096+shaydeci@users.noreply.github.com> --- .../module_interfaces/exportable_detector.py | 12 ++++++++++-- .../module_interfaces/exportable_pose_estimation.py | 12 ++++++++++-- .../module_interfaces/exportable_segmentation.py | 12 ++++++++++-- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/module_interfaces/exportable_detector.py b/src/super_gradients/module_interfaces/exportable_detector.py index 558719b357..bf89ab835f 100644 --- a/src/super_gradients/module_interfaces/exportable_detector.py +++ b/src/super_gradients/module_interfaces/exportable_detector.py @@ -338,7 +338,15 @@ def export( # This variable holds the output names of the model. # If postprocessing is enabled, it will be set to the output names of the postprocessing module. - output_names: Optional[List[str]] = None + if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs: + output_names = onnx_export_kwargs.pop("output_names") + else: + output_names = None + + if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs: + input_names = onnx_export_kwargs.pop("input_names") + else: + input_names = ["input"] if isinstance(postprocessing, nn.Module): # If a user-specified postprocessing module is provided, we will attach is to the model and not @@ -452,7 +460,7 @@ def export( model=complete_model, model_input=onnx_input, onnx_filename=output, - input_names=["input"], + input_names=input_names, output_names=output_names, onnx_opset=onnx_export_kwargs.get("opset_version", None), do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True), diff --git a/src/super_gradients/module_interfaces/exportable_pose_estimation.py b/src/super_gradients/module_interfaces/exportable_pose_estimation.py index 834798d82d..1b0114441b 100644 --- a/src/super_gradients/module_interfaces/exportable_pose_estimation.py +++ b/src/super_gradients/module_interfaces/exportable_pose_estimation.py @@ -325,7 +325,15 @@ def export( # This variable holds the output names of the model. # If postprocessing is enabled, it will be set to the output names of the postprocessing module. - output_names: Optional[List[str]] = None + if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs: + output_names = onnx_export_kwargs.pop("output_names") + else: + output_names = None + + if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs: + input_names = onnx_export_kwargs.pop("input_names") + else: + input_names = ["input"] if isinstance(postprocessing, nn.Module): # If a user-specified postprocessing module is provided, we will attach is to the model and not @@ -438,7 +446,7 @@ def export( model=complete_model, model_input=onnx_input, onnx_filename=output, - input_names=["input"], + input_names=input_names, output_names=output_names, onnx_opset=onnx_export_kwargs.get("opset_version", None), do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True), diff --git a/src/super_gradients/module_interfaces/exportable_segmentation.py b/src/super_gradients/module_interfaces/exportable_segmentation.py index 8ef5134847..8499f25cd3 100644 --- a/src/super_gradients/module_interfaces/exportable_segmentation.py +++ b/src/super_gradients/module_interfaces/exportable_segmentation.py @@ -326,7 +326,15 @@ def export( # This variable holds the output names of the model. # If postprocessing is enabled, it will be set to the output names of the postprocessing module. - output_names: Optional[List[str]] = None + if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs: + output_names = onnx_export_kwargs.pop("output_names") + else: + output_names = None + + if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs: + input_names = onnx_export_kwargs.pop("input_names") + else: + input_names = ["input"] if isinstance(postprocessing, nn.Module): # If a user-specified postprocessing module is provided, we will attach is to the model and not @@ -412,7 +420,7 @@ def export( model=complete_model, model_input=onnx_input, onnx_filename=output, - input_names=["input"], + input_names=input_names, output_names=output_names, onnx_opset=onnx_export_kwargs.get("opset_version", None), do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True),