Skip to content

Commit

Permalink
Introduce TestCase and ensure mixed precision works with all layers (
Browse files Browse the repository at this point in the history
…#137)

* Add TestCase and ensure mixed precision support

* Update version

* Fix CI

* Add checks

* Update tests

* Fix RandAugment and TrivialAugmentWide
  • Loading branch information
james77777778 authored Aug 12, 2024
1 parent 5fe4414 commit e1a01b6
Show file tree
Hide file tree
Showing 46 changed files with 503 additions and 650 deletions.
2 changes: 1 addition & 1 deletion keras_aug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from keras_aug import visualization
from keras_aug._src.version import version

__version__ = "1.1.0"
__version__ = "1.1.1"
4 changes: 2 additions & 2 deletions keras_aug/_src/backend/bounding_box_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from absl.testing import parameterized
from keras.src import testing
from keras.src.testing.test_utils import named_product

from keras_aug._src.backend.bounding_box import BoundingBoxBackend
from keras_aug._src.testing.test_case import TestCase


class BoundingBoxBackendTest(testing.TestCase, parameterized.TestCase):
class BoundingBoxBackendTest(TestCase):
size = 1000.0
xyxy_box = np.array([[[10, 20, 110, 120], [20, 30, 120, 130]]], "float32")
yxyx_box = np.array([[[20, 10, 120, 110], [30, 20, 130, 120]]], "float32")
Expand Down
17 changes: 12 additions & 5 deletions keras_aug/_src/backend/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ def pad(
def adjust_brightness(self, images, factor):
ops = self.backend
images = ops.convert_to_tensor(images)
factor = ops.convert_to_tensor(factor)
original_dtype = backend.standardize_dtype(images.dtype)
is_float_inputs = backend.is_float_dtype(original_dtype)
compute_dtype = backend.result_type(original_dtype, float)
factor = ops.convert_to_tensor(factor, compute_dtype)
max_value = self._max_value_of_dtype(original_dtype)
if len(ops.shape(factor)) == 1:
factor = ops.numpy.expand_dims(factor, [1, 2, 3])
Expand All @@ -154,9 +155,10 @@ def adjust_contrast(self, images, factor, data_format=None):

ops = self.backend
images = ops.convert_to_tensor(images)
factor = ops.convert_to_tensor(factor)
original_dtype = backend.standardize_dtype(images.dtype)
is_float_inputs = backend.is_float_dtype(original_dtype)
compute_dtype = backend.result_type(original_dtype, float)
factor = ops.convert_to_tensor(factor, compute_dtype)
if len(ops.shape(factor)) == 1:
factor = ops.numpy.expand_dims(factor, [1, 2, 3])

Expand All @@ -172,9 +174,10 @@ def adjust_saturation(self, images, factor, data_format=None):

ops = self.backend
images = ops.convert_to_tensor(images)
factor = ops.convert_to_tensor(factor)
original_dtype = backend.standardize_dtype(images.dtype)
is_float_inputs = backend.is_float_dtype(original_dtype)
compute_dtype = backend.result_type(original_dtype, float)
factor = ops.convert_to_tensor(factor, compute_dtype)
if len(ops.shape(factor)) == 1:
factor = ops.numpy.expand_dims(factor, [1, 2, 3])

Expand All @@ -190,8 +193,9 @@ def adjust_hue(self, images, factor, data_format=None):

ops = self.backend
images = ops.convert_to_tensor(images)
factor = ops.convert_to_tensor(factor)
original_dtype = backend.standardize_dtype(images.dtype)
compute_dtype = backend.result_type(original_dtype, float)
factor = ops.convert_to_tensor(factor, compute_dtype)
max_value = self._max_value_of_dtype(original_dtype)
if len(ops.shape(factor)) == 1:
factor = ops.numpy.expand_dims(factor, [1, 2, 3])
Expand Down Expand Up @@ -529,6 +533,9 @@ def posterize_int(images):
mask, ops.numpy.power(2, dtype_bits - bits)
)
mask = ops.cast(mask, images.dtype)
if len(ops.shape(mask)) == 0:
mask = ops.numpy.expand_dims(mask, axis=0)
mask = ops.numpy.expand_dims(mask, axis=(1, 2, 3))
return images & mask

if backend.is_float_dtype(dtype):
Expand Down Expand Up @@ -611,7 +618,7 @@ def solarize(self, images, threshold):
threshold = ops.numpy.expand_dims(threshold, axis=0)
threshold = ops.numpy.expand_dims(threshold, axis=(1, 2, 3))
images = ops.numpy.where(
images >= ops.cast(threshold, images.dtype),
ops.numpy.greater_equal(images, threshold),
self.invert(images),
images,
)
Expand Down
38 changes: 8 additions & 30 deletions keras_aug/_src/backend/image_test.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
from absl.testing import parameterized
from keras import backend
from keras import ops
from keras.src import testing
from keras.src.testing.test_utils import named_product

from keras_aug._src.backend.image import ImageBackend
from keras_aug._src.testing.test_case import TestCase
from keras_aug._src.utils.test_utils import get_images
from keras_aug._src.utils.test_utils import uses_gpu


class ImageBackendTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
# Defaults to channels_last
self.data_format = backend.image_data_format()
backend.set_image_data_format("channels_last")
return super().setUp()

def tearDown(self) -> None:
backend.set_image_data_format(self.data_format)
return super().tearDown()

class ImageBackendTest(TestCase):
def test_crop(self):
image_backend = ImageBackend()

Expand Down Expand Up @@ -75,7 +65,6 @@ def test_adjust_brightness(self, dtype):
y = image_backend.adjust_brightness(x, 0.5)

ref_y = TF.adjust_brightness(torch.tensor(x), 0.5)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -89,7 +78,6 @@ def test_adjust_contrast(self, dtype):
y = image_backend.adjust_contrast(x, 0.5, "channels_first")

ref_y = TF.adjust_contrast(torch.tensor(x), 0.5)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -103,22 +91,23 @@ def test_adjust_hue(self, dtype):
y = image_backend.adjust_hue(x, 0.5, "channels_first")

ref_y = TF.adjust_hue(torch.tensor(x), 0.5)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

@parameterized.named_parameters(named_product(dtype=["float32", "uint8"]))
def test_adjust_saturation(self, dtype):
import torch
import torchvision.transforms.v2.functional as TF
from keras.src.backend.torch import convert_to_tensor

atol = 2 if dtype == "uint8" else 1e-6
rtol = 2 if dtype == "uint8" else 1e-6

image_backend = ImageBackend()
x = get_images(dtype, "channels_first")
y = image_backend.adjust_saturation(x, 0.5, "channels_first")

ref_y = TF.adjust_saturation(torch.tensor(x), 0.5)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
ref_y = TF.adjust_saturation(convert_to_tensor(x), 0.5)
self.assertAllClose(y, ref_y, atol=atol, rtol=rtol)
self.assertDType(y, dtype)

@parameterized.named_parameters(
Expand Down Expand Up @@ -169,7 +158,6 @@ def test_affine(self, dtype, args, interpolation):
[shear_x, shear_y],
torch_interpolation,
)
ref_y = ref_y.cpu().numpy()
# TODO: Test uint8
if dtype != "uint8":
# TODO: Investigate these parameters
Expand Down Expand Up @@ -198,7 +186,6 @@ def test_auto_contrast(self, dtype):
y = image_backend.auto_contrast(x, "channels_first")

ref_y = TF.autocontrast(torch.tensor(x))
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y, atol=atol)
self.assertDType(y, dtype)

