Skip to content

Commit

Permalink
Update handling of TensorFlow signature input names to support 2.12 (#…
Browse files Browse the repository at this point in the history
…365)

* Remove tensorflow 2.9 upper-bound from test-cpu requirements file

* Get input column name from tensor spec instead of signature keys

* Add condition for selecting bweteen signature key and tensor spec
  • Loading branch information
oliverholworthy authored Jun 9, 2023
1 parent b23ff21 commit e66a375
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
14 changes: 11 additions & 3 deletions merlin/systems/dag/ops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import pathlib
import tempfile

from packaging import version

# this needs to be before any modules that import protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

Expand Down Expand Up @@ -170,16 +172,22 @@ def _ensure_input_spec_includes_names(model):

def _build_schema_from_signature(signature):
schema = Schema()
for col_name, col in signature.items():
for signature_key, tensor_spec in signature.items():
# Selecting input name from the signature depending on tensorflow version
# This is to handle kebab-case keys in inputs
if version.parse(tf.__version__) < version.parse("2.12.0"):
col_name = signature_key
else:
col_name = tensor_spec.name
if "__offsets" in col_name or "__values" in col_name:
col_name = col_name.replace("__offsets", "").replace("__values", "")
col_values_sig = signature[f"{col_name}__values"]
col_offsets_sig = signature[f"{col_name}__offsets"]
col_dtype = col_values_sig.dtype.as_numpy_dtype
col_dims = (col_offsets_sig.shape[0], None)
else:
col_dtype = col.dtype.as_numpy_dtype
col_dims = col.shape
col_dtype = tensor_spec.dtype.as_numpy_dtype
col_dims = tensor_spec.shape
col_schema = ColumnSchema(col_name, dtype=col_dtype, dims=col_dims)
schema.column_schemas[col_name] = col_schema
return schema
1 change: 0 additions & 1 deletion requirements/test-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-r test.txt

faiss-cpu==1.7.2
tensorflow<=2.9.0
treelite==2.4.0
treelite_runtime==2.4.0
torch~=1.12
Expand Down

0 comments on commit e66a375

Please sign in to comment.