From e66a37589b0c1603c6a773120fe3cecd277f95e6 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 9 Jun 2023 15:14:17 +0100 Subject: [PATCH] Update handling of TensorFlow signature input names to support 2.12 (#365) * Remove tensorflow 2.9 upper-bound from test-cpu requirements file * Get input column name from tensor spec instead of signature keys * Add condition for selecting bweteen signature key and tensor spec --- merlin/systems/dag/ops/tensorflow.py | 14 +++++++++++--- requirements/test-cpu.txt | 1 - 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/merlin/systems/dag/ops/tensorflow.py b/merlin/systems/dag/ops/tensorflow.py index 419ea4bc7..40a8b8c3f 100644 --- a/merlin/systems/dag/ops/tensorflow.py +++ b/merlin/systems/dag/ops/tensorflow.py @@ -17,6 +17,8 @@ import pathlib import tempfile +from packaging import version + # this needs to be before any modules that import protobuf os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -170,7 +172,13 @@ def _ensure_input_spec_includes_names(model): def _build_schema_from_signature(signature): schema = Schema() - for col_name, col in signature.items(): + for signature_key, tensor_spec in signature.items(): + # Selecting input name from the signature depending on tensorflow version + # This is to handle kebab-case keys in inputs + if version.parse(tf.__version__) < version.parse("2.12.0"): + col_name = signature_key + else: + col_name = tensor_spec.name if "__offsets" in col_name or "__values" in col_name: col_name = col_name.replace("__offsets", "").replace("__values", "") col_values_sig = signature[f"{col_name}__values"] @@ -178,8 +186,8 @@ def _build_schema_from_signature(signature): col_dtype = col_values_sig.dtype.as_numpy_dtype col_dims = (col_offsets_sig.shape[0], None) else: - col_dtype = col.dtype.as_numpy_dtype - col_dims = col.shape + col_dtype = tensor_spec.dtype.as_numpy_dtype + col_dims = tensor_spec.shape col_schema = ColumnSchema(col_name, dtype=col_dtype, dims=col_dims) schema.column_schemas[col_name] = col_schema return schema diff --git a/requirements/test-cpu.txt b/requirements/test-cpu.txt index 871c5c051..1a59d917b 100644 --- a/requirements/test-cpu.txt +++ b/requirements/test-cpu.txt @@ -1,7 +1,6 @@ -r test.txt faiss-cpu==1.7.2 -tensorflow<=2.9.0 treelite==2.4.0 treelite_runtime==2.4.0 torch~=1.12