Expand All @@ -213,7 +200,6 @@ def test_blend(self, dtype):
y = image_backend.blend(x1, x2, 0.5)

ref_y = TF._color._blend(torch.tensor(x1), torch.tensor(x2), 0.5)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -232,7 +218,6 @@ def test_equalize(self, dtype):
y = image_backend.equalize(x, data_format="channels_first")

ref_y = TF.equalize(torch.tensor(x))
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y, atol=atol)
self.assertDType(y, dtype)

Expand All @@ -254,7 +239,6 @@ def test_guassian_blur(self, dtype):
)

ref_y = TF.gaussian_blur(torch.tensor(x), (3, 3), (0.1, 0.1))
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y, atol=atol)
self.assertDType(y, dtype)

Expand All @@ -268,15 +252,13 @@ def test_rgb_to_grayscale(self, dtype):
y = image_backend.rgb_to_grayscale(x, 3, "channels_first")

ref_y = TF.rgb_to_grayscale(torch.tensor(x), 3)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

# Test ops.image.rgb_to_grayscale
y = ops.image.rgb_to_grayscale(x, "channels_first")

ref_y = TF.rgb_to_grayscale(torch.tensor(x))
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -292,7 +274,6 @@ def test_invert(self, dtype):
y = image_backend.invert(x)

ref_y = TF.invert(torch.tensor(x))
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -306,7 +287,6 @@ def test_posterize(self, dtype):
y = image_backend.posterize(x, bits=3)

