Skip to content

Commit

Permalink
Migrate RandomChoice and RandomApply layers from keras-cv to keras
Browse files Browse the repository at this point in the history
- Added RandomChoice and RandomApply layers to keras.layers.preprocessing.image_preprocessing.
- Implemented necessary transform methods (transform_images, transform_labels, transform_bounding_boxes, transform_segmentation_masks) to comply with BaseImagePreprocessingLayer.
- Updated tests to ensure compatibility with keras and fix failing test cases.
- Added support for batchwise processing, auto-vectorization, and random seed control.
  • Loading branch information
harshaljanjani committed Jan 13, 2025
1 parent e010829 commit 40c285e
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 2 deletions.
6 changes: 6 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@
from keras.src.layers.preprocessing.image_preprocessing.solarization import (
Solarization,
)
from keras.src.layers.preprocessing.image_preprocessing.random_choice import (
RandomChoice,
)
from keras.src.layers.preprocessing.image_preprocessing.random_apply import (
RandomApply,
)
from keras.src.layers.preprocessing.index_lookup import IndexLookup
from keras.src.layers.preprocessing.integer_lookup import IntegerLookup
from keras.src.layers.preprocessing.mel_spectrogram import MelSpectrogram
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
densify_bounding_boxes,
)
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer

import tensorflow as tf

class BaseImagePreprocessingLayer(TFDataLayer):
_USE_BASE_FACTOR = True
_FACTOR_BOUNDS = (-1, 1)

def __init__(
self, factor=None, bounding_box_format=None, data_format=None, **kwargs
self, factor=None, bounding_box_format=None, data_format=None, seed=None, **kwargs
):
if seed is None:
self.random_generator = tf.random.Generator.from_non_deterministic_state()
else:
self.random_generator = tf.random.Generator.from_seed(seed)
self.seed = seed
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.data_format = backend_config.standardize_data_format(data_format)
Expand Down
192 changes: 192 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/random_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
import tensorflow as tf

@keras_export("keras.layers.RandomApply")
class RandomApply(BaseImagePreprocessingLayer):
"""Apply provided layer to random elements in a batch.
Args:
layer: a keras `Layer` or `BaseImagePreprocessingLayer`. This layer will
be applied to randomly chosen samples in a batch. Layer should not
modify the size of provided inputs.
rate: controls the frequency of applying the layer. 1.0 means all
elements in a batch will be modified. 0.0 means no elements will be
modified. Defaults to 0.5.
batchwise: (Optional) bool, whether to pass entire batches to the
underlying layer. When set to true, only a single random sample is
drawn to determine if the batch should be passed to the underlying
layer.
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for
batched input. Setting this to True might give better performance
but currently doesn't work with XLA. Defaults to False.
seed: integer, controls random behaviour.
Example:
```
# Let's declare an example layer that will set all image pixels to zero.
zero_out = keras.layers.Lambda(lambda x: {"images": 0 * x["images"]})
# Create a small batch of random, single-channel, 2x2 images:
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])
print(images[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0.08216608, 0.40928006],
# [0.39318466, 0.3162533 ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0.51828027, 0.12730157],
# [0.288486 , 0.252975 ]]], dtype=float32)>
# Apply the layer with 50% probability:
random_apply = RandomApply(layer=zero_out, rate=0.5, seed=1234)
outputs = random_apply(images)
print(outputs[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0. , 0. ],
# [0. , 0. ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0. , 0. ],
# [0. , 0. ]]], dtype=float32)>
# We can observe that the layer has been randomly applied to 2 out of 5
samples.
```
"""
def __init__(
self,
layer,
rate=0.5,
batchwise=False,
auto_vectorize=False,
seed=None,
**kwargs,
):
super().__init__(seed=seed, **kwargs)
if not (0 <= rate <= 1.0):
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}")
self._layer = layer
self._rate = rate
self.auto_vectorize = auto_vectorize
self.batchwise = batchwise
self.seed = seed
self.built = True

def _get_should_augment(self, inputs):
input_shape = tf.shape(inputs)

if self.batchwise:
return self._rate > tf.random.uniform(shape=(), seed=self.seed)

batch_size = input_shape[0]
random_values = tf.random.uniform(shape=(batch_size,), seed=self.seed)
should_augment = random_values < self._rate

ndims = tf.rank(inputs)
broadcast_shape = tf.concat(
[input_shape[:1], tf.ones(ndims - 1, dtype=tf.int32)],
axis=0
)
return tf.reshape(should_augment, broadcast_shape)

def _augment_single(self, inputs):
random_value = tf.random.uniform(shape=(), seed=self.seed)
should_augment = random_value < self._rate

