Skip to content

Commit

Permalink
Add p to TrivialAugmentWide
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Aug 9, 2024
1 parent f30d2da commit 681c0c3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
31 changes: 25 additions & 6 deletions keras_aug/_src/layers/vision/trivial_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TrivialAugmentWide(VisionRandomLayer):
- [TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation](https://arxiv.org/abs/2103.10158)
Args:
p: A float specifying the probability. Defaults to `1.0`.
num_magnitude_bins: The number of different magnitude values. Defaults
to `31`.
geometric: Whether to include geometric augmentations. This
Expand All @@ -46,6 +47,7 @@ class TrivialAugmentWide(VisionRandomLayer):

def __init__(
self,
p: float = 1.0,
num_magnitude_bins: int = 31,
geometric: bool = True,
interpolation: str = "bilinear",
Expand All @@ -56,6 +58,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.p = float(p)
self.num_magnitude_bins = int(num_magnitude_bins)
self.geometric = bool(geometric)
self.interpolation = standardize_interpolation(interpolation)
Expand Down Expand Up @@ -106,7 +109,7 @@ def __init__(
)
p = [1.0] * len(self.augmentation_space)
total = sum(p)
self.p = [prob / total for prob in p]
self.fn_idx_p = [prob / total for prob in p]

def compute_output_shape(self, input_shape):
return input_shape
Expand All @@ -115,15 +118,17 @@ def get_params(self, batch_size, images=None, **kwargs):
ops = self.backend
random_generator = self.random_generator

p = ops.random.uniform([batch_size], seed=random_generator)
magnitude = ops.random.randint([batch_size], 0, self.num_magnitude_bins)
fn_idx_p = ops.convert_to_tensor([self.p])
fn_idx_p = ops.convert_to_tensor([self.fn_idx_p])
fn_idx = ops.random.categorical(
ops.numpy.log(fn_idx_p), 1, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
magnitude=magnitude, # shape: (batch_size,)
fn_idx=fn_idx, # shape: (1,)
signed=signed, # shape: (batch_size,)
Expand Down Expand Up @@ -301,17 +306,25 @@ def _apply_images_transform(self, images, magnitude, idx, signed):
return images

def augment_images(self, images, transformations, **kwargs):
ops = self.backend

p = transformations["p"]
magnitude = transformations["magnitude"]
fn_idx = transformations["fn_idx"][0]
signed = transformations["signed"]
images = self._apply_images_transform(images, magnitude, fn_idx, signed)
prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3])
images = ops.numpy.where(
prob,
self._apply_images_transform(images, magnitude, fn_idx, signed),
images,
)
return images

def augment_labels(self, labels, transformations, **kwargs):
return labels

def _apply_bounding_boxes_transform(
self, bounding_boxes, height, width, magnitude, idx, signed
self, bounding_boxes, height, width, p, magnitude, idx, signed
):
ops = self.backend

Expand Down Expand Up @@ -423,7 +436,11 @@ def _apply_bounding_boxes_transform(
width=width,
)
)
boxes = ops.core.switch(idx, transforms, bounding_boxes["boxes"])
prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2])
boxes = bounding_boxes["boxes"]
boxes = ops.numpy.where(

Check warning on line 441 in keras_aug/_src/layers/vision/trivial_augment.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/layers/vision/trivial_augment.py#L439-L441

Added lines #L439 - L441 were not covered by tests
prob, ops.core.switch(idx, transforms, boxes), boxes
)
bounding_boxes = bounding_boxes.copy()
bounding_boxes["boxes"] = boxes
return bounding_boxes
Expand All @@ -443,6 +460,7 @@ def augment_bounding_boxes(
)
ops = self.backend

p = transformations["p"]

Check warning on line 463 in keras_aug/_src/layers/vision/trivial_augment.py

View check run for this annotation

Codecov / codecov/patch

keras_aug/_src/layers/vision/trivial_augment.py#L463

Added line #L463 was not covered by tests
magnitude = transformations["magnitude"]
fn_idx = transformations["fn_idx"][0]
signed = transformations["signed"]
Expand All @@ -458,7 +476,7 @@ def augment_bounding_boxes(
dtype=self.bounding_box_dtype,
)
bounding_boxes = self._apply_bounding_boxes_transform(
bounding_boxes, height, width, magnitude, fn_idx, signed
bounding_boxes, height, width, p, magnitude, fn_idx, signed
)
bounding_boxes = self.bbox_backend.clip_to_images(
bounding_boxes,
Expand All @@ -480,6 +498,7 @@ def get_config(self):
config = super().get_config()
config.update(
{
"p": self.p,
"num_magnitude_bins": self.num_magnitude_bins,
"geometric": self.geometric,
"interpolation": self.interpolation,
Expand Down
10 changes: 10 additions & 0 deletions keras_aug/_src/layers/vision/trivial_augment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ def test_config(self):
y2 = layer(x)
self.assertEqual(y.shape, y2.shape)

# Test `p=0.0`
layer = TrivialAugmentWide(p=0.0)
y = layer(x)

layer = TrivialAugmentWide.from_config(layer.get_config())
y2 = layer(x)
self.assertAllClose(y, x)
self.assertAllClose(y2, x)
self.assertEqual(y.shape, y2.shape)

def test_tf_data_compatibility(self):
import tensorflow as tf

Expand Down

0 comments on commit 681c0c3

Please sign in to comment.