diff --git a/keras_aug/_src/backend/image.py b/keras_aug/_src/backend/image.py index b7a34b1..3184a4e 100644 --- a/keras_aug/_src/backend/image.py +++ b/keras_aug/_src/backend/image.py @@ -53,13 +53,30 @@ def transform_dtype(self, images, from_dtype, to_dtype, scale=True): num_bits_input = self._num_bits_of_dtype(from_dtype) num_bits_output = self._num_bits_of_dtype(to_dtype) + def right_shift(inputs, bits): + if self.name == "tensorflow": + import tensorflow as tf + + return tf.bitwise.right_shift(inputs, bits) + else: + return inputs >> bits + + def left_shift(inputs, bits): + if self.name == "tensorflow": + import tensorflow as tf + + return tf.bitwise.left_shift(inputs, bits) + else: + return inputs << bits + if num_bits_input > num_bits_output: return ops.cast( - images >> (num_bits_input - num_bits_output), to_dtype + right_shift(images, (num_bits_input - num_bits_output)), + to_dtype, ) else: - return ops.cast(images, to_dtype) << ( - num_bits_output - num_bits_input + return left_shift( + ops.cast(images, to_dtype), num_bits_output - num_bits_input ) def crop(self, images, top, left, height, width, data_format=None): diff --git a/keras_aug/_src/layers/base/vision_random_layer.py b/keras_aug/_src/layers/base/vision_random_layer.py index 34d3e4c..6946c9f 100644 --- a/keras_aug/_src/layers/base/vision_random_layer.py +++ b/keras_aug/_src/layers/base/vision_random_layer.py @@ -74,14 +74,17 @@ class VisionRandomLayer(keras.Layer): IS_DICT = "is_dict" BATCHED = "batched" + SUPPORTED_INT_DTYPES = ("uint8", "int16", "int32") + def __init__(self, has_generator=True, seed=None, **kwargs): super().__init__(**kwargs) # Check dtype if not backend.is_float_dtype(self.compute_dtype): - if self.compute_dtype != "uint8": + if self.compute_dtype not in self.SUPPORTED_INT_DTYPES: raise ValueError( - "Only floating and 'uint8' are supported for compute dtype." - f" Received: compute_dtype={self.compute_dtype}" + f"Only floating and {self.SUPPORTED_INT_DTYPES} are " + "supported for compute dtype. " + f"Received: compute_dtype={self.compute_dtype}" ) self._backend = DynamicBackend(backend.backend()) diff --git a/keras_aug/_src/layers/vision/to_dtype.py b/keras_aug/_src/layers/vision/to_dtype.py index 62bd89a..a4fe50e 100644 --- a/keras_aug/_src/layers/vision/to_dtype.py +++ b/keras_aug/_src/layers/vision/to_dtype.py @@ -10,7 +10,11 @@ class ToDType(VisionRandomLayer): """Converts the input to a specific dtype, optionally scaling the values. - If + If `scale` is `True`, the value range will changed as follows: + - `"uint8"`: `[0, 255]` + - `"int16"`: `[-32768, 32767]` + - `"int32"`: `[-2147483648, 2147483647]` + - float: `[0.0, 1.0]` Args: to_dtype: A string specifying the target dtype. diff --git a/keras_aug/_src/layers/vision/to_dtype_test.py b/keras_aug/_src/layers/vision/to_dtype_test.py index 4c79ac6..d706737 100644 --- a/keras_aug/_src/layers/vision/to_dtype_test.py +++ b/keras_aug/_src/layers/vision/to_dtype_test.py @@ -22,8 +22,8 @@ def tearDown(self) -> None: @parameterized.named_parameters( named_product( - from_dtype=["uint8", "float16", "float32"], - to_dtype=["uint8", "float16", "float32"], + from_dtype=["uint8", "int16", "int32", "bfloat16", "float32"], + to_dtype=["uint8", "int16", "bfloat16", "float32"], scale=[True, False], ) ) @@ -37,13 +37,22 @@ def test_correctness(self, from_dtype, to_dtype, scale): layer = ToDType(to_dtype, scale) y = layer(x) + if from_dtype == "bfloat16": + x = x.astype("float32") ref_y = TF.to_dtype( torch.tensor(np.transpose(x, [0, 3, 1, 2])), dtype=to_torch_dtype(to_dtype), scale=scale, ) + + if to_dtype == "bfloat16": + y = keras.ops.cast(y, "float32") + ref_y = ref_y.to(torch.float32) + to_dtype = "float32" ref_y = np.transpose(ref_y.cpu().numpy(), [0, 2, 3, 1]) self.assertDType(y, to_dtype) + if from_dtype == "bfloat16" and to_dtype in ("uint8", "int16"): + return self.assertAllClose(y, ref_y) def test_shape(self): diff --git a/keras_aug/_src/utils/test_utils.py b/keras_aug/_src/utils/test_utils.py index 614e5f3..6dc42db 100644 --- a/keras_aug/_src/utils/test_utils.py +++ b/keras_aug/_src/utils/test_utils.py @@ -15,6 +15,12 @@ def get_images(dtype, data_format="channels_first", size=(32, 32)): x = np.random.uniform(0, 255, (2, 3, *size)).astype(dtype) elif dtype == "int8": x = np.random.uniform(-128, 127, (2, 3, *size)).astype(dtype) + elif dtype == "int16": + x = np.random.uniform(-32768, 32767, (2, 3, *size)).astype(dtype) + elif dtype == "int32": + x = np.random.uniform(-2147483648, 2147483647, (2, 3, *size)).astype( + dtype + ) if data_format == "channels_last": x = np.transpose(x, [0, 2, 3, 1]) return x