Skip to content

Commit

Permalink
Add ToDType and fix bugs in RandAugment and TrivialAugmentWide (#…
Browse files Browse the repository at this point in the history
…136)

* Add ToDType and fix seed bugs in RandAugment and TrivialAugmentWide

* Update version

* Loosen dtype constraint
  • Loading branch information
james77777778 authored Aug 12, 2024
1 parent c9969c4 commit 5fe4414
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 15 deletions.
2 changes: 1 addition & 1 deletion keras_aug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from keras_aug import visualization
from keras_aug._src.version import version

__version__ = "1.0.1"
__version__ = "1.1.0"
27 changes: 23 additions & 4 deletions keras_aug/_src/backend/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ class ImageBackend(DynamicBackend):
def __init__(self, name=None):
super().__init__(name=name)

def transform_dtype(self, images, from_dtype, to_dtype):
def transform_dtype(self, images, from_dtype, to_dtype, scale=True):
# Ref: torchvision.transforms.v2.ToDtype
ops = self.backend
from_dtype = backend.standardize_dtype(from_dtype)
to_dtype = backend.standardize_dtype(to_dtype)

if from_dtype == to_dtype:
return images
if scale is False:
return ops.cast(images, to_dtype)

is_float_input = backend.is_float_dtype(from_dtype)
is_float_output = backend.is_float_dtype(to_dtype)
Expand Down Expand Up @@ -51,13 +53,30 @@ def transform_dtype(self, images, from_dtype, to_dtype):
num_bits_input = self._num_bits_of_dtype(from_dtype)
num_bits_output = self._num_bits_of_dtype(to_dtype)

def right_shift(inputs, bits):
if self.name == "tensorflow":
import tensorflow as tf

return tf.bitwise.right_shift(inputs, bits)
else:
return inputs >> bits

def left_shift(inputs, bits):
if self.name == "tensorflow":
import tensorflow as tf

return tf.bitwise.left_shift(inputs, bits)
else:
return inputs << bits

