diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index f9719bfe442..59f241cbaf2 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -96,6 +96,9 @@ MaxNumBoundingBoxes, ) from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py b/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py index 02c65be7124..fdcbfa793d8 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +++ b/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py @@ -1,3 +1,5 @@ +import random + import keras.src.layers as layers from keras.src.api_export import keras_export from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 @@ -15,107 +17,123 @@ class RandAugment(BaseImagePreprocessingLayer): policy implemented by this layer has been benchmarked extensively and is effective on a wide variety of datasets. - The policy operates as follows: - - For each augmentation in the range `[0, augmentations_per_image]`, - the policy selects a random operation from a list of operations. - It then samples a random number and if that number is less than - `rate` applies it to the given image. - References: - [RandAugment](https://arxiv.org/abs/1909.13719) + Args: + value_range: + The range of values the input image can take. Default is (0, 255). + Typically, this would be (0, 1) for normalized images + or (0, 255) for raw images. + num_ops: + The number of augmentation operations to apply sequentially + to each image. Default is 2. + magnitude: + The strength of the augmentation as a normalized value + between 0 and 1. Default is 0.5. + interpolation: + The interpolation method to use for resizing operations. + Options include "nearest", "bilinear". Default is "bilinear". + seed: + Integer. Used to create a random seed. + """ - _AUGMENT_LAYERS = ["Identity", "random_shear", "random_translation", "random_rotation", - "random_brightness", "random_color_degeneration", "random_contrast", - "random_sharpness", "random_posterization", "solarization", "auto_contrast", "equalization"] + _AUGMENT_LAYERS = [ + "identity", + "random_shear", + "random_translation", + "random_rotation", + "random_brightness", + "random_color_degeneration", + "random_contrast", + "random_sharpness", + "random_posterization", + "solarization", + "auto_contrast", + "equalization", + ] def __init__( - self, - value_range=(0, 255), - num_ops=2, - magnitude=9, - num_magnitude_bins=31, - interpolation="nearest", - seed=None, - data_format=None, - **kwargs, + self, + value_range=(0, 255), + num_ops=2, + magnitude=0.5, + interpolation="bilinear", + seed=None, + data_format=None, + **kwargs, ): super().__init__(data_format=data_format, **kwargs) self.value_range = value_range self.num_ops = num_ops self.magnitude = magnitude - self.num_magnitude_bins = num_magnitude_bins self.interpolation = interpolation self.seed = seed self.generator = SeedGenerator(seed) - augmentation_space = self._augmentation_space(self.num_magnitude_bins) - - random_indices = self.backend.random.randint([11], 0, self.num_magnitude_bins, seed=self.seed) self.random_shear = layers.RandomShear( - x_factor=float(augmentation_space["ShearX"][int(random_indices[0])]), - y_factor=float(augmentation_space["ShearY"][int(random_indices[1])]), + x_factor=self.magnitude, + y_factor=self.magnitude, + interpolation=interpolation, seed=self.seed, data_format=data_format, ) self.random_translation = layers.RandomTranslation( - height_factor=float(augmentation_space["TranslateX"][int(random_indices[2])]), - width_factor=float(augmentation_space["TranslateY"][int(random_indices[3])]), + height_factor=self.magnitude, + width_factor=self.magnitude, + interpolation=interpolation, seed=self.seed, data_format=data_format, ) self.random_rotation = layers.RandomRotation( - factor=float(augmentation_space["Rotate"][int(random_indices[4])]), + factor=self.magnitude, + interpolation=interpolation, seed=self.seed, data_format=data_format, ) self.random_brightness = layers.RandomBrightness( - factor=float(augmentation_space["Brightness"][int(random_indices[5])]), + factor=self.magnitude, value_range=self.value_range, seed=self.seed, data_format=data_format, ) self.random_color_degeneration = layers.RandomColorDegeneration( - factor=float(augmentation_space["Color"][int(random_indices[6])]), + factor=self.magnitude, value_range=self.value_range, seed=self.seed, data_format=data_format, ) self.random_contrast = layers.RandomContrast( - factor=float(augmentation_space["Contrast"][int(random_indices[7])]), + factor=self.magnitude, value_range=self.value_range, seed=self.seed, data_format=data_format, ) self.random_sharpness = layers.RandomSharpness( - factor=float(augmentation_space["Sharpness"][int(random_indices[8])]), + factor=self.magnitude, value_range=self.value_range, seed=self.seed, data_format=data_format, ) self.solarization = layers.Solarization( - addition_factor=int(augmentation_space["Solarize"][int(random_indices[9])]), - threshold_factor=int(augmentation_space["Solarize"][int(random_indices[10])]), + addition_factor=self.magnitude, + threshold_factor=self.magnitude, value_range=self.value_range, seed=self.seed, data_format=data_format, ) - random_indices = self.backend.random.uniform((1,), - minval=0, maxval=len(augmentation_space["Posterize"]), - seed=self.seed) self.random_posterization = layers.RandomPosterization( - factor=int(augmentation_space["Posterize"][int(random_indices[0])]), + factor=max(1, int(8 * self.magnitude)), value_range=self.value_range, seed=self.seed, data_format=data_format, @@ -131,21 +149,6 @@ def __init__( data_format=data_format, ) - def _augmentation_space(self, num_bins): - return { - "ShearX": self.backend.numpy.linspace(0.0, 1.0, num_bins), - "ShearY": self.backend.numpy.linspace(0.0, 1.0, num_bins), - "TranslateX": self.backend.numpy.linspace(-1, 1, num_bins), - "TranslateY": self.backend.numpy.linspace(-1, 1, num_bins), - "Rotate": self.backend.numpy.linspace(-1, 1, num_bins), - "Brightness": self.backend.numpy.linspace(-1, 1, num_bins), - "Color": self.backend.numpy.linspace(0.0, 1.0, num_bins), - "Contrast": self.backend.numpy.linspace(0.0, 1.0, num_bins), - "Sharpness": self.backend.numpy.linspace(0.0, 1.0, num_bins), - "Solarize": self.backend.numpy.linspace(0.0, 1.0, num_bins), - "Posterize": 8. - (self.backend.numpy.arange(num_bins, dtype='float32') / ((num_bins - 1.) / 4)), - } - def get_random_transformation(self, data, training=True, seed=None): if not training: return None @@ -154,24 +157,24 @@ def get_random_transformation(self, data, training=True, seed=None): self.backend.set_backend("tensorflow") for layer_name in self._AUGMENT_LAYERS: - if layer_name == "Identity": + if layer_name == "identity": continue augmentation_layer = getattr(self, layer_name) augmentation_layer.backend.set_backend("tensorflow") transformation = {} - random_indices = self.backend.random.shuffle( - self.backend.numpy.arange(len(self._AUGMENT_LAYERS)), - seed=self.seed)[:self.num_ops] - for layer_idx in random_indices: - layer_name = self._AUGMENT_LAYERS[layer_idx] - if layer_name == "Identity": + random.shuffle(self._AUGMENT_LAYERS) + for layer_name in self._AUGMENT_LAYERS[: self.num_ops]: + if layer_name == "identity": continue augmentation_layer = getattr(self, layer_name) - transformation[layer_name] = augmentation_layer.get_random_transformation(data, - training=training, - seed=self._get_seed_generator( - self.backend._backend)) + transformation[layer_name] = ( + augmentation_layer.get_random_transformation( + data, + training=training, + seed=self._get_seed_generator(self.backend._backend), + ) + ) return transformation @@ -192,15 +195,15 @@ def transform_labels(self, labels, transformation, training=True): return labels def transform_bounding_boxes( - self, - bounding_boxes, - transformation, - training=True, + self, + bounding_boxes, + transformation, + training=True, ): return bounding_boxes def transform_segmentation_masks( - self, segmentation_masks, transformation, training=True + self, segmentation_masks, transformation, training=True ): return segmentation_masks @@ -212,7 +215,6 @@ def get_config(self): "value_range": self.value_range, "num_ops": self.num_ops, "magnitude": self.magnitude, - "num_magnitude_bins": self.num_magnitude_bins, "interpolation": self.interpolation, "seed": self.seed, } diff --git a/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py b/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py new file mode 100644 index 00000000000..0fe40b10dab --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandAugmentTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandAugment, + init_kwargs={ + "value_range": (0, 255), + "num_ops": 2, + "magnitude": 1, + "interpolation": "nearest", + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_rand_augment_inference(self): + seed = 3481 + layer = layers.RandAugment() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment() + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy()