Skip to content

Commit

Permalink
Fix compatibility issues for the TF/Keras 2.13.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 535031817
  • Loading branch information
Xhark authored and tensorflower-gardener committed May 25, 2023
1 parent af9d021 commit 73db296
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import collections
import inspect

from keras import backend
import numpy as np
import tensorflow as tf

Expand All @@ -29,6 +28,12 @@
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms

try:
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top
except ImportError:
# Path as seen in pip packages as of TF/Keras 2.13.
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top

LayerNode = transforms.LayerNode
LayerPattern = transforms.LayerPattern

Expand Down Expand Up @@ -364,9 +369,9 @@ def pattern(self):
return LayerPattern('SeparableConv1D')

def _get_name(self, prefix):
# TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't
# TODO(pulkitb): Move away from `unique_object_name` since it isn't
# exposed as externally usable.
return backend.unique_object_name(prefix)
return unique_object_name(prefix)

def replacement(self, match_layer):
if _has_custom_quantize_config(match_layer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import collections
import inspect

from keras import backend
import numpy as np
import tensorflow as tf

Expand All @@ -29,6 +28,13 @@
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms


try:
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top
except ImportError:
# Path as seen in pip packages as of TF/Keras 2.13.
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top

LayerNode = transforms.LayerNode
LayerPattern = transforms.LayerPattern

Expand Down Expand Up @@ -395,9 +401,9 @@ def pattern(self):
return LayerPattern('SeparableConv1D')

def _get_name(self, prefix):
# TODO(pulkitb): Move away from `backend.unique_object_name` since it isn't
# TODO(pulkitb): Move away from `unique_object_name` since it isn't
# exposed as externally usable.
return backend.unique_object_name(prefix)
return unique_object_name(prefix)

def replacement(self, match_layer):
if _has_custom_quantize_config(match_layer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
# ==============================================================================
"""Registry responsible for built-in keras classes."""

from keras.engine import base_layer
import tensorflow as tf

from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer

try:
from keras.engine import base_layer # pylint: disable=g-import-not-at-top
except ImportError:
# Path as seen in pip packages as of TF/Keras 2.13.
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top

# TODO(b/139939526): move to public API.

layers = tf.keras.layers
Expand Down

0 comments on commit 73db296

Please sign in to comment.