if num_bits_input > num_bits_output:
return ops.cast(
images >> (num_bits_input - num_bits_output), to_dtype
right_shift(images, (num_bits_input - num_bits_output)),
to_dtype,
)
else:
return ops.cast(images, to_dtype) << (
num_bits_output - num_bits_input
return left_shift(
ops.cast(images, to_dtype), num_bits_output - num_bits_input
)

def crop(self, images, top, left, height, width, data_format=None):
Expand Down
23 changes: 19 additions & 4 deletions keras_aug/_src/layers/base/vision_random_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ class VisionRandomLayer(keras.Layer):
IS_DICT = "is_dict"
BATCHED = "batched"

SUPPORTED_INT_DTYPES = ("uint8", "int16", "int32")

def __init__(self, has_generator=True, seed=None, **kwargs):
super().__init__(**kwargs)
# Check dtype
if not backend.is_float_dtype(self.compute_dtype):
if self.compute_dtype != "uint8":
if self.compute_dtype not in self.SUPPORTED_INT_DTYPES:
raise ValueError(
"Only floating and 'uint8' are supported for compute dtype."
f" Received: compute_dtype={self.compute_dtype}"
f"Only floating and {self.SUPPORTED_INT_DTYPES} are "
"supported for compute dtype. "
f"Received: compute_dtype={self.compute_dtype}"
)

self._backend = DynamicBackend(backend.backend())
Expand All @@ -99,6 +102,7 @@ def __init__(self, has_generator=True, seed=None, **kwargs):
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
self.autocast = False
self._transform_dtype_scale = True

@property
def image_dtype(self):
Expand All @@ -122,6 +126,14 @@ def backend(self):
def random_generator(self):
return self._random_generator.random_generator

@property
def transform_dtype_scale(self):
return self._transform_dtype_scale

@transform_dtype_scale.setter
def transform_dtype_scale(self, value):
self._transform_dtype_scale = bool(value)

def get_params(
self,
batch_size,
Expand Down Expand Up @@ -389,7 +401,10 @@ def _cast_inputs(self, inputs):
if self.IMAGES in inputs:
inputs[self.IMAGES] = ops.convert_to_tensor(inputs[self.IMAGES])
inputs[self.IMAGES] = self.image_backend.transform_dtype(
inputs[self.IMAGES], inputs[self.IMAGES].dtype, self.image_dtype
inputs[self.IMAGES],
inputs[self.IMAGES].dtype,
self.image_dtype,
scale=self.transform_dtype_scale,
)
if self.LABELS in inputs:
inputs[self.LABELS] = ops.convert_to_tensor(inputs[self.LABELS])
Expand Down
2 changes: 1 addition & 1 deletion keras_aug/_src/layers/vision/rand_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_params(self, batch_size, images=None, **kwargs):
ops.numpy.log(fn_idx_p), self.num_ops, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed_p = ops.random.uniform([batch_size], seed=random_generator) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
Expand Down
56 changes: 56 additions & 0 deletions keras_aug/_src/layers/vision/to_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import keras
from keras import backend

from keras_aug._src.keras_aug_export import keras_aug_export
from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer


@keras_aug_export(parent_path=["keras_aug.layers.vision"])
@keras.saving.register_keras_serializable(package="keras_aug")
class ToDType(VisionRandomLayer):
"""Converts the input to a specific dtype, optionally scaling the values.
If `scale` is `True`, the value range will changed as follows:
- `"uint8"`: `[0, 255]`
- `"int16"`: `[-32768, 32767]`
- `"int32"`: `[-2147483648, 2147483647]`
- float: `[0.0, 1.0]`
Args:
to_dtype: A string specifying the target dtype.
scale: Whether to scale the values. Defaults to `False`.
"""

def __init__(self, to_dtype, scale=False, **kwargs):
to_dtype = backend.standardize_dtype(to_dtype)
self.scale = bool(scale)
if "dtype" in kwargs:
kwargs.pop("dtype")
super().__init__(has_generator=False, dtype=to_dtype, **kwargs)
self.to_dtype = to_dtype
self.transform_dtype_scale = self.scale

def compute_output_shape(self, input_shape):
return input_shape

def augment_images(self, images, transformations, **kwargs):
return images

def augment_labels(self, labels, transformations, **kwargs):
return labels

def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs):
return bounding_boxes

def augment_segmentation_masks(
self, segmentation_masks, transformations, **kwargs
):
return segmentation_masks

def augment_keypoints(self, keypoints, transformations, **kwargs):
return keypoints

def get_config(self):
config = super().get_config()
config.update({"to_dtype": self.to_dtype, "scale": self.scale})
return config
102 changes: 102 additions & 0 deletions keras_aug/_src/layers/vision/to_dtype_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import keras
import numpy as np
from absl.testing import parameterized
from keras import backend
from keras.src import testing
from keras.src.testing.test_utils import named_product

from keras_aug._src.layers.vision.to_dtype import ToDType
from keras_aug._src.utils.test_utils import get_images


class ToDTypeTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
# Defaults to channels_last
self.data_format = backend.image_data_format()
backend.set_image_data_format("channels_last")
return super().setUp()

def tearDown(self) -> None:
backend.set_image_data_format(self.data_format)
return super().tearDown()

@parameterized.named_parameters(
named_product(
from_dtype=["uint8", "int16", "int32", "bfloat16", "float32"],
to_dtype=["uint8", "int16", "bfloat16", "float32"],
scale=[True, False],
)
)
def test_correctness(self, from_dtype, to_dtype, scale):
import torch
import torchvision.transforms.v2.functional as TF
from keras.src.backend.torch import to_torch_dtype

# Test channels_last
x = get_images(from_dtype, "channels_last")
layer = ToDType(to_dtype, scale)
y = layer(x)

if from_dtype == "bfloat16":
x = x.astype("float32")
ref_y = TF.to_dtype(
torch.tensor(np.transpose(x, [0, 3, 1, 2])),
dtype=to_torch_dtype(to_dtype),
scale=scale,
)

if to_dtype == "bfloat16":
y = keras.ops.cast(y, "float32")
ref_y = ref_y.to(torch.float32)
to_dtype = "float32"
ref_y = np.transpose(ref_y.cpu().numpy(), [0, 2, 3, 1])
self.assertDType(y, to_dtype)
if from_dtype == "bfloat16" and to_dtype in ("uint8", "int16"):
return
self.assertAllClose(y, ref_y)

def test_shape(self):
# Test dynamic shape
x = keras.KerasTensor((None, None, None, 3))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, None, None, 3))
backend.set_image_data_format("channels_first")
x = keras.KerasTensor((None, 3, None, None))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, 3, None, None))

