Skip to content

Commit

Permalink
Add random_invert layer
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Jan 20, 2025
1 parent a25881c commit 4e8cc29
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
from keras.src.layers.preprocessing.image_preprocessing.random_invert import (
RandomInvert,
)
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
RandomPosterization,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
from keras.src.layers.preprocessing.image_preprocessing.random_invert import (
RandomInvert,
)
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
RandomPosterization,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_hue import (
RandomHue,
)
from keras.src.layers.preprocessing.image_preprocessing.random_invert import (
RandomInvert,
)
from keras.src.layers.preprocessing.image_preprocessing.random_posterization import (
RandomPosterization,
)
Expand Down
126 changes: 126 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_invert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)


@keras_export("keras.layers.RandomInvert")
class RandomInvert(BaseImagePreprocessingLayer):
"""Preprocessing layer for random inversion of image colors.
This layer randomly inverts the colors of input images with a specified
probability range. When applied, each image has a chance of having its
colors inverted, where the pixel values are transformed to their
complementary values. Images that are not selected for inversion
remain unchanged.
Args:
factor: A single float or a tuple of two floats.
`factor` controls the probability of inverting the image colors.
If a tuple is provided, the value is sampled between the two values
for each image, where `factor[0]` is the minimum and `factor[1]` is
the maximum probability. If a single float is provided, a value
between `0.0` and the provided float is sampled.
Defaults to `(0, 1)`.
value_range: a tuple or a list of two elements. The first value
represents the lower bound for values in passed images, the second
represents the upper bound. Images passed to the layer should have
values within `value_range`. Defaults to `(0, 255)`.
seed: Integer. Used to create a random seed.
"""

_USE_BASE_FACTOR = False
_FACTOR_BOUNDS = (0, 1)

def __init__(
self,
factor=1.0,
value_range=(0, 255),
seed=None,
data_format=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self.value_range = value_range
self.seed = seed
self.generator = self.backend.random.SeedGenerator(seed)

def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None

if isinstance(data, dict):
images = data["images"]
else:
images = data

seed = seed or self._get_seed_generator(self.backend._backend)

images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
batch_size = 1
elif rank == 4:
batch_size = images_shape[0]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received "
f"inputs.shape={images_shape}"
)

invert_probability = self.backend.random.uniform(
shape=(batch_size,),
minval=self.factor[0],
maxval=self.factor[1],
seed=seed,
)

random_threshold = self.backend.random.uniform(
shape=(batch_size,),
minval=0,
maxval=1,
seed=seed,
)

apply_inversion = random_threshold < invert_probability
return {"apply_inversion": apply_inversion}

def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)
apply_inversion = transformation["apply_inversion"]
return self.backend.numpy.where(
apply_inversion[:, None, None, None],
self.value_range[1] - images,
images,
)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"factor": self.factor,
"value_range": self.value_range,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandomInvertTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomInvert,
init_kwargs={
"factor": 0.75,
"value_range": (20, 200),
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_random_invert_inference(self):
seed = 3481
layer = layers.RandomInvert()
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_random_invert_no_op(self):
seed = 3481
layer = layers.RandomInvert(factor=0)
np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs)
self.assertAllClose(inputs, output)

def test_random_invert_basic(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((1, 8, 8, 3))
else:
input_data = np.random.random((1, 3, 8, 8))
layer = layers.RandomInvert(
factor=(1, 1),
value_range=[0, 1],
data_format=data_format,
seed=1337,
)
output = layer(input_data)
self.assertAllClose(1 - input_data, output)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandomInvert(
factor=0.5, value_range=[0, 1], data_format=data_format, seed=1337
)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit 4e8cc29

Please sign in to comment.