diff --git a/src/super_gradients/module_interfaces/exportable_detector.py b/src/super_gradients/module_interfaces/exportable_detector.py index 558719b357..bf89ab835f 100644 --- a/src/super_gradients/module_interfaces/exportable_detector.py +++ b/src/super_gradients/module_interfaces/exportable_detector.py @@ -338,7 +338,15 @@ def export( # This variable holds the output names of the model. # If postprocessing is enabled, it will be set to the output names of the postprocessing module. - output_names: Optional[List[str]] = None + if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs: + output_names = onnx_export_kwargs.pop("output_names") + else: + output_names = None + + if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs: + input_names = onnx_export_kwargs.pop("input_names") + else: + input_names = ["input"] if isinstance(postprocessing, nn.Module): # If a user-specified postprocessing module is provided, we will attach is to the model and not @@ -452,7 +460,7 @@ def export( model=complete_model, model_input=onnx_input, onnx_filename=output, - input_names=["input"], + input_names=input_names, output_names=output_names, onnx_opset=onnx_export_kwargs.get("opset_version", None), do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True), diff --git a/src/super_gradients/module_interfaces/exportable_pose_estimation.py b/src/super_gradients/module_interfaces/exportable_pose_estimation.py index 834798d82d..1b0114441b 100644 --- a/src/super_gradients/module_interfaces/exportable_pose_estimation.py +++ b/src/super_gradients/module_interfaces/exportable_pose_estimation.py @@ -325,7 +325,15 @@ def export( # This variable holds the output names of the model. # If postprocessing is enabled, it will be set to the output names of the postprocessing module. - output_names: Optional[List[str]] = None + if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs: + output_names = onnx_export_kwargs.pop("output_names") + else: + output_names = None + + if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs: + input_names = onnx_export_kwargs.pop("input_names") + else: + input_names = ["input"] if isinstance(postprocessing, nn.Module): # If a user-specified postprocessing module is provided, we will attach is to the model and not @@ -438,7 +446,7 @@ def export( model=complete_model, model_input=onnx_input, onnx_filename=output, - input_names=["input"], + input_names=input_names, output_names=output_names, onnx_opset=onnx_export_kwargs.get("opset_version", None), do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True), diff --git a/src/super_gradients/module_interfaces/exportable_segmentation.py b/src/super_gradients/module_interfaces/exportable_segmentation.py index 8ef5134847..8499f25cd3 100644 --- a/src/super_gradients/module_interfaces/exportable_segmentation.py +++ b/src/super_gradients/module_interfaces/exportable_segmentation.py @@ -326,7 +326,15 @@ def export( # This variable holds the output names of the model. # If postprocessing is enabled, it will be set to the output names of the postprocessing module. - output_names: Optional[List[str]] = None + if onnx_export_kwargs is not None and "output_names" in onnx_export_kwargs: + output_names = onnx_export_kwargs.pop("output_names") + else: + output_names = None + + if onnx_export_kwargs is not None and "input_names" in onnx_export_kwargs: + input_names = onnx_export_kwargs.pop("input_names") + else: + input_names = ["input"] if isinstance(postprocessing, nn.Module): # If a user-specified postprocessing module is provided, we will attach is to the model and not @@ -412,7 +420,7 @@ def export( model=complete_model, model_input=onnx_input, onnx_filename=output, - input_names=["input"], + input_names=input_names, output_names=output_names, onnx_opset=onnx_export_kwargs.get("opset_version", None), do_constant_folding=onnx_export_kwargs.get("do_constant_folding", True),