diff --git a/merlin/systems/dag/ops/tensorflow.py b/merlin/systems/dag/ops/tensorflow.py index 419ea4bc7..40a8b8c3f 100644 --- a/merlin/systems/dag/ops/tensorflow.py +++ b/merlin/systems/dag/ops/tensorflow.py @@ -17,6 +17,8 @@ import pathlib import tempfile +from packaging import version + # this needs to be before any modules that import protobuf os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -170,7 +172,13 @@ def _ensure_input_spec_includes_names(model): def _build_schema_from_signature(signature): schema = Schema() - for col_name, col in signature.items(): + for signature_key, tensor_spec in signature.items(): + # Selecting input name from the signature depending on tensorflow version + # This is to handle kebab-case keys in inputs + if version.parse(tf.__version__) < version.parse("2.12.0"): + col_name = signature_key + else: + col_name = tensor_spec.name if "__offsets" in col_name or "__values" in col_name: col_name = col_name.replace("__offsets", "").replace("__values", "") col_values_sig = signature[f"{col_name}__values"] @@ -178,8 +186,8 @@ def _build_schema_from_signature(signature): col_dtype = col_values_sig.dtype.as_numpy_dtype col_dims = (col_offsets_sig.shape[0], None) else: - col_dtype = col.dtype.as_numpy_dtype - col_dims = col.shape + col_dtype = tensor_spec.dtype.as_numpy_dtype + col_dims = tensor_spec.shape col_schema = ColumnSchema(col_name, dtype=col_dtype, dims=col_dims) schema.column_schemas[col_name] = col_schema return schema diff --git a/requirements/test-cpu.txt b/requirements/test-cpu.txt index 871c5c051..1a59d917b 100644 --- a/requirements/test-cpu.txt +++ b/requirements/test-cpu.txt @@ -1,7 +1,6 @@ -r test.txt faiss-cpu==1.7.2 -tensorflow<=2.9.0 treelite==2.4.0 treelite_runtime==2.4.0 torch~=1.12