Skip to content

Commit

Permalink
Add rand_augment
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Jan 3, 2025
1 parent 288f0f2 commit 11d682d
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 70 deletions.
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@
MaxNumBoundingBoxes,
)
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
RandAugment,
)
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
Expand Down
142 changes: 72 additions & 70 deletions keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random

import keras.src.layers as layers
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
Expand All @@ -15,107 +17,123 @@ class RandAugment(BaseImagePreprocessingLayer):
policy implemented by this layer has been benchmarked extensively and is
effective on a wide variety of datasets.
The policy operates as follows:
For each augmentation in the range `[0, augmentations_per_image]`,
the policy selects a random operation from a list of operations.
It then samples a random number and if that number is less than
`rate` applies it to the given image.
References:
- [RandAugment](https://arxiv.org/abs/1909.13719)
Args:
value_range:
The range of values the input image can take. Default is (0, 255).
Typically, this would be (0, 1) for normalized images
or (0, 255) for raw images.
num_ops:
The number of augmentation operations to apply sequentially
to each image. Default is 2.
magnitude:
The strength of the augmentation as a normalized value
between 0 and 1. Default is 0.5.
interpolation:
The interpolation method to use for resizing operations.
Options include "nearest", "bilinear". Default is "bilinear".
seed:
Integer. Used to create a random seed.
"""

_AUGMENT_LAYERS = ["Identity", "random_shear", "random_translation", "random_rotation",
"random_brightness", "random_color_degeneration", "random_contrast",
"random_sharpness", "random_posterization", "solarization", "auto_contrast", "equalization"]
_AUGMENT_LAYERS = [
"identity",
"random_shear",
"random_translation",
"random_rotation",
"random_brightness",
"random_color_degeneration",
"random_contrast",
"random_sharpness",
"random_posterization",
"solarization",
"auto_contrast",
"equalization",
]

def __init__(
self,
value_range=(0, 255),
num_ops=2,
magnitude=9,
num_magnitude_bins=31,
interpolation="nearest",
seed=None,
data_format=None,
**kwargs,
self,
value_range=(0, 255),
num_ops=2,
magnitude=0.5,
interpolation="bilinear",
seed=None,
data_format=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)

self.value_range = value_range
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.seed = seed
self.generator = SeedGenerator(seed)

augmentation_space = self._augmentation_space(self.num_magnitude_bins)

random_indices = self.backend.random.randint([11], 0, self.num_magnitude_bins, seed=self.seed)
self.random_shear = layers.RandomShear(
x_factor=float(augmentation_space["ShearX"][int(random_indices[0])]),
y_factor=float(augmentation_space["ShearY"][int(random_indices[1])]),
x_factor=self.magnitude,
y_factor=self.magnitude,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
)

self.random_translation = layers.RandomTranslation(
height_factor=float(augmentation_space["TranslateX"][int(random_indices[2])]),
width_factor=float(augmentation_space["TranslateY"][int(random_indices[3])]),
height_factor=self.magnitude,
width_factor=self.magnitude,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
)

self.random_rotation = layers.RandomRotation(
factor=float(augmentation_space["Rotate"][int(random_indices[4])]),
factor=self.magnitude,
interpolation=interpolation,
seed=self.seed,
data_format=data_format,
)

self.random_brightness = layers.RandomBrightness(
factor=float(augmentation_space["Brightness"][int(random_indices[5])]),
factor=self.magnitude,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

self.random_color_degeneration = layers.RandomColorDegeneration(
factor=float(augmentation_space["Color"][int(random_indices[6])]),
factor=self.magnitude,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

self.random_contrast = layers.RandomContrast(
factor=float(augmentation_space["Contrast"][int(random_indices[7])]),
factor=self.magnitude,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

self.random_sharpness = layers.RandomSharpness(
factor=float(augmentation_space["Sharpness"][int(random_indices[8])]),
factor=self.magnitude,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

self.solarization = layers.Solarization(
addition_factor=int(augmentation_space["Solarize"][int(random_indices[9])]),
threshold_factor=int(augmentation_space["Solarize"][int(random_indices[10])]),
addition_factor=self.magnitude,
threshold_factor=self.magnitude,
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
)

random_indices = self.backend.random.uniform((1,),
minval=0, maxval=len(augmentation_space["Posterize"]),
seed=self.seed)
self.random_posterization = layers.RandomPosterization(
factor=int(augmentation_space["Posterize"][int(random_indices[0])]),
factor=max(1, int(8 * self.magnitude)),
value_range=self.value_range,
seed=self.seed,
data_format=data_format,
Expand All @@ -131,21 +149,6 @@ def __init__(
data_format=data_format,
)

def _augmentation_space(self, num_bins):
return {
"ShearX": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"ShearY": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"TranslateX": self.backend.numpy.linspace(-1, 1, num_bins),
"TranslateY": self.backend.numpy.linspace(-1, 1, num_bins),
"Rotate": self.backend.numpy.linspace(-1, 1, num_bins),
"Brightness": self.backend.numpy.linspace(-1, 1, num_bins),
"Color": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Contrast": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Sharpness": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Solarize": self.backend.numpy.linspace(0.0, 1.0, num_bins),
"Posterize": 8. - (self.backend.numpy.arange(num_bins, dtype='float32') / ((num_bins - 1.) / 4)),
}

def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None
Expand All @@ -154,24 +157,24 @@ def get_random_transformation(self, data, training=True, seed=None):
self.backend.set_backend("tensorflow")

for layer_name in self._AUGMENT_LAYERS:
if layer_name == "Identity":
if layer_name == "identity":
continue
augmentation_layer = getattr(self, layer_name)
augmentation_layer.backend.set_backend("tensorflow")

transformation = {}
random_indices = self.backend.random.shuffle(
self.backend.numpy.arange(len(self._AUGMENT_LAYERS)),
seed=self.seed)[:self.num_ops]
for layer_idx in random_indices:
layer_name = self._AUGMENT_LAYERS[layer_idx]
if layer_name == "Identity":
random.shuffle(self._AUGMENT_LAYERS)
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
if layer_name == "identity":
continue
augmentation_layer = getattr(self, layer_name)
transformation[layer_name] = augmentation_layer.get_random_transformation(data,
training=training,
seed=self._get_seed_generator(
self.backend._backend))
transformation[layer_name] = (
augmentation_layer.get_random_transformation(
data,
training=training,
seed=self._get_seed_generator(self.backend._backend),
)
)

return transformation

Expand All @@ -192,15 +195,15 @@ def transform_labels(self, labels, transformation, training=True):
return labels

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
self,
bounding_boxes,
transformation,
training=True,
):
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

Expand All @@ -212,7 +215,6 @@ def get_config(self):
"value_range": self.value_range,
"num_ops": self.num_ops,
"magnitude": self.magnitude,
"num_magnitude_bins": self.num_magnitude_bins,
"interpolation": self.interpolation,
"seed": self.seed,
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandAugmentTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandAugment,
init_kwargs={
"value_range": (0, 255),
"num_ops": 2,
"magnitude": 1,
"interpolation": "nearest",
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_rand_augment_inference(self):
seed = 3481
layer = layers.RandAugment()

np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandAugment()

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit 11d682d

Please sign in to comment.