Skip to content

Commit

Permalink
Add cut_mix processing layer
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Jan 16, 2025
1 parent 25d6d80 commit f91cd84
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from keras.src.layers.preprocessing.image_preprocessing.center_crop import (
CenterCrop,
)
from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix
from keras.src.layers.preprocessing.image_preprocessing.equalization import (
Equalization,
)
Expand Down
1 change: 1 addition & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from keras.src.layers.preprocessing.image_preprocessing.center_crop import (
CenterCrop,
)
from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix
from keras.src.layers.preprocessing.image_preprocessing.equalization import (
Equalization,
)
Expand Down
1 change: 1 addition & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from keras.src.layers.preprocessing.image_preprocessing.center_crop import (
CenterCrop,
)
from keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix
from keras.src.layers.preprocessing.image_preprocessing.equalization import (
Equalization,
)
Expand Down
186 changes: 186 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/cut_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
convert_format,
)
from keras.src.random import SeedGenerator


@keras_export("keras.layers.CutMix")
class CutMix(BaseImagePreprocessingLayer):
"""CutMix implements the CutMix data augmentation technique.
Args:
alpha: Float between 0 and 1. Inverse scale parameter for the gamma
distribution. This controls the shape of the distribution from which
the smoothing values are sampled. Defaults to 1.0, which is a
recommended value when training an imagenet1k classification model.
seed: Integer. Used to create a random seed.
References:
- [CutMix paper]( https://arxiv.org/abs/1905.04899).
"""

def __init__(self, alpha=1.0, seed=None, data_format=None, **kwargs):
super().__init__(data_format=data_format, **kwargs)
self.alpha = alpha
self.seed = seed
self.generator = SeedGenerator(seed)

def get_random_transformation(self, data, training=True, seed=None):
if not training:
return None

if isinstance(data, dict):
images = data["images"]
else:
images = data

images_shape = self.backend.shape(images)
if len(images_shape) == 3:
return None

batch_size = images_shape[0]
if self.data_format == "channels_first":
image_height = images_shape[-2]
image_width = images_shape[-1]
else:
image_height = images_shape[-3]
image_width = images_shape[-2]

seed = seed or self._get_seed_generator(self.backend._backend)

permutation_order = self.backend.random.shuffle(
self.backend.numpy.arange(0, batch_size, dtype="int64"),
seed=seed,
)

mix_weight = self.backend.random.beta(
(batch_size,), self.alpha, self.alpha, seed=seed
)

ratio = self.backend.numpy.sqrt(1 - mix_weight)
cut_height = self.backend.cast(
ratio * image_height, dtype=self.compute_dtype
)
cut_width = self.backend.cast(
ratio * image_width, dtype=self.compute_dtype
)

random_center_height = self.backend.random.uniform(
shape=[batch_size],
minval=0,
maxval=image_height,
dtype=self.compute_dtype,
)
random_center_width = self.backend.random.uniform(
shape=[batch_size],
minval=0,
maxval=image_width,
dtype=self.compute_dtype,
)

return {
"permutation_order": permutation_order,
"cut_height": cut_height,
"cut_width": cut_width,
"random_center_height": random_center_height,
"random_center_width": random_center_width,
"input_shape": (batch_size, image_height, image_width),
}

def transform_images(self, images, transformation=None, training=True):
if training:
if transformation is not None:
images = self._cut_mix(images, transformation)
images = self.backend.cast(images, self.compute_dtype)
return images

def _cut_mix(self, images, transformation):
def _axis_mask(starts, ends, mask_len, batch_size):
axis_indices = self.backend.numpy.arange(0, mask_len)
axis_indices = self.backend.numpy.expand_dims(axis_indices, 0)
axis_indices = self.backend.numpy.tile(
axis_indices, [batch_size, 1]
)

axis_mask = self.backend.numpy.greater_equal(
axis_indices, starts
) & self.backend.numpy.less(axis_indices, ends)
return axis_mask

def corners_to_mask(bounding_boxes, mask_shape):
batch_size, mask_height, mask_width = mask_shape
x0, y0, x1, y1 = self.backend.numpy.split(
bounding_boxes, 4, axis=-1
)

w_mask = _axis_mask(x0, x1, mask_width, batch_size)
h_mask = _axis_mask(y0, y1, mask_height, batch_size)

w_mask = self.backend.numpy.expand_dims(w_mask, axis=1)
h_mask = self.backend.numpy.expand_dims(h_mask, axis=2)
masks = self.backend.numpy.logical_and(w_mask, h_mask)
return masks

images = self.backend.cast(images, self.compute_dtype)

if self.data_format == "channels_first":
channel_axis = 1
else:
channel_axis = -1

permutation_order = transformation["permutation_order"]
random_center_width = transformation["random_center_width"]
random_center_height = transformation["random_center_height"]
cut_width = transformation["cut_width"]
cut_height = transformation["cut_height"]
input_shape = transformation["input_shape"]

xywh = self.backend.numpy.stack(
[random_center_width, random_center_height, cut_width, cut_height],
axis=1,
)
corners = convert_format(xywh, source="center_xywh", target="xyxy")
is_rectangle = corners_to_mask(corners, input_shape)
is_rectangle = self.backend.numpy.expand_dims(
is_rectangle, channel_axis
)

images = self.backend.numpy.where(
is_rectangle,
self.backend.numpy.take(images, permutation_order, axis=0),
images,
)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
raise NotImplementedError()

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return self.transform_images(
segmentation_masks, transformation, training
)

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"alpha": self.alpha,
"seed": self.seed,
}
base_config = super().get_config()
return {**base_config, **config}
78 changes: 78 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import testing


class CutMixTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.CutMix,
init_kwargs={
"alpha": 1.0,
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_cut_mix_inference(self):
seed = 3481
layer = layers.CutMix()

np.random.seed(seed)
inputs = np.random.randint(0, 255, size=(224, 224, 3))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output)

def test_cut_mix_basic(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image1 = np.ones((2, 2, 1))
image2 = np.zeros((2, 2, 1))
inputs = np.asarray([image1, image2])
expected_output = np.asarray(
[
[[[1.0], [1.0]], [[1.0], [0.0]]],
[[[0.0], [0.0]], [[0.0], [1.0]]],
]
)

else:
image1 = np.ones((1, 2, 2))
image2 = np.zeros((1, 2, 2))
inputs = np.asarray([image1, image2])
expected_output = np.asarray(
[[[[1.0, 1.0], [1.0, 0.0]]], [[[0.0, 0.0], [0.0, 1.0]]]]
)

layer = layers.CutMix(data_format=data_format)

transformation = {
"cut_height": [1.0, 1.0],
"cut_width": [1.0, 1.0],
"input_shape": (2, 2, 2),
"permutation_order": [1, 0],
"random_center_height": [1.0, 1.0],
"random_center_width": [1.0, 1.0],
}
output = layer.transform_images(inputs, transformation)

self.assertAllClose(expected_output, output)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.CutMix(data_format=data_format)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit f91cd84

Please sign in to comment.