ref_y = TF.posterize(torch.tensor(x), bits=3)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -322,7 +302,6 @@ def test_sharpen(self, dtype):
y = image_backend.sharpen(x, 0.5, "channels_first")

ref_y = TF.adjust_sharpness(torch.tensor(x), 0.5)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)

Expand All @@ -340,6 +319,5 @@ def test_solarize(self, dtype):
y = image_backend.solarize(x, threshold=threshold)

ref_y = TF.solarize(torch.tensor(x), threshold=threshold)
ref_y = ref_y.cpu().numpy()
self.assertAllClose(y, ref_y)
self.assertDType(y, dtype)
4 changes: 2 additions & 2 deletions keras_aug/_src/layers/base/vision_random_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import numpy as np
import pytest
from keras import backend
from keras.src import testing

from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer
from keras_aug._src.testing.test_case import TestCase


class RandomAddLayer(VisionRandomLayer):
Expand Down Expand Up @@ -181,7 +181,7 @@ def augment_custom_annotations(
return custom_annotations


class VisionRandomLayerTest(testing.TestCase):
class VisionRandomLayerTest(TestCase):
def test_single_image(self):
add_layer = RandomAddLayer(fixed_value=2.0)
image = np.random.random(size=(8, 8, 3)).astype("float32")
Expand Down
22 changes: 6 additions & 16 deletions keras_aug/_src/layers/composition/random_apply_test.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,30 @@
import keras
import numpy as np
from absl.testing import parameterized
from keras import backend
from keras.src import testing

from keras_aug._src.layers.composition.random_apply import RandomApply
from keras_aug._src.layers.vision.rand_augment import RandAugment
from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale
from keras_aug._src.layers.vision.resize import Resize
from keras_aug._src.testing.test_case import TestCase
from keras_aug._src.utils.test_utils import get_images


class RandomApplyTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
# Defaults to channels_last
self.data_format = backend.image_data_format()
backend.set_image_data_format("channels_last")
return super().setUp()

def tearDown(self) -> None:
backend.set_image_data_format(self.data_format)
return super().tearDown()

class RandomApplyTest(TestCase):
def test_correctness(self):
import torch
import torchvision.transforms.v2.functional as TF
from keras.src.backend.torch import convert_to_tensor

layer = RandomApply(transforms=[RandomGrayscale(p=1.0)], p=1.0)

x = get_images("float32", "channels_last")
y = layer(x)

ref_y = TF.rgb_to_grayscale(
torch.tensor(np.transpose(x, [0, 3, 1, 2])), num_output_channels=3
convert_to_tensor(np.transpose(x, [0, 3, 1, 2])),
num_output_channels=3,
)
ref_y = np.transpose(ref_y.cpu().numpy(), [0, 2, 3, 1])
ref_y = torch.permute(ref_y, (0, 2, 3, 1))
self.assertAllClose(y, ref_y)

# Test p=0.0
Expand Down
22 changes: 6 additions & 16 deletions keras_aug/_src/layers/composition/random_choice_test.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
import keras
import numpy as np
from absl.testing import parameterized
from keras import backend
from keras.src import testing

from keras_aug._src.layers.composition.random_choice import RandomChoice
from keras_aug._src.layers.vision.identity import Identity
from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale
from keras_aug._src.layers.vision.resize import Resize
from keras_aug._src.testing.test_case import TestCase
from keras_aug._src.utils.test_utils import get_images


class RandomChoiceTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
# Defaults to channels_last
self.data_format = backend.image_data_format()
backend.set_image_data_format("channels_last")
return super().setUp()

def tearDown(self) -> None:
backend.set_image_data_format(self.data_format)
return super().tearDown()

class RandomChoiceTest(TestCase):
def test_correctness(self):
import torch
import torchvision.transforms.v2.functional as TF
from keras.src.backend.torch import convert_to_tensor

layer = RandomChoice(
transforms=[RandomGrayscale(p=1.0), Identity()], p=[1.0, 0.0]
Expand All @@ -34,9 +23,10 @@ def test_correctness(self):
y = layer(x)

ref_y = TF.rgb_to_grayscale(
torch.tensor(np.transpose(x, [0, 3, 1, 2])), num_output_channels=3
convert_to_tensor(np.transpose(x, [0, 3, 1, 2])),
num_output_channels=3,
)
ref_y = np.transpose(ref_y.cpu().numpy(), [0, 2, 3, 1])
ref_y = torch.permute(ref_y, (0, 2, 3, 1))
self.assertAllClose(y, ref_y)

# Test p=0.0
Expand Down
Loading

0 comments on commit e1a01b6

Please sign in to comment.