diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index db8c361610c..2264fb694e9 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -176,6 +176,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( RandomHue, ) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index e38f87bdef3..16ed66c84b2 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -176,6 +176,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( RandomHue, ) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 881a6620f68..eb342d08079 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -120,6 +120,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_hue import ( RandomHue, ) +from keras.src.layers.preprocessing.image_preprocessing.random_posterization import ( + RandomPosterization, +) from keras.src.layers.preprocessing.image_preprocessing.random_rotation import ( RandomRotation, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py b/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py new file mode 100644 index 00000000000..55e7536724f --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_posterization.py @@ -0,0 +1,151 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) + + +@keras_export("keras.layers.RandomPosterization") +class RandomPosterization(BaseImagePreprocessingLayer): + """Reduces the number of bits for each color channel. + + References: + - [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501) + - [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719) + + Args: + value_range: a tuple or a list of two elements. The first value + represents the lower bound for values in passed images, the second + represents the upper bound. Images passed to the layer should have + values within `value_range`. Defaults to `(0, 255)`. + factor: integer, the number of bits to keep for each channel. Must be a + value between 1-8. + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (1, 8) + _MAX_FACTOR = 8 + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = self.backend.random.SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + if self.factor[0] != self.factor[1]: + factor = self.backend.random.randint( + (batch_size,), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + dtype="uint8", + ) + else: + factor = ( + self.backend.numpy.ones((batch_size,), dtype="uint8") + * self.factor[0] + ) + + shift_factor = self._MAX_FACTOR - factor + return {"shift_factor": shift_factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + shift_factor = transformation["shift_factor"] + + shift_factor = self.backend.numpy.reshape( + shift_factor, self.backend.shape(shift_factor) + (1, 1, 1) + ) + + images = self._transform_value_range( + images, + original_range=self.value_range, + target_range=(0, 255), + dtype=self.compute_dtype, + ) + + images = self.backend.cast(images, "uint8") + images = self.backend.numpy.bitwise_left_shift( + self.backend.numpy.bitwise_right_shift(images, shift_factor), + shift_factor, + ) + images = self.backend.cast(images, self.compute_dtype) + + images = self._transform_value_range( + images, + original_range=(0, 255), + target_range=self.value_range, + dtype=self.compute_dtype, + ) + + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py new file mode 100644 index 00000000000..347f82a3a96 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py @@ -0,0 +1,85 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomPosterizationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomPosterization, + init_kwargs={ + "factor": 1, + "value_range": (20, 200), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_posterization_inference(self): + seed = 3481 + layer = layers.RandomPosterization(1, [0, 255]) + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_random_posterization_basic(self): + seed = 3481 + layer = layers.RandomPosterization( + 1, [0, 255], data_format="channels_last", seed=seed + ) + np.random.seed(seed) + inputs = np.asarray( + [[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]] + ) + output = layer(inputs) + expected_output = np.asarray( + [[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]] + ) + self.assertAllClose(expected_output, output) + + def test_random_posterization_value_range_0_to_1(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomPosterization(1, [0, 1.0]) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_posterization_value_range_0_to_255(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255) + + layer = layers.RandomPosterization(1, [0, 255]) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255)) + + def test_random_posterization_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomPosterization(1, [0, 255]) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomPosterization(1, [0, 255]) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy()