Skip to content

Commit

Permalink
fix: Refactor RandomApply and RandomChoice layers
Browse files Browse the repository at this point in the history
- Simplified docstrings.
- Removed copyright headers from files.
- Refactored RandomApply and RandomChoice.
- Streamlined transformation logic and improved overall code organization.
- Eliminated redundant backend-specific code paths.
  • Loading branch information
harshaljanjani committed Jan 16, 2025
1 parent 9d72cd9 commit 9fd3bd2
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 336 deletions.
260 changes: 71 additions & 189 deletions keras/src/layers/preprocessing/image_preprocessing/random_apply.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,35 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import keras.src.backend as K
import keras.src.random as random
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random.seed_generator import SeedGenerator


@keras_export("keras.layers.RandomApply")
class RandomApply(BaseImagePreprocessingLayer):
"""A preprocessing layer that randomly applies a provided layer to elements
in a batch.
This layer is useful for applying data augmentations or transformations
with a specified probability. During training, each input (or batch of
inputs) has to be transformed by the provided layer, controlled by the
`rate` parameter. This allows for stochastic application of augmentations
which can improve model robustness.
**Example:**
```python
# Create a layer that zeroes out all pixels in an image
zero_out = keras.layers.Lambda(lambda x: 0 * x)
# Create a batch of random 2x2 images
images = tf.random.uniform(shape=(5, 2, 2, 1))
print(images[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0.82, 0.41],
# [0.39, 0.32]],
#
# [[0.35, 0.73],
# [0.56, 0.98]],
#
# [[0.55, 0.13],
# [0.29, 0.51]],
#
# [[0.39, 0.81],
# [0.60, 0.11]],
#
# [[0.52, 0.13],
# [0.29, 0.25]]], dtype=float32)>
# Apply the layer with 50% probability
random_apply = RandomApply(layer=zero_out, rate=0.5, seed=1234)
outputs = random_apply(images)
print(outputs[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0.00, 0.00],
# [0.00, 0.00]],
#
# [[0.35, 0.73],
# [0.56, 0.98]],
#
# [[0.55, 0.13],
# [0.29, 0.51]],
#
# [[0.39, 0.81],
# [0.60, 0.11]],
#
# [[0.00, 0.00],
# [0.00, 0.00]]], dtype=float32)>
# Observe that the layer was applied to 2 out of 5 samples.
```
**Args:**
layer: A `keras.Layer` or `BaseImagePreprocessingLayer` instance. This
will be applied to randomly selected inputs in the batch. The layer
should not modify the shape of the input.
rate: A float between 0 and 1, which is the probability of applying the
layer.
- `1.0` means the layer is applied to all inputs.
- `0.0` means the layer is never applied.
Defaults to `0.5`.
batchwise: A boolean, indicating whether the decision to apply the layer
should be made for the entire batch at once (True) or for each input
individually (False). When True, the layer is either applied to
entire batch or not at all. When False, the layer is applied
independently to each input in the batch. Defaults to False.
auto_vectorize: A boolean, for whether to use vectorized operations for
batched inputs. Can improve performance but may not work with XLA.
"""A preprocessing layer that randomly applies a specified layer during
training.
This layer randomly applies a given transformation layer to inputs based on
the `rate` parameter. It is useful for stochastic data augmentation to
improve model robustness. At inference time, the output is identical to
the input. Call the layer with `training=True` to enable random application.
Args:
layer: A `keras.Layer` to apply. The layer must not modify input shape.
rate: Float between 0.0 and 1.0, representing the probability of
applying the layer. Defaults to 0.5.
batchwise: Boolean. If True, the decision to apply the layer is made for
the entire batch. If False, it is made independently for each input.
Defaults to False.
seed: An integer, used to seed the random number generator for
reproducibility. Defaults to None.
**Call Arguments:**
inputs: Single input tensor (rank 3), a batch of input tensors (rank 4),
or a dictionary of tensors. Input will be transformed by provided
layer with probability `rate`.
**Returns:**
Transformed inputs, with the same shape and structure as the input.
seed: Optional integer to ensure reproducibility.
**Notes:**
- When `batchwise=True`, layer is applied to the entire batch or not
at all, based on a single random decision.
- When `batchwise=False`, layer applied independently to each input in
the batch, allowing for more fine-grained control.
- The provided `layer` should not modify the shape of the input, as this
could lead to inconsistencies in the output.
Inputs: A tensor (rank 3 for single input, rank 4 for batch input). The
input can have any dtype and range.
**Example with Batchwise Application:**
```python
# Apply a layer to the entire batch with 50% probability
random_apply = RandomApply(layer=zero_out, rate=0.5, batchwise=True)
outputs = random_apply(images) # Either all images zeroed out or none are
```
**Example with Per-Input Application:**
```python
# Apply a layer to each input independently with 50% probability
random_apply = RandomApply(layer=zero_out, rate=0.5, batchwise=False)
outputs = random_apply(images) # Each image independently zeroed out or
# left unchanged
```
Output: A tensor with the same shape and dtype as the input, with the
transformation layer applied to selected inputs.
"""

def __init__(
Expand All @@ -150,59 +51,62 @@ def __init__(
self.auto_vectorize = auto_vectorize
self.batchwise = batchwise
self.seed = seed
self.generator = SeedGenerator(seed)
self.built = True
if K.backend() == "jax":
self.seed_generator = random.SeedGenerator(seed)

def _get_should_augment(self, inputs, seed=None):
input_shape = ops.shape(inputs)

if self.batchwise:
return self._rate > random.uniform(shape=(), seed=seed)

batch_size = input_shape[0]
random_values = random.uniform(shape=(batch_size,), seed=seed)
should_augment = random_values < self._rate
def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None

ndims = len(inputs.shape)
ones = [1] * (ndims - 1)
broadcast_shape = tuple([batch_size] + ones)
if seed is None:
seed = self._get_seed_generator(self.backend._backend)

return ops.reshape(should_augment, broadcast_shape)
if isinstance(data, dict):
inputs = data["images"]
else:
inputs = data

def _augment_single(self, inputs, seed=None):
random_value = random.uniform(shape=(), seed=seed)
should_augment = random_value < self._rate
input_shape = self.backend.shape(inputs)
if self.batchwise:
should_augment = self._rate > self.backend.random.uniform(
shape=(), seed=seed
)
else:
batch_size = input_shape[0]
random_values = self.backend.random.uniform(
shape=(batch_size,), seed=seed
)
should_augment = random_values < self._rate

def apply_layer():
if hasattr(self._layer, "get_random_transformation"):
transformation = self._layer.get_random_transformation(
inputs, training=True, seed=seed
)
return self._layer.transform_images(
inputs, transformation, training=True
)
return self._layer(inputs)
ndims = len(input_shape)
ones = [1] * (ndims - 1)
broadcast_shape = tuple([batch_size] + ones)
should_augment = self.backend.numpy.reshape(
should_augment, broadcast_shape
)

def return_inputs():
return inputs
return {
"should_augment": should_augment,
"input_shape": input_shape,
}

return ops.cond(should_augment, apply_layer, return_inputs)
def transform_images(self, images, transformation, training=True):
if not training or transformation is None:
return images

def _augment_batch(self, inputs, seed=None):
should_augment = self._get_should_augment(inputs, seed=seed)
should_augment = transformation["should_augment"]

if hasattr(self._layer, "get_random_transformation"):
transformation = self._layer.get_random_transformation(
inputs, training=True, seed=seed
layer_transform = self._layer.get_random_transformation(
images, training=True
)
augmented = self._layer.transform_images(
inputs, transformation, training=True
images, layer_transform, training=True
)
else:
augmented = self._layer(inputs)
augmented = self._layer(images)

return ops.where(should_augment, augmented, inputs)
return self.backend.numpy.where(should_augment, augmented, images)

def call(self, inputs):
if isinstance(inputs, dict):
Expand All @@ -214,23 +118,8 @@ def call(self, inputs):
return self._call_single(inputs)

def _call_single(self, inputs):
inputs_rank = len(inputs.shape)
is_single_sample = ops.equal(inputs_rank, 3)
is_batch = ops.equal(inputs_rank, 4)

if K.backend() == "jax":
seed = self.seed_generator.next()
else:
seed = self.seed

def augment_single():
return self._augment_single(inputs, seed=seed)

def augment_batch():
return self._augment_batch(inputs, seed=seed)

condition = ops.logical_or(is_single_sample, is_batch)
return ops.cond(ops.all(condition), augment_batch, augment_single)
transformation = self.get_random_transformation(inputs, training=True)
return self.transform_images(inputs, transformation, training=True)

@staticmethod
def _num_zero_batches(images):
Expand All @@ -241,29 +130,22 @@ def _num_zero_batches(images):
num_non_zero_batches = ops.sum(ops.cast(any_nonzero, dtype="int32"))
return num_batches - num_non_zero_batches

def transform_images(self, images, transformation=None, training=True):
if not training:
return images
return self.call(images)

def transform_labels(self, labels, transformation=None, training=True):
if not training:
def transform_labels(self, labels, transformation, training=True):
if not training or transformation is None:
return labels
return self.call(labels)
return self.transform_images(labels, transformation, training)

def transform_bounding_boxes(
self, bboxes, transformation=None, training=True
):
if not training:
def transform_bounding_boxes(self, bboxes, transformation, training=True):
if not training or transformation is None:
return bboxes
return self.call(bboxes)
return self.transform_images(bboxes, transformation, training)

def transform_segmentation_masks(
self, masks, transformation=None, training=True
self, masks, transformation, training=True
):
if not training:
if not training or transformation is None:
return masks
return self.call(masks)
return self.transform_images(masks, transformation, training)

def get_config(self):
config = super().get_config()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from absl.testing import parameterized

Expand Down
Loading

0 comments on commit 9fd3bd2

Please sign in to comment.