Skip to content

Commit

Permalink
Update PredictTensorflow to use low-level saved_model API (#367)
Browse files Browse the repository at this point in the history
* Use saved_model_utils to get inputs from tensorflow model

* Use tf.saved_model API to load and save model
  • Loading branch information
oliverholworthy authored Jun 13, 2023
1 parent f31ac3d commit c84a28c
Showing 1 changed file with 23 additions and 50 deletions.
73 changes: 23 additions & 50 deletions merlin/systems/dag/ops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
# limitations under the License.
#
import os
import pathlib
import tempfile

from packaging import version

# this needs to be before any modules that import protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

Expand Down Expand Up @@ -52,7 +49,7 @@ def __init__(self, model_or_path, custom_objects: dict = None):

if isinstance(model_or_path, (str, os.PathLike)):
self.path = model_or_path
self.model = tf.keras.models.load_model(self.path, custom_objects=custom_objects)
self.model = tf.saved_model.load(self.path)
else:
self.path = None
self.model = model_or_path
Expand Down Expand Up @@ -134,60 +131,36 @@ def compute_output_schema(
return self.output_schema


def _construct_schemas_from_model(model):
signatures = getattr(model, "signatures", {}) or {}
default_signature = signatures.get("serving_default")

if not default_signature:
# roundtrip saved model to disk to generate signature if it doesn't exist
model = _ensure_input_spec_includes_names(model)
def _construct_schemas_from_model(model, *, signature_name="serving_default", tag_set="serve"):
# Importing here because tensorflow is an optional dependency of Merlin Systems
from tensorflow.python.tools import saved_model_utils

with tempfile.TemporaryDirectory() as tmp_dir:
tf_model_path = pathlib.Path(tmp_dir) / "model.savedmodel"
model.save(tf_model_path, include_optimizer=False)
reloaded = tf.keras.models.load_model(tf_model_path)
default_signature = reloaded.signatures["serving_default"]
# save to disk to generate signature from saved model
with tempfile.TemporaryDirectory() as saved_model_dir:
tf.saved_model.save(model, saved_model_dir)
meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
signature_def = meta_graph_def.signature_def[signature_name]

input_schema = _build_schema_from_signature(default_signature.structured_input_signature[1])
output_schema = _build_schema_from_signature(default_signature.structured_outputs)
input_schema = _build_schema_from_signature(signature_def.inputs)
output_schema = _build_schema_from_signature(signature_def.outputs)

return input_schema, output_schema


def _ensure_input_spec_includes_names(model):
if isinstance(model._saved_model_inputs_spec, dict):
for key, spec in model._saved_model_inputs_spec.items():
if isinstance(spec, tuple):
model._saved_model_inputs_spec[key] = (
tf.TensorSpec(shape=spec[0].shape, dtype=spec[0].dtype, name=key),
tf.TensorSpec(shape=spec[1].shape, dtype=spec[1].dtype, name=key),
)
else:
model._saved_model_inputs_spec[key] = tf.TensorSpec(
shape=spec.shape, dtype=spec.dtype, name=key
)

return model


def _build_schema_from_signature(signature):
def _build_schema_from_signature(signature_def_inputs_or_outputs):
schema = Schema()
for signature_key, tensor_spec in signature.items():
# Selecting input name from the signature depending on tensorflow version
# This is to handle kebab-case keys in inputs
if version.parse(tf.__version__) < version.parse("2.12.0"):
col_name = signature_key
else:
col_name = tensor_spec.name
if "__offsets" in col_name or "__values" in col_name:
col_name = col_name.replace("__offsets", "").replace("__values", "")
col_values_sig = signature[f"{col_name}__values"]
col_offsets_sig = signature[f"{col_name}__offsets"]
col_dtype = col_values_sig.dtype.as_numpy_dtype
col_dims = (col_offsets_sig.shape[0], None)
for tensor_name, tensor_info in signature_def_inputs_or_outputs.items():
if "__offsets" in tensor_name or "__values" in tensor_name:
col_name = tensor_name.replace("__offsets", "").replace("__values", "")
values_info = signature_def_inputs_or_outputs[f"{col_name}__values"]
col_dtype = tf.as_dtype(values_info.dtype)
col_dims = [None, None] + [dim.size for dim in values_info.tensor_shape.dim][1:]
else:
col_dtype = tensor_spec.dtype.as_numpy_dtype
col_dims = tensor_spec.shape
col_name = tensor_name
col_dtype = tf.as_dtype(tensor_info.dtype)
col_dims = [
None if dim.size == -1 else dim.size for dim in tensor_info.tensor_shape.dim
]
col_schema = ColumnSchema(col_name, dtype=col_dtype, dims=col_dims)
schema.column_schemas[col_name] = col_schema
return schema

0 comments on commit c84a28c

Please sign in to comment.