# Test static shape
backend.set_image_data_format("channels_last")
x = keras.KerasTensor((None, 32, 32, 3))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, 32, 32, 3))
backend.set_image_data_format("channels_first")
x = keras.KerasTensor((None, 3, 32, 32))
y = ToDType("float32", scale=True)(x)
self.assertEqual(y.shape, (None, 3, 32, 32))

def test_model(self):
layer = ToDType("float32", scale=True)
inputs = keras.layers.Input(shape=[None, None, 5])
outputs = layer(inputs)
model = keras.models.Model(inputs, outputs)
self.assertEqual(model.output_shape, (None, None, None, 5))

def test_config(self):
x = get_images("float32", "channels_last")
layer = ToDType("float32", scale=True)
y = layer(x)

layer = ToDType.from_config(layer.get_config())
y2 = layer(x)
self.assertAllClose(y, y2)

def test_tf_data_compatibility(self):
import tensorflow as tf

layer = ToDType("float32", scale=True)
x = get_images("float32", "channels_last")
ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer)
for output in ds.take(1):
self.assertIsInstance(output, tf.Tensor)
self.assertEqual(output.shape, (2, 32, 32, 3))
6 changes: 4 additions & 2 deletions keras_aug/_src/layers/vision/trivial_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def get_params(self, batch_size, images=None, **kwargs):
random_generator = self.random_generator

p = ops.random.uniform([batch_size], seed=random_generator)
magnitude = ops.random.randint([batch_size], 0, self.num_magnitude_bins)
magnitude = ops.random.randint(
[batch_size], 0, self.num_magnitude_bins, seed=random_generator
)
fn_idx_p = ops.convert_to_tensor([self.fn_idx_p])
fn_idx = ops.random.categorical(
ops.numpy.log(fn_idx_p), 1, seed=random_generator
)
fn_idx = fn_idx[0]
signed_p = ops.random.uniform([batch_size]) > 0.5
signed_p = ops.random.uniform([batch_size], seed=random_generator) > 0.5
signed = ops.cast(ops.numpy.where(signed_p, 1.0, -1.0), dtype="float32")
return dict(
p=p, # shape: (batch_size,)
Expand Down
6 changes: 4 additions & 2 deletions keras_aug/_src/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@


@keras_aug_export(parent_path=["keras_aug.ops.image"])
def transform_dtype(images, from_dtype, to_dtype):
def transform_dtype(images, from_dtype, to_dtype, scale=True):
backend = "tensorflow" if in_tf_graph() else None
return ImageBackend(backend).transform_dtype(images, from_dtype, to_dtype)
return ImageBackend(backend).transform_dtype(
images, from_dtype, to_dtype, scale=scale
)


@keras_aug_export(parent_path=["keras_aug.ops.image"])
Expand Down
8 changes: 8 additions & 0 deletions keras_aug/_src/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ def get_images(dtype, data_format="channels_first", size=(32, 32)):
x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype)
elif dtype == "bfloat16":
x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype)
elif dtype == "float16":
x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype)
elif dtype == "uint8":
x = np.random.uniform(0, 255, (2, 3, *size)).astype(dtype)
elif dtype == "int8":
x = np.random.uniform(-128, 127, (2, 3, *size)).astype(dtype)
elif dtype == "int16":
x = np.random.uniform(-32768, 32767, (2, 3, *size)).astype(dtype)
elif dtype == "int32":
x = np.random.uniform(-2147483648, 2147483647, (2, 3, *size)).astype(
dtype
)
if data_format == "channels_last":
x = np.transpose(x, [0, 2, 3, 1])
return x
Expand Down
2 changes: 1 addition & 1 deletion keras_aug/_src/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from keras_aug._src.keras_aug_export import keras_aug_export

__version__ = "1.0.1"
__version__ = "1.1.0"


@keras_aug_export("keras_aug")
Expand Down
1 change: 1 addition & 0 deletions keras_aug/layers/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@
from keras_aug._src.layers.vision.random_solarize import RandomSolarize
from keras_aug._src.layers.vision.rescale import Rescale
from keras_aug._src.layers.vision.resize import Resize
from keras_aug._src.layers.vision.to_dtype import ToDType
from keras_aug._src.layers.vision.trivial_augment import TrivialAugmentWide

0 comments on commit 5fe4414

Please sign in to comment.