Skip to content

Commit

Permalink
Fix warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jan 5, 2025
1 parent d3fdb3f commit 02e86d1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
15 changes: 13 additions & 2 deletions keras/src/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import tree
from keras.src.export.export_utils import convert_spec_to_tensor
from keras.src.export.export_utils import get_input_signature
from keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME
from keras.src.export.saved_model import export_saved_model
from keras.src.utils.module_utils import tensorflow as tf

Expand Down Expand Up @@ -76,7 +77,12 @@ def export_onnx(model, filepath, verbose=True, input_signature=None, **kwargs):
input_signature,
**kwargs,
)
saved_model_to_onnx(temp_dir, filepath, model.name)
saved_model_to_onnx(
temp_dir,
filepath,
model.name,
signatures=[DEFAULT_ENDPOINT_NAME],
)

elif backend.backend() == "torch":
import torch
Expand Down Expand Up @@ -133,19 +139,24 @@ def _check_jax_kwargs(kwargs):
return kwargs


def saved_model_to_onnx(saved_model_dir, filepath, name):
def saved_model_to_onnx(saved_model_dir, filepath, name, signatures=None):
from keras.src.export.tf2onnx_lib import patch_tf2onnx
from keras.src.utils.module_utils import tf2onnx

# TODO: Remove this patch once `tf2onnx` supports `numpy>=2.0.0`.
patch_tf2onnx()

if signatures is None:
signatures = ["serve"]

# Convert to ONNX using `tf2onnx` library.
(graph_def, inputs, outputs, initialized_tables, tensors_to_rename) = (
tf2onnx.tf_loader.from_saved_model(
saved_model_dir,
None,
None,
tag=signatures,
signatures=signatures,
return_initialized_tables=True,
return_tensors_to_rename=True,
)
Expand Down
5 changes: 4 additions & 1 deletion keras/src/export/saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
)


DEFAULT_ENDPOINT_NAME = "serve"


@keras_export("keras.export.ExportArchive")
class ExportArchive(BackendExportArchive):
"""ExportArchive is used to write SavedModel artifacts (e.g. for inference).
Expand Down Expand Up @@ -623,7 +626,7 @@ def export_saved_model(
input_signature = get_input_signature(model)

export_archive.track_and_add_endpoint(
"serve", model, input_signature, **kwargs
DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs
)
export_archive.write_out(filepath, verbose=verbose)

Expand Down

0 comments on commit 02e86d1

Please sign in to comment.