From 73db296a2d1361273b579ea30d45df3603cc6a2c Mon Sep 17 00:00:00 2001 From: Jaehong Kim Date: Wed, 24 May 2023 17:03:36 -0700 Subject: [PATCH] Fix compatibility issues for the TF/Keras 2.13. PiperOrigin-RevId: 535031817 --- .../keras/default_8bit/default_8bit_transforms.py | 11 ++++++++--- .../default_n_bit/default_n_bit_transforms.py | 12 +++++++++--- .../python/core/sparsity/keras/prune_registry.py | 7 ++++++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py index 498360c74..58ec82303 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py @@ -17,7 +17,6 @@ import collections import inspect -from keras import backend import numpy as np import tensorflow as tf @@ -29,6 +28,12 @@ from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms +try: + from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top +except ImportError: + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + LayerNode = transforms.LayerNode LayerPattern = transforms.LayerPattern @@ -364,9 +369,9 @@ def pattern(self): return LayerPattern('SeparableConv1D') def _get_name(self, prefix): - # TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't + # TODO(pulkitb): Move away from `unique_object_name` since it isn't # exposed as externally usable. - return backend.unique_object_name(prefix) + return unique_object_name(prefix) def replacement(self, match_layer): if _has_custom_quantize_config(match_layer): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py index 320fb7267..62d37344b 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py @@ -17,7 +17,6 @@ import collections import inspect -from keras import backend import numpy as np import tensorflow as tf @@ -29,6 +28,13 @@ from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms + +try: + from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top +except ImportError: + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top + LayerNode = transforms.LayerNode LayerPattern = transforms.LayerPattern @@ -395,9 +401,9 @@ def pattern(self): return LayerPattern('SeparableConv1D') def _get_name(self, prefix): - # TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't + # TODO(pulkitb): Move away from `unique_object_name` since it isn't # exposed as externally usable. - return backend.unique_object_name(prefix) + return unique_object_name(prefix) def replacement(self, match_layer): if _has_custom_quantize_config(match_layer): diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py index fdabb8f5c..36e203bc8 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py @@ -14,11 +14,16 @@ # ============================================================================== """Registry responsible for built-in keras classes.""" -from keras.engine import base_layer import tensorflow as tf from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer +try: + from keras.engine import base_layer # pylint: disable=g-import-not-at-top +except ImportError: + # Path as seen in pip packages as of TF/Keras 2.13. + from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top + # TODO(b/139939526): move to public API. layers = tf.keras.layers