Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rand_augment processing layer #20716

Merged
merged 12 commits into from
Jan 7, 2025
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
235 changes: 235 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import random

import keras.src.layers as layers
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator
from keras.src.utils import backend_utils


@keras_export("keras.layers.RandAugment")
class RandAugment(BaseImagePreprocessingLayer):
"""RandAugment performs the Rand Augment operation on input images.

This layer can be thought of as an all-in-one image augmentation layer. The
policy implemented by this layer has been benchmarked extensively and is
effective on a wide variety of datasets.

References:
- [RandAugment](https://arxiv.org/abs/1909.13719)

Args:
value_range: The range of values the input image can take.
Default is `(0, 255)`. Typically, this would be `(0, 1)`
for normalized images or `(0, 255)` for raw images.
num_ops: The number of augmentation operations to apply sequentially
to each image. Default is 2.
factor: The strength of the augmentation as a normalized value
between 0 and 1. Default is 0.5.
interpolation: The interpolation method to use for resizing operations.
Options include `nearest`, `bilinear`. Default is `bilinear`.
seed: Integer. Used to create a random seed.

"""

_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (0, 1)

_AUGMENT_LAYERS = [
"random_shear",
"random_translation",
"random_rotation",
"random_brightness",
"random_color_degeneration",
"random_contrast",
"random_sharpness",
"random_posterization",
"solarization",
"auto_contrast",
"equalization",
]

def __init__(
self,
value_range=(0, 255),
num_ops=2,
factor=0.5,
interpolation="bilinear",
seed=None,
data_format=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)

self.value_range = value_range
self.num_ops = num_ops
self._set_factor(factor)
self.interpolation = interpolation
self.seed = seed
self.generator = SeedGenerator(seed)

self.random_shear = layers.RandomShear(
x_factor=self.factor,
y_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_translation = layers.RandomTranslation(
height_factor=self.factor,
width_factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_rotation = layers.RandomRotation(
factor=self.factor,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_brightness = layers.RandomBrightness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_color_degeneration = layers.RandomColorDegeneration(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_contrast = layers.RandomContrast(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_sharpness = layers.RandomSharpness(
factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.solarization = layers.Solarization(
addition_factor=self.factor,
threshold_factor=self.factor,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.random_posterization = layers.RandomPosterization(
factor=max(1, int(8 * self.factor[1])),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
**kwargs,
)

self.auto_contrast = layers.AutoContrast(
value_range=self.value_range, data_format=data_format, **kwargs
)

self.equalization = layers.Equalization(
value_range=self.value_range, data_format=data_format, **kwargs
)

def build(self, input_shape):
for layer_name in self._AUGMENT_LAYERS:
augmentation_layer = getattr(self, layer_name)
augmentation_layer.build(input_shape)

def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None

if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

for layer_name in self._AUGMENT_LAYERS:
augmentation_layer = getattr(self, layer_name)
augmentation_layer.backend.set_backend("tensorflow")

transformation = {}
random.shuffle(self._AUGMENT_LAYERS)
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
augmentation_layer = getattr(self, layer_name)
transformation[layer_name] = (
augmentation_layer.get_random_transformation(
data,
training=training,
seed=self._get_seed_generator(self.backend._backend),
)
)

return transformation

def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)

for layer_name, transformation_value in transformation.items():
augmentation_layer = getattr(self, layer_name)
images = augmentation_layer.transform_images(
images, transformation_value
)

images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
if training:
for layer_name, transformation_value in transformation.items():
augmentation_layer = getattr(self, layer_name)
bounding_boxes = augmentation_layer.transform_bounding_boxes(
bounding_boxes, transformation_value, training=training
)
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return self.transform_images(
segmentation_masks, transformation, training=training
)

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"value_range": self.value_range,
"num_ops": self.num_ops,
"factor": self.factor,
"interpolation": self.interpolation,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandAugmentTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandAugment,
init_kwargs={
"value_range": (0, 255),
"num_ops": 2,
"factor": 1,
"interpolation": "nearest",
"seed": 1,
"data_format": "channels_last",
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_rand_augment_inference(self):
seed = 3481
layer = layers.RandAugment()

np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_rand_augment_basic(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment(data_format=data_format)

augmented_image = layer(input_data)
self.assertEqual(augmented_image.shape, input_data.shape)

def test_rand_augment_no_operations(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment(num_ops=0, data_format=data_format)

augmented_image = layer(input_data)
self.assertAllClose(
backend.convert_to_numpy(augmented_image), input_data
)

def test_random_augment_randomness(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))

layer = layers.RandAugment(num_ops=11, data_format=data_format)
augmented_image = layer(input_data)

self.assertNotAllClose(
backend.convert_to_numpy(augmented_image), input_data
)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment(data_format=data_format)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

def test_rand_augment_tf_data_bounding_boxes(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (1, 10, 8, 3)
else:
image_shape = (1, 3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
]
),
"labels": np.array([[1, 2]]),
}

input_data = {"images": input_image, "bounding_boxes": bounding_boxes}

ds = tf_data.Dataset.from_tensor_slices(input_data)
layer = layers.RandAugment(
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)
ds.map(layer)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_correctness(self):
seed = 2390

# Always scale up, but randomly between 0 ~ 255
layer = layers.RandomBrightness([0, 1.0])
layer = layers.RandomBrightness([0.1, 1.0])
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = backend.convert_to_numpy(layer(inputs))
Expand All @@ -44,7 +44,7 @@ def test_correctness(self):
self.assertTrue(np.mean(diff) > 0)

# Always scale down, but randomly between 0 ~ 255
layer = layers.RandomBrightness([-1.0, 0.0])
layer = layers.RandomBrightness([-1.0, -0.1])
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = backend.convert_to_numpy(layer(inputs))
Expand Down
Loading
Loading