def apply_layer():
return self._layer(inputs)

def return_inputs():
return inputs

return tf.cond(should_augment, apply_layer, return_inputs)

def _augment_batch(self, inputs):
should_augment = self._get_should_augment(inputs)
augmented = self._layer(inputs)
return tf.where(should_augment, augmented, inputs)

def call(self, inputs):
if isinstance(inputs, dict):
return {key: self._call_single(input_tensor) for
key, input_tensor in inputs.items()}
else:
return self._call_single(inputs)

def _call_single(self, inputs):
inputs_rank = tf.rank(inputs)
is_single_sample = tf.equal(inputs_rank, 3)
is_batch = tf.equal(inputs_rank, 4)

def augment_single():
return self._augment_single(inputs)

def augment_batch():
return self._augment_batch(inputs)

condition = tf.logical_or(is_single_sample, is_batch)
return tf.cond(tf.reduce_all(condition), augment_batch, augment_single)

def transform_images(self, images, transformation=None, training=True):
if not training:
return images
return self.call(images)

def transform_labels(self, labels, transformation=None, training=True):
if not training:
return labels
return self.call(labels)

def transform_bounding_boxes(self, bboxes, transformation=None, training=True):
if not training:
return bboxes
return self.call(bboxes)

def transform_segmentation_masks(self, masks, transformation=None, training=True):
if not training:
return masks
return self.call(masks)

def get_config(self):
config = super().get_config()
config.update({
"rate": self._rate,
"layer": self._layer,
"seed": self.seed,
"batchwise": self.batchwise,
"auto_vectorize": self.auto_vectorize,
})
return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras.src import layers
from keras.src import ops
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.testing import TestCase


class ZeroOut(BaseImagePreprocessingLayer):
"""Layer that zeros out tensors."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.built = True

def call(self, inputs):
return tf.zeros_like(inputs)

def transform_images(self, images, transformation=None, training=True):
return tf.zeros_like(images)

def transform_segmentation_masks(self, masks, transformation=None, training=True):
return tf.zeros_like(masks)

def transform_bounding_boxes(self, bboxes, transformation=None, training=True):
return tf.zeros_like(bboxes)

def transform_labels(self, labels, transformation=None, training=True):
return tf.zeros_like(labels)

def get_config(self):
return super().get_config()


class RandomApplyTest(TestCase):
rng = tf.random.Generator.from_seed(seed=1234)

@parameterized.parameters([-0.5, 1.7])
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate):
with self.assertRaises(ValueError):
layers.RandomApply(rate=invalid_rate, layer=ZeroOut())

def test_works_with_batched_input(self):
batch_size = 32
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), seed=1234)

outputs = ops.convert_to_numpy(layer(dummy_inputs))
num_zero_inputs = self._num_zero_batches(dummy_inputs)
num_zero_outputs = self._num_zero_batches(outputs)

self.assertEqual(num_zero_inputs, 0)
self.assertLess(num_zero_outputs, batch_size)
self.assertGreater(num_zero_outputs, 0)

def test_works_with_batchwise_layers(self):
batch_size = 32
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
random_flip_layer = layers.RandomFlip("vertical", data_format="channels_last", seed=42)
layer = layers.RandomApply(random_flip_layer, rate=0.5, batchwise=True)
outputs = layer(dummy_inputs)
self.assertEqual(outputs.shape, dummy_inputs.shape)

@staticmethod
def _num_zero_batches(images):
num_batches = tf.shape(images)[0]
num_non_zero_batches = tf.math.count_nonzero(
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32
)
return num_batches - num_non_zero_batches

def test_inputs_unchanged_with_zero_rate(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=0.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllClose(outputs, dummy_inputs)

def test_all_inputs_changed_with_rate_equal_to_one(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer(dummy_inputs)
tf.reduce_all(tf.equal(outputs, tf.zeros_like(dummy_inputs)))

def test_works_with_single_image(self):
dummy_inputs = self.rng.uniform(shape=(224, 224, 3))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer(dummy_inputs)
tf.reduce_all(tf.equal(outputs, tf.zeros_like(dummy_inputs)))

def test_can_modify_label(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
dummy_labels = tf.ones(shape=(32, 2))
layer = layers.RandomApply(rate=1.0, layer=ZeroOut())
outputs = layer({"images": dummy_inputs, "labels": dummy_labels})
tf.reduce_all(tf.equal(outputs["labels"], tf.zeros_like(dummy_labels)))

def test_works_with_xla(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False)

@tf.function(jit_compile=True)
def apply(x):
return layer(x)

outputs = apply(dummy_inputs)
Loading

0 comments on commit 40c285e

Please sign in to comment.