diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py index aaae72f91..8dc0378fe 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py @@ -21,7 +21,7 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized +from keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index 872296e5d..fa36dd00c 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -21,7 +21,6 @@ from absl.testing import parameterized import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper @@ -162,7 +161,6 @@ def _count_clustered_layers(self, model): count += 1 return count - @keras_parameterized.run_all_keras_modes def testClusterKerasClusterableLayer(self): """Verifies that a built-in keras layer marked as clusterable is being clustered correctly.""" wrapped_layer = self._build_clustered_layer_model( @@ -170,7 +168,6 @@ def testClusterKerasClusterableLayer(self): self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer) - @keras_parameterized.run_all_keras_modes def testClusterKerasClusterableLayerWithSparsityPreservation(self): """Verifies that a built-in keras layer marked as clusterable is being clustered correctly when sparsity preservation is enabled.""" preserve_sparsity_params = {'preserve_sparsity': True} @@ -180,7 +177,6 @@ def testClusterKerasClusterableLayerWithSparsityPreservation(self): self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer) - @keras_parameterized.run_all_keras_modes def testClusterKerasNonClusterableLayer(self): """Verifies that a built-in keras layer not marked as clusterable is not being clustered.""" wrapped_layer = self._build_clustered_layer_model( @@ -190,7 +186,6 @@ def testClusterKerasNonClusterableLayer(self): wrapped_layer) self.assertEqual([], wrapped_layer.layer.get_clusterable_weights()) - @keras_parameterized.run_all_keras_modes def testDepthwiseConv2DLayerNonClusterable(self): """Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss.""" wrapped_layer = self._build_clustered_layer_model( @@ -200,7 +195,6 @@ def testDepthwiseConv2DLayerNonClusterable(self): wrapped_layer) self.assertEqual([], wrapped_layer.layer.get_clusterable_weights()) - @keras_parameterized.run_all_keras_modes def testDenseLayer(self): """Verifies that we can cluster a Dense layer.""" input_shape = (28, 1) @@ -214,7 +208,6 @@ def testDenseLayer(self): self.assertEqual([1, 10], wrapped_layer.layer.get_clusterable_weights()[0][1].shape) - @keras_parameterized.run_all_keras_modes def testConv1DLayer(self): """Verifies that we can cluster a Conv1D layer.""" input_shape = (28, 1) @@ -227,7 +220,6 @@ def testConv1DLayer(self): self.assertEqual([5, 1, 3], wrapped_layer.layer.get_clusterable_weights()[0][1].shape) - @keras_parameterized.run_all_keras_modes def testConv1DTransposeLayer(self): """Verifies that we can cluster a Conv1DTranspose layer.""" input_shape = (28, 1) @@ -240,7 +232,6 @@ def testConv1DTransposeLayer(self): self.assertEqual([5, 3, 1], wrapped_layer.layer.get_clusterable_weights()[0][1].shape) - @keras_parameterized.run_all_keras_modes def testConv2DLayer(self): """Verifies that we can cluster a Conv2D layer.""" input_shape = (28, 28, 1) @@ -253,7 +244,6 @@ def testConv2DLayer(self): self.assertEqual([4, 5, 1, 3], wrapped_layer.layer.get_clusterable_weights()[0][1].shape) - @keras_parameterized.run_all_keras_modes def testConv2DTransposeLayer(self): """Verifies that we can cluster a Conv2DTranspose layer.""" input_shape = (28, 28, 1) @@ -266,7 +256,6 @@ def testConv2DTransposeLayer(self): self.assertEqual([4, 5, 3, 1], wrapped_layer.layer.get_clusterable_weights()[0][1].shape) - @keras_parameterized.run_all_keras_modes def testConv3DLayer(self): """Verifies that we can cluster a Conv3D layer.""" input_shape = (28, 28, 28, 1) @@ -287,7 +276,6 @@ def testClusterKerasUnsupportedLayer(self): with self.assertRaises(ValueError): cluster.cluster_weights(keras_unsupported_layer, **self.params) - @keras_parameterized.run_all_keras_modes def testClusterCustomClusterableLayer(self): """Verifies that a custom clusterable layer is being clustered correctly.""" wrapped_layer = self._build_clustered_layer_model( @@ -297,7 +285,6 @@ def testClusterCustomClusterableLayer(self): self.assertEqual([('kernel', wrapped_layer.layer.kernel)], wrapped_layer.layer.get_clusterable_weights()) - @keras_parameterized.run_all_keras_modes def testClusterCustomClusterableLayerWithSparsityPreservation(self): """Verifies that a custom clusterable layer is being clustered correctly when sparsity preservation is enabled.""" preserve_sparsity_params = {'preserve_sparsity': True} @@ -424,7 +411,6 @@ def testStripClusteringSequentialModelWithBiasConstraint(self): keras_file = os.path.join(tmp_dir_name, 'cluster_test') stripped_model.save(keras_file, save_traces=True) - @keras_parameterized.run_all_keras_modes def testClusterSequentialModelSelectively(self): clustered_model = keras.Sequential() clustered_model.add( @@ -437,7 +423,6 @@ def testClusterSequentialModelSelectively(self): self.assertNotIsInstance(clustered_model.layers[1], cluster_wrapper.ClusterWeights) - @keras_parameterized.run_all_keras_modes def testClusterSequentialModelSelectivelyWithSparsityPreservation(self): """Verifies that layers within a sequential model can be clustered selectively when sparsity preservation is enabled.""" preserve_sparsity_params = {'preserve_sparsity': True} @@ -454,7 +439,6 @@ def testClusterSequentialModelSelectivelyWithSparsityPreservation(self): self.assertNotIsInstance(clustered_model.layers[1], cluster_wrapper.ClusterWeights) - @keras_parameterized.run_all_keras_modes def testClusterFunctionalModelSelectively(self): """Verifies that layers within a functional model can be clustered selectively.""" i1 = keras.Input(shape=(10,)) @@ -469,7 +453,6 @@ def testClusterFunctionalModelSelectively(self): self.assertNotIsInstance(clustered_model.layers[3], cluster_wrapper.ClusterWeights) - @keras_parameterized.run_all_keras_modes def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self): """Verifies that layers within a functional model can be clustered selectively when sparsity preservation is enabled.""" preserve_sparsity_params = {'preserve_sparsity': True} @@ -486,7 +469,6 @@ def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self): self.assertNotIsInstance(clustered_model.layers[3], cluster_wrapper.ClusterWeights) - @keras_parameterized.run_all_keras_modes def testClusterModelValidLayersSuccessful(self): """Verifies that clustering a sequential model results in all clusterable layers within the model being clustered.""" model = keras.Sequential([ @@ -500,7 +482,6 @@ def testClusterModelValidLayersSuccessful(self): for layer, clustered_layer in zip(model.layers, clustered_model.layers): self._validate_clustered_layer(layer, clustered_layer) - @keras_parameterized.run_all_keras_modes def testClusterModelValidLayersSuccessfulWithSparsityPreservation(self): """Verifies that clustering a sequential model results in all clusterable layers within the model being clustered when sparsity preservation is enabled.""" preserve_sparsity_params = {'preserve_sparsity': True} @@ -540,7 +521,6 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self): self.custom_clusterable_layer, custom_non_clusterable_layer ]), **self.params) - @keras_parameterized.run_all_keras_modes def testClusterModelDoesNotWrapAlreadyWrappedLayer(self): """Verifies that clustering a model that contains an already clustered layer does not result in wrapping the clustered layer into another cluster_wrapper.""" model = keras.Sequential([ @@ -579,7 +559,6 @@ def testClusterSequentialModelNoInput(self): clustered_model = cluster.cluster_weights(model, **self.params) self.assertEqual(self._count_clustered_layers(clustered_model), 2) - @keras_parameterized.run_all_keras_modes def testClusterSequentialModelWithInput(self): """Verifies that a sequential model with an input layer is being clustered correctly.""" # With InputLayer @@ -607,7 +586,6 @@ def testClusterSequentialModelPreservesBuiltStateNoInput(self): json.loads(clustered_model.to_json())) self.assertEqual(loaded_model.built, False) - @keras_parameterized.run_all_keras_modes def testClusterSequentialModelPreservesBuiltStateWithInput(self): """Verifies that clustering a sequential model with an input layer preserves the built state of the model.""" # With InputLayer @@ -625,7 +603,6 @@ def testClusterSequentialModelPreservesBuiltStateWithInput(self): json.loads(clustered_model.to_json())) self.assertEqual(loaded_model.built, True) - @keras_parameterized.run_all_keras_modes def testClusterFunctionalModelPreservesBuiltState(self): """Verifies that clustering a functional model preserves the built state of the model.""" i1 = keras.Input(shape=(10,)) @@ -644,7 +621,6 @@ def testClusterFunctionalModelPreservesBuiltState(self): json.loads(clustered_model.to_json())) self.assertEqual(loaded_model.built, True) - @keras_parameterized.run_all_keras_modes def testClusterFunctionalModel(self): """Verifies that a functional model is being clustered correctly.""" i1 = keras.Input(shape=(10,)) @@ -656,7 +632,6 @@ def testClusterFunctionalModel(self): clustered_model = cluster.cluster_weights(model, **self.params) self.assertEqual(self._count_clustered_layers(clustered_model), 3) - @keras_parameterized.run_all_keras_modes def testClusterFunctionalModelWithLayerReused(self): """Verifies that a layer reused within a functional model multiple times is only being clustered once.""" # The model reuses the Dense() layer. Make sure it's only clustered once. @@ -668,14 +643,12 @@ def testClusterFunctionalModelWithLayerReused(self): clustered_model = cluster.cluster_weights(model, **self.params) self.assertEqual(self._count_clustered_layers(clustered_model), 1) - @keras_parameterized.run_all_keras_modes def testClusterSubclassModel(self): """Verifies that attempting to cluster an instance of a subclass of keras.Model raises an exception.""" model = TestModel() with self.assertRaises(ValueError): _ = cluster.cluster_weights(model, **self.params) - @keras_parameterized.run_all_keras_modes def testClusterSubclassModelAsSubmodel(self): """Verifies that attempting to cluster a model with submodel that is a subclass throws an exception.""" model_subclass = TestModel() @@ -683,7 +656,6 @@ def testClusterSubclassModelAsSubmodel(self): with self.assertRaisesRegex(ValueError, 'Subclassed models.*'): _ = cluster.cluster_weights(model, **self.params) - @keras_parameterized.run_all_keras_modes def testStripClusteringSequentialModel(self): """Verifies that stripping the clustering wrappers from a sequential model produces the expected config.""" model = keras.Sequential([ @@ -697,7 +669,6 @@ def testStripClusteringSequentialModel(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config()) - @keras_parameterized.run_all_keras_modes def testClusterStrippingFunctionalModel(self): """Verifies that stripping the clustering wrappers from a functional model produces the expected config.""" i1 = keras.Input(shape=(10,)) @@ -713,7 +684,6 @@ def testClusterStrippingFunctionalModel(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config()) - @keras_parameterized.run_all_keras_modes def testClusterWeightsStrippedWeights(self): """Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights.""" i1 = keras.Input(shape=(10,)) @@ -728,7 +698,6 @@ def testClusterWeightsStrippedWeights(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertLen(stripped_model.get_weights(), cluster_weight_length) - @keras_parameterized.run_all_keras_modes def testStrippedKernel(self): """Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value.""" i1 = keras.Input(shape=(1, 1, 1)) @@ -746,7 +715,6 @@ def testStrippedKernel(self): self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel) self.assertIn(stripped_conv2d_layer.kernel, stripped_conv2d_layer.weights) - @keras_parameterized.run_all_keras_modes def testStripSelectivelyClusteredFunctionalModel(self): """Verifies that invoking strip_clustering() on a selectively clustered functional model strips the clustering wrappers from the clustered layers.""" i1 = keras.Input(shape=(10,)) @@ -761,7 +729,6 @@ def testStripSelectivelyClusteredFunctionalModel(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertIsInstance(stripped_model.layers[2], layers.Dense) - @keras_parameterized.run_all_keras_modes def testStripSelectivelyClusteredSequentialModel(self): """Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers.""" clustered_model = keras.Sequential([ @@ -775,7 +742,6 @@ def testStripSelectivelyClusteredSequentialModel(self): self.assertEqual(self._count_clustered_layers(stripped_model), 0) self.assertIsInstance(stripped_model.layers[0], layers.Dense) - @keras_parameterized.run_all_keras_modes def testStripClusteringAndSetOriginalWeightsBack(self): """Verifies that we can set_weights onto the stripped model.""" model = keras.Sequential([ diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py index 78b43464d..492f752b6 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_integration_test.py @@ -17,7 +17,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster @@ -30,7 +29,6 @@ layers = tf.keras.layers -@keras_parameterized.run_all_keras_modes class ClusterPreserveIntegrationTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry_test.py index 3fee5077c..99b601389 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry_test.py @@ -18,7 +18,6 @@ import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry from tensorflow_model_optimization.python.core.quantization.keras import quantize_config @@ -29,7 +28,6 @@ layers = tf.keras.layers -@keras_parameterized.run_all_keras_modes class ClusterPreserveQuantizeRegistryTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve/prune_preserve_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve/prune_preserve_quantize_registry_test.py index 9124fa510..65bc5dd5d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve/prune_preserve_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/prune_preserve/prune_preserve_quantize_registry_test.py @@ -17,7 +17,6 @@ import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras import quantize_config from tensorflow_model_optimization.python.core.quantization.keras.collaborative_optimizations.prune_preserve import ( prune_preserve_quantize_registry,) @@ -28,7 +27,6 @@ layers = tf.keras.layers -@keras_parameterized.run_all_keras_modes class PrunePreserveQuantizeRegistryTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py index 3efec2d6a..e603c2d7f 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py @@ -24,7 +24,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras import quantizers from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry @@ -73,7 +72,6 @@ def _assert_kernel_equality(self, a, b): self.assertAllEqual(a.numpy(), b.numpy()) -@keras_parameterized.run_all_keras_modes class QuantizeRegistryTest( tf.test.TestCase, parameterized.TestCase, _TestHelper): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py index 90f29339a..7da1d114a 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantizers_test.py @@ -22,7 +22,6 @@ import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers Default8BitConvWeightsQuantizer = default_8bit_quantizers.Default8BitConvWeightsQuantizer @@ -30,7 +29,6 @@ keras = tf.keras -@keras_parameterized.run_all_keras_modes class Default8BitConvWeightsQuantizerTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py index 0e4fd43e9..35363ec89 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py @@ -17,11 +17,10 @@ import collections import inspect +from keras import backend import numpy as np import tensorflow as tf -from tensorflow.python.keras import backend - from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer from tensorflow_model_optimization.python.core.quantization.keras import quantizers diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py index e25be7d1a..01be17c6d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py @@ -22,12 +22,10 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras import quantize from tensorflow_model_optimization.python.core.quantization.keras import utils -@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class QuantizeNumericalTest(tf.test.TestCase, parameterized.TestCase): def _batch(self, dims, batch_size): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py index 75b9a31e7..cc59fef52 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py @@ -24,7 +24,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras import quantizers from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry as n_bit_registry @@ -73,7 +72,6 @@ def _assert_kernel_equality(self, a, b): self.assertAllEqual(a.numpy(), b.numpy()) -@keras_parameterized.run_all_keras_modes class QuantizeRegistryTest( tf.test.TestCase, parameterized.TestCase, _TestHelper): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py index 6a6022342..dcbcc90bf 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantizers_test.py @@ -22,7 +22,6 @@ import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantizers DefaultNBitConvWeightsQuantizer = default_n_bit_quantizers.DefaultNBitConvWeightsQuantizer @@ -30,7 +29,6 @@ keras = tf.keras -@keras_parameterized.run_all_keras_modes class DefaultNBitConvWeightsQuantizerTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py index 8c3c91afc..c2eeef955 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py @@ -17,10 +17,10 @@ import collections import inspect +from keras import backend import numpy as np import tensorflow as tf -from tensorflow.python.keras import backend from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer from tensorflow_model_optimization.python.core.quantization.keras import quantizers diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quant_ops_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quant_ops_test.py index 3d4d7b861..4a28ceeb5 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quant_ops_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quant_ops_test.py @@ -22,14 +22,12 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.quantization.keras import quant_ops _SYMMETRIC_RANGE_RATIO = 0.9921875 # 127 / 128 -@keras_parameterized.run_all_keras_modes class QuantOpsTest(tf.test.TestCase, parameterized.TestCase): def testAllValuesQuantiize_TrainingAssign(self): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py index eb1ae3168..d289f52f3 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py @@ -23,7 +23,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation from tensorflow_model_optimization.python.core.quantization.keras import quantizers @@ -37,7 +36,6 @@ MovingAverageQuantizer = quantizers.MovingAverageQuantizer -@keras_parameterized.run_all_keras_modes class QuantizeAwareQuantizationTest(tf.test.TestCase, parameterized.TestCase): def setUp(self): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py index de1515478..d1f809bd6 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_functional_test.py @@ -25,7 +25,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.keras.testing import test_utils_mnist from tensorflow_model_optimization.python.core.quantization.keras import quantize @@ -34,7 +33,6 @@ layers = tf.keras.layers -@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class QuantizeFunctionalTest(tf.test.TestCase, parameterized.TestCase): # TODO(pulkitb): Parameterize test and include functional mnist, and diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py index f604ddfce..1e7abae36 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py @@ -26,7 +26,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.keras import test_utils @@ -44,7 +43,6 @@ # TODO(tfmot): enable for v1. Currently fails because the decorator # on graph mode wraps everything in a graph, which is not compatible # with the TFLite converter's call to clear_session(). -@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class QuantizeIntegrationTest(tf.test.TestCase, parameterized.TestCase): def _batch(self, dims, batch_size): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py index 99e5faf7a..c51f07d25 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_models_test.py @@ -26,12 +26,10 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.quantization.keras import quantize from tensorflow_model_optimization.python.core.quantization.keras import utils -@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class QuantizeModelsTest(tf.test.TestCase, parameterized.TestCase): # Derived using diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py index 7b3dcc3ed..7df0567f2 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py @@ -23,7 +23,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.quantization.keras import quantizers @@ -31,7 +30,6 @@ serialize_keras_object = tf.keras.utils.serialize_keras_object -@keras_parameterized.run_all_keras_modes @parameterized.parameters( quantizers.LastValueQuantizer, quantizers.MovingAverageQuantizer, diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py index c7b443a77..ae1c1b117 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py @@ -21,7 +21,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils from tensorflow_model_optimization.python.core.sparsity.keras import prune from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry @@ -37,7 +36,6 @@ ModelCompare = keras_test_utils.ModelCompare -@keras_parameterized.run_all_keras_modes class PruneIntegrationTest(tf.test.TestCase, parameterized.TestCase, ModelCompare): @@ -691,7 +689,6 @@ def testPruneWithPolynomialDecayPastEndStep_PreservesSparsity( self._check_strip_pruning_matches_original(model, 0.6) -@keras_parameterized.run_all_keras_modes(always_skip_v1=True) class PruneIntegrationCustomTrainingLoopTest(tf.test.TestCase, parameterized.TestCase): diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py index 0238b78e2..9b11eeaf8 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_registry.py @@ -14,12 +14,13 @@ # ============================================================================== """Registry responsible for built-in keras classes.""" +import keras import tensorflow as tf -# TODO(b/139939526): move to public API. -from tensorflow.python.keras.engine.base_layer import TensorFlowOpLayer from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer +# TODO(b/139939526): move to public API. + layers = tf.keras.layers layers_compat_v1 = tf.compat.v1.keras.layers @@ -100,7 +101,7 @@ class PruneRegistry(object): ], layers.experimental.SyncBatchNormalization: [], layers.experimental.preprocessing.Rescaling.__class__: [], - TensorFlowOpLayer: [], + keras.layers.Lambda: [], layers_compat_v1.BatchNormalization: [], } diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py index fe2fbb2f0..9b6062ab3 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py @@ -22,7 +22,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer from tensorflow_model_optimization.python.core.sparsity.keras import prune @@ -202,7 +201,6 @@ def testPruneValidLayersListSuccessful(self): for layer, pruned_layer in zip(model_layers, pruned_layers): self._validate_pruned_layer(layer, pruned_layer) - @keras_parameterized.run_all_keras_modes def testPruneInferenceWorks_PruningStepCallbackNotRequired(self): model = prune.prune_low_magnitude( keras.Sequential([ diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py index 8966c0144..588902995 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks_test.py @@ -22,7 +22,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils from tensorflow_model_optimization.python.core.sparsity.keras import prune from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks @@ -55,7 +54,6 @@ def _pruned_model_setup(self, custom_training_loop=False): pruned_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) return pruned_model, x_train, y_train - @keras_parameterized.run_all_keras_modes def testUpdatePruningStepsAndLogsSummaries(self): log_dir = tempfile.mkdtemp() pruned_model, x_train, y_train = self._pruned_model_setup() @@ -77,7 +75,6 @@ def testUpdatePruningStepsAndLogsSummaries(self): self._assertLogsExist(log_dir) # This style of custom training loop isn't available in graph mode. - @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def testUpdatePruningStepsAndLogsSummaries_CustomTrainingLoop(self): log_dir = tempfile.mkdtemp() pruned_model, loss, optimizer, x_train, y_train = self._pruned_model_setup( @@ -116,7 +113,6 @@ def testUpdatePruningStepsAndLogsSummaries_CustomTrainingLoop(self): 3, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step)) self._assertLogsExist(log_dir) - @keras_parameterized.run_all_keras_modes def testUpdatePruningStepsAndLogsSummaries_RunInference(self): pruned_model, _, _, x_train, _ = self._pruned_model_setup( custom_training_loop=True) @@ -128,7 +124,6 @@ def testUpdatePruningStepsAndLogsSummaries_RunInference(self): self.assertEqual( -1, tf.keras.backend.get_value(pruned_model.layers[1].pruning_step)) - @keras_parameterized.run_all_keras_modes def testPruneTrainingRaisesError_PruningStepCallbackMissing(self): pruned_model, x_train, y_train = self._pruned_model_setup() @@ -137,7 +132,6 @@ def testPruneTrainingRaisesError_PruningStepCallbackMissing(self): pruned_model.fit(x_train, y_train) # This style of custom training loop isn't available in graph mode. - @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def testPruneTrainingLoopRaisesError_PruningStepCallbackMissing_CustomTrainingLoop( self): pruned_model, _, _, x_train, _ = self._pruned_model_setup( @@ -149,7 +143,6 @@ def testPruneTrainingLoopRaisesError_PruningStepCallbackMissing_CustomTrainingLo with tf.GradientTape(): pruned_model(inp, training=True) - @keras_parameterized.run_all_keras_modes def testPruningSummariesRaisesError_LogDirNotNonEmptyString(self): with self.assertRaises(ValueError): pruning_callbacks.PruningSummaries(log_dir='') diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py index 1d0999369..0d25f9b0a 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py @@ -25,7 +25,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule @@ -43,7 +42,6 @@ def assign_add(ref, value): return ref.assign_add(value) -@keras_parameterized.run_all_keras_modes class PruningTest(test.TestCase, parameterized.TestCase): def setUp(self): diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py index 46c849379..916d080ab 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py @@ -18,7 +18,6 @@ import tensorflow as tf # TODO(b/139939526): move to public API. -from tensorflow.python.keras import keras_parameterized from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule @@ -155,7 +154,6 @@ def testSparsityValueIsValid(self, schedule_type): # Tests to ensure begin_step, end_step, frequency are used correctly. - @keras_parameterized.run_all_keras_modes @parameterized.named_parameters( { 'testcase_name': 'ConstantSparsity', @@ -189,7 +187,6 @@ def testPrunesOnlyInBeginEndStepRange(self, schedule_type): self.assertFalse(self.evaluate(sparsity(step_201))[0]) self.assertFalse(self.evaluate(sparsity(step_210))[0]) - @keras_parameterized.run_all_keras_modes @parameterized.named_parameters( { 'testcase_name': 'ConstantSparsity', @@ -216,7 +213,6 @@ def testOnlyPrunesAtValidFrequencySteps(self, schedule_type): class ConstantSparsityTest(tf.test.TestCase, parameterized.TestCase): - @keras_parameterized.run_all_keras_modes def testPrunesForeverIfEndStepIsNegativeOne(self): sparsity = pruning_schedule.ConstantSparsity(0.5, 0, -1, 10) @@ -230,7 +226,6 @@ def testPrunesForeverIfEndStepIsNegativeOne(self): self.assertAllClose(0.5, self.evaluate(sparsity(step_10000))[1]) self.assertAllClose(0.5, self.evaluate(sparsity(step_100000000))[1]) - @keras_parameterized.run_all_keras_modes def testPrunesWithConstantSparsity(self): sparsity = pruning_schedule.ConstantSparsity(0.5, 100, 200, 10) @@ -263,7 +258,6 @@ def testRaisesErrorIfEndStepIsNegative(self): with self.assertRaises(ValueError): pruning_schedule.PolynomialDecay(0.4, 0.8, 10, -1) - @keras_parameterized.run_all_keras_modes def testPolynomialDecay_PrunesCorrectly(self): sparsity = pruning_schedule.PolynomialDecay(0.2, 0.8, 100, 110, 3, 2) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index 65d5a69a5..13927b471 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -20,12 +20,13 @@ from __future__ import print_function import inspect + # import g3 + +from keras.utils import generic_utils import numpy as np import tensorflow as tf -# TODO(b/139939526): update to use public API. -from tensorflow.python.keras.utils import generic_utils from tensorflow_model_optimization.python.core.keras import compat as tf_compat from tensorflow_model_optimization.python.core.keras import metrics from tensorflow_model_optimization.python.core.keras import utils @@ -35,6 +36,8 @@ from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched from tensorflow_model_optimization.python.core.sparsity.keras.pruning_utils import convert_to_tuple_of_two_int +# TODO(b/139939526): update to use public API. + keras = tf.keras K = keras.backend Wrapper = keras.layers.Wrapper