diff --git a/keras_aug/_src/layers/vision/trivial_augment.py b/keras_aug/_src/layers/vision/trivial_augment.py index 22e274f..4c688c9 100644 --- a/keras_aug/_src/layers/vision/trivial_augment.py +++ b/keras_aug/_src/layers/vision/trivial_augment.py @@ -23,6 +23,7 @@ class TrivialAugmentWide(VisionRandomLayer): - [TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation](https://arxiv.org/abs/2103.10158) Args: + p: A float specifying the probability. Defaults to `1.0`. num_magnitude_bins: The number of different magnitude values. Defaults to `31`. geometric: Whether to include geometric augmentations. This @@ -46,6 +47,7 @@ class TrivialAugmentWide(VisionRandomLayer): def __init__( self, + p: float = 1.0, num_magnitude_bins: int = 31, geometric: bool = True, interpolation: str = "bilinear", @@ -56,6 +58,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.p = float(p) self.num_magnitude_bins = int(num_magnitude_bins) self.geometric = bool(geometric) self.interpolation = standardize_interpolation(interpolation) @@ -106,7 +109,7 @@ def __init__( ) p = [1.0] * len(self.augmentation_space) total = sum(p) - self.p = [prob / total for prob in p] + self.fn_idx_p = [prob / total for prob in p] def compute_output_shape(self, input_shape): return input_shape @@ -115,8 +118,9 @@ def get_params(self, batch_size, images=None, **kwargs): ops = self.backend random_generator = self.random_generator + p = ops.random.uniform([batch_size], seed=random_generator) magnitude = ops.random.randint([batch_size], 0, self.num_magnitude_bins) - fn_idx_p = ops.convert_to_tensor([self.p]) + fn_idx_p = ops.convert_to_tensor([self.fn_idx_p]) fn_idx = ops.random.categorical( ops.numpy.log(fn_idx_p), 1, seed=random_generator ) @@ -124,6 +128,7 @@ def get_params(self, batch_size, images=None, **kwargs): signed_p = ops.random.uniform([batch_size]) > 0.5 signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32") return dict( + p=p, # shape: (batch_size,) magnitude=magnitude, # shape: (batch_size,) fn_idx=fn_idx, # shape: (1,) signed=signed, # shape: (batch_size,) @@ -301,17 +306,25 @@ def _apply_images_transform(self, images, magnitude, idx, signed): return images def augment_images(self, images, transformations, **kwargs): + ops = self.backend + + p = transformations["p"] magnitude = transformations["magnitude"] fn_idx = transformations["fn_idx"][0] signed = transformations["signed"] - images = self._apply_images_transform(images, magnitude, fn_idx, signed) + prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) + images = ops.numpy.where( + prob, + self._apply_images_transform(images, magnitude, fn_idx, signed), + images, + ) return images def augment_labels(self, labels, transformations, **kwargs): return labels def _apply_bounding_boxes_transform( - self, bounding_boxes, height, width, magnitude, idx, signed + self, bounding_boxes, height, width, p, magnitude, idx, signed ): ops = self.backend @@ -423,7 +436,11 @@ def _apply_bounding_boxes_transform( width=width, ) ) - boxes = ops.core.switch(idx, transforms, bounding_boxes["boxes"]) + prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2]) + boxes = bounding_boxes["boxes"] + boxes = ops.numpy.where( + prob, ops.core.switch(idx, transforms, boxes), boxes + ) bounding_boxes = bounding_boxes.copy() bounding_boxes["boxes"] = boxes return bounding_boxes @@ -443,6 +460,7 @@ def augment_bounding_boxes( ) ops = self.backend + p = transformations["p"] magnitude = transformations["magnitude"] fn_idx = transformations["fn_idx"][0] signed = transformations["signed"] @@ -458,7 +476,7 @@ def augment_bounding_boxes( dtype=self.bounding_box_dtype, ) bounding_boxes = self._apply_bounding_boxes_transform( - bounding_boxes, height, width, magnitude, fn_idx, signed + bounding_boxes, height, width, p, magnitude, fn_idx, signed ) bounding_boxes = self.bbox_backend.clip_to_images( bounding_boxes, @@ -480,6 +498,7 @@ def get_config(self): config = super().get_config() config.update( { + "p": self.p, "num_magnitude_bins": self.num_magnitude_bins, "geometric": self.geometric, "interpolation": self.interpolation, diff --git a/keras_aug/_src/layers/vision/trivial_augment_test.py b/keras_aug/_src/layers/vision/trivial_augment_test.py index 9b2a12a..caa9398 100644 --- a/keras_aug/_src/layers/vision/trivial_augment_test.py +++ b/keras_aug/_src/layers/vision/trivial_augment_test.py @@ -55,6 +55,16 @@ def test_config(self): y2 = layer(x) self.assertEqual(y.shape, y2.shape) + # Test `p=0.0` + layer = TrivialAugmentWide(p=0.0) + y = layer(x) + + layer = TrivialAugmentWide.from_config(layer.get_config()) + y2 = layer(x) + self.assertAllClose(y, x) + self.assertAllClose(y2, x) + self.assertEqual(y.shape, y2.shape) + def test_tf_data_compatibility(self): import tensorflow as tf