-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate RandomChoice and RandomApply layers from keras-cv to keras
- Added RandomChoice and RandomApply layers to keras.layers.preprocessing.image_preprocessing. - Implemented necessary transform methods (transform_images, transform_labels, transform_bounding_boxes, transform_segmentation_masks) to comply with BaseImagePreprocessingLayer. - Updated tests to ensure compatibility with keras and fix failing test cases. - Added support for batchwise processing, auto-vectorization, and random seed control.
- Loading branch information
1 parent
e010829
commit 40c285e
Showing
6 changed files
with
586 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
192 changes: 192 additions & 0 deletions
192
keras/src/layers/preprocessing/image_preprocessing/random_apply.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
import tensorflow as tf | ||
|
||
@keras_export("keras.layers.RandomApply") | ||
class RandomApply(BaseImagePreprocessingLayer): | ||
"""Apply provided layer to random elements in a batch. | ||
Args: | ||
layer: a keras `Layer` or `BaseImagePreprocessingLayer`. This layer will | ||
be applied to randomly chosen samples in a batch. Layer should not | ||
modify the size of provided inputs. | ||
rate: controls the frequency of applying the layer. 1.0 means all | ||
elements in a batch will be modified. 0.0 means no elements will be | ||
modified. Defaults to 0.5. | ||
batchwise: (Optional) bool, whether to pass entire batches to the | ||
underlying layer. When set to true, only a single random sample is | ||
drawn to determine if the batch should be passed to the underlying | ||
layer. | ||
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for | ||
batched input. Setting this to True might give better performance | ||
but currently doesn't work with XLA. Defaults to False. | ||
seed: integer, controls random behaviour. | ||
Example: | ||
``` | ||
# Let's declare an example layer that will set all image pixels to zero. | ||
zero_out = keras.layers.Lambda(lambda x: {"images": 0 * x["images"]}) | ||
# Create a small batch of random, single-channel, 2x2 images: | ||
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1]) | ||
print(images[..., 0]) | ||
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy= | ||
# array([[[0.08216608, 0.40928006], | ||
# [0.39318466, 0.3162533 ]], | ||
# | ||
# [[0.34717774, 0.73199546], | ||
# [0.56369007, 0.9769211 ]], | ||
# | ||
# [[0.55243933, 0.13101244], | ||
# [0.2941643 , 0.5130266 ]], | ||
# | ||
# [[0.38977218, 0.80855536], | ||
# [0.6040567 , 0.10502195]], | ||
# | ||
# [[0.51828027, 0.12730157], | ||
# [0.288486 , 0.252975 ]]], dtype=float32)> | ||
# Apply the layer with 50% probability: | ||
random_apply = RandomApply(layer=zero_out, rate=0.5, seed=1234) | ||
outputs = random_apply(images) | ||
print(outputs[..., 0]) | ||
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy= | ||
# array([[[0. , 0. ], | ||
# [0. , 0. ]], | ||
# | ||
# [[0.34717774, 0.73199546], | ||
# [0.56369007, 0.9769211 ]], | ||
# | ||
# [[0.55243933, 0.13101244], | ||
# [0.2941643 , 0.5130266 ]], | ||
# | ||
# [[0.38977218, 0.80855536], | ||
# [0.6040567 , 0.10502195]], | ||
# | ||
# [[0. , 0. ], | ||
# [0. , 0. ]]], dtype=float32)> | ||
# We can observe that the layer has been randomly applied to 2 out of 5 | ||
samples. | ||
``` | ||
""" | ||
def __init__( | ||
self, | ||
layer, | ||
rate=0.5, | ||
batchwise=False, | ||
auto_vectorize=False, | ||
seed=None, | ||
**kwargs, | ||
): | ||
super().__init__(seed=seed, **kwargs) | ||
if not (0 <= rate <= 1.0): | ||
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}") | ||
self._layer = layer | ||
self._rate = rate | ||
self.auto_vectorize = auto_vectorize | ||
self.batchwise = batchwise | ||
self.seed = seed | ||
self.built = True | ||
|
||
def _get_should_augment(self, inputs): | ||
input_shape = tf.shape(inputs) | ||
|
||
if self.batchwise: | ||
return self._rate > tf.random.uniform(shape=(), seed=self.seed) | ||
|
||
batch_size = input_shape[0] | ||
random_values = tf.random.uniform(shape=(batch_size,), seed=self.seed) | ||
should_augment = random_values < self._rate | ||
|
||
ndims = tf.rank(inputs) | ||
broadcast_shape = tf.concat( | ||
[input_shape[:1], tf.ones(ndims - 1, dtype=tf.int32)], | ||
axis=0 | ||
) | ||
return tf.reshape(should_augment, broadcast_shape) | ||
|
||
def _augment_single(self, inputs): | ||
random_value = tf.random.uniform(shape=(), seed=self.seed) | ||
should_augment = random_value < self._rate | ||
|
||
def apply_layer(): | ||
return self._layer(inputs) | ||
|
||
def return_inputs(): | ||
return inputs | ||
|
||
return tf.cond(should_augment, apply_layer, return_inputs) | ||
|
||
def _augment_batch(self, inputs): | ||
should_augment = self._get_should_augment(inputs) | ||
augmented = self._layer(inputs) | ||
return tf.where(should_augment, augmented, inputs) | ||
|
||
def call(self, inputs): | ||
if isinstance(inputs, dict): | ||
return {key: self._call_single(input_tensor) for | ||
key, input_tensor in inputs.items()} | ||
else: | ||
return self._call_single(inputs) | ||
|
||
def _call_single(self, inputs): | ||
inputs_rank = tf.rank(inputs) | ||
is_single_sample = tf.equal(inputs_rank, 3) | ||
is_batch = tf.equal(inputs_rank, 4) | ||
|
||
def augment_single(): | ||
return self._augment_single(inputs) | ||
|
||
def augment_batch(): | ||
return self._augment_batch(inputs) | ||
|
||
condition = tf.logical_or(is_single_sample, is_batch) | ||
return tf.cond(tf.reduce_all(condition), augment_batch, augment_single) | ||
|
||
def transform_images(self, images, transformation=None, training=True): | ||
if not training: | ||
return images | ||
return self.call(images) | ||
|
||
def transform_labels(self, labels, transformation=None, training=True): | ||
if not training: | ||
return labels | ||
return self.call(labels) | ||
|
||
def transform_bounding_boxes(self, bboxes, transformation=None, training=True): | ||
if not training: | ||
return bboxes | ||
return self.call(bboxes) | ||
|
||
def transform_segmentation_masks(self, masks, transformation=None, training=True): | ||
if not training: | ||
return masks | ||
return self.call(masks) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({ | ||
"rate": self._rate, | ||
"layer": self._layer, | ||
"seed": self.seed, | ||
"batchwise": self.batchwise, | ||
"auto_vectorize": self.auto_vectorize, | ||
}) | ||
return config |
124 changes: 124 additions & 0 deletions
124
keras/src/layers/preprocessing/image_preprocessing/random_apply_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2022 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
import tensorflow as tf | ||
from absl.testing import parameterized | ||
|
||
from keras.src import layers | ||
from keras.src import ops | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
from keras.src.testing import TestCase | ||
|
||
|
||
class ZeroOut(BaseImagePreprocessingLayer): | ||
"""Layer that zeros out tensors.""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.built = True | ||
|
||
def call(self, inputs): | ||
return tf.zeros_like(inputs) | ||
|
||
def transform_images(self, images, transformation=None, training=True): | ||
return tf.zeros_like(images) | ||
|
||
def transform_segmentation_masks(self, masks, transformation=None, training=True): | ||
return tf.zeros_like(masks) | ||
|
||
def transform_bounding_boxes(self, bboxes, transformation=None, training=True): | ||
return tf.zeros_like(bboxes) | ||
|
||
def transform_labels(self, labels, transformation=None, training=True): | ||
return tf.zeros_like(labels) | ||
|
||
def get_config(self): | ||
return super().get_config() | ||
|
||
|
||
class RandomApplyTest(TestCase): | ||
rng = tf.random.Generator.from_seed(seed=1234) | ||
|
||
@parameterized.parameters([-0.5, 1.7]) | ||
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate): | ||
with self.assertRaises(ValueError): | ||
layers.RandomApply(rate=invalid_rate, layer=ZeroOut()) | ||
|
||
def test_works_with_batched_input(self): | ||
batch_size = 32 | ||
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3)) | ||
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), seed=1234) | ||
|
||
outputs = ops.convert_to_numpy(layer(dummy_inputs)) | ||
num_zero_inputs = self._num_zero_batches(dummy_inputs) | ||
num_zero_outputs = self._num_zero_batches(outputs) | ||
|
||
self.assertEqual(num_zero_inputs, 0) | ||
self.assertLess(num_zero_outputs, batch_size) | ||
self.assertGreater(num_zero_outputs, 0) | ||
|
||
def test_works_with_batchwise_layers(self): | ||
batch_size = 32 | ||
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3)) | ||
random_flip_layer = layers.RandomFlip("vertical", data_format="channels_last", seed=42) | ||
layer = layers.RandomApply(random_flip_layer, rate=0.5, batchwise=True) | ||
outputs = layer(dummy_inputs) | ||
self.assertEqual(outputs.shape, dummy_inputs.shape) | ||
|
||
@staticmethod | ||
def _num_zero_batches(images): | ||
num_batches = tf.shape(images)[0] | ||
num_non_zero_batches = tf.math.count_nonzero( | ||
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32 | ||
) | ||
return num_batches - num_non_zero_batches | ||
|
||
def test_inputs_unchanged_with_zero_rate(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
layer = layers.RandomApply(rate=0.0, layer=ZeroOut()) | ||
|
||
outputs = layer(dummy_inputs) | ||
|
||
self.assertAllClose(outputs, dummy_inputs) | ||
|
||
def test_all_inputs_changed_with_rate_equal_to_one(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
layer = layers.RandomApply(rate=1.0, layer=ZeroOut()) | ||
outputs = layer(dummy_inputs) | ||
tf.reduce_all(tf.equal(outputs, tf.zeros_like(dummy_inputs))) | ||
|
||
def test_works_with_single_image(self): | ||
dummy_inputs = self.rng.uniform(shape=(224, 224, 3)) | ||
layer = layers.RandomApply(rate=1.0, layer=ZeroOut()) | ||
outputs = layer(dummy_inputs) | ||
tf.reduce_all(tf.equal(outputs, tf.zeros_like(dummy_inputs))) | ||
|
||
def test_can_modify_label(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
dummy_labels = tf.ones(shape=(32, 2)) | ||
layer = layers.RandomApply(rate=1.0, layer=ZeroOut()) | ||
outputs = layer({"images": dummy_inputs, "labels": dummy_labels}) | ||
tf.reduce_all(tf.equal(outputs["labels"], tf.zeros_like(dummy_labels))) | ||
|
||
def test_works_with_xla(self): | ||
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
layer = layers.RandomApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False) | ||
|
||
@tf.function(jit_compile=True) | ||
def apply(x): | ||
return layer(x) | ||
|
||
outputs = apply(dummy_inputs) |
Oops, something went wrong.