Skip to content

Commit

Permalink
Loosen dtype constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Aug 12, 2024
1 parent a0f0579 commit 95f4d30
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 9 deletions.
23 changes: 20 additions & 3 deletions keras_aug/_src/backend/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,30 @@ def transform_dtype(self, images, from_dtype, to_dtype, scale=True):
num_bits_input = self._num_bits_of_dtype(from_dtype)
num_bits_output = self._num_bits_of_dtype(to_dtype)

def right_shift(inputs, bits):
if self.name == "tensorflow":
import tensorflow as tf

return tf.bitwise.right_shift(inputs, bits)
else:
return inputs >> bits

def left_shift(inputs, bits):
if self.name == "tensorflow":
import tensorflow as tf

return tf.bitwise.left_shift(inputs, bits)
else:
return inputs << bits

if num_bits_input > num_bits_output:
return ops.cast(
images >> (num_bits_input - num_bits_output), to_dtype
right_shift(images, (num_bits_input - num_bits_output)),
to_dtype,
)
else:
return ops.cast(images, to_dtype) << (
num_bits_output - num_bits_input
return left_shift(
ops.cast(images, to_dtype), num_bits_output - num_bits_input
)

def crop(self, images, top, left, height, width, data_format=None):
Expand Down
9 changes: 6 additions & 3 deletions keras_aug/_src/layers/base/vision_random_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ class VisionRandomLayer(keras.Layer):
IS_DICT = "is_dict"
BATCHED = "batched"

SUPPORTED_INT_DTYPES = ("uint8", "int16", "int32")

def __init__(self, has_generator=True, seed=None, **kwargs):
super().__init__(**kwargs)
# Check dtype
if not backend.is_float_dtype(self.compute_dtype):
if self.compute_dtype != "uint8":
if self.compute_dtype not in self.SUPPORTED_INT_DTYPES:
raise ValueError(
"Only floating and 'uint8' are supported for compute dtype."
f" Received: compute_dtype={self.compute_dtype}"
f"Only floating and {self.SUPPORTED_INT_DTYPES} are "
"supported for compute dtype. "
f"Received: compute_dtype={self.compute_dtype}"
)

self._backend = DynamicBackend(backend.backend())
Expand Down
6 changes: 5 additions & 1 deletion keras_aug/_src/layers/vision/to_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
class ToDType(VisionRandomLayer):
"""Converts the input to a specific dtype, optionally scaling the values.
If
If `scale` is `True`, the value range will changed as follows:
- `"uint8"`: `[0, 255]`
- `"int16"`: `[-32768, 32767]`
- `"int32"`: `[-2147483648, 2147483647]`
- float: `[0.0, 1.0]`
Args:
to_dtype: A string specifying the target dtype.
Expand Down
13 changes: 11 additions & 2 deletions keras_aug/_src/layers/vision/to_dtype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def tearDown(self) -> None:

@parameterized.named_parameters(
named_product(
from_dtype=["uint8", "float16", "float32"],
to_dtype=["uint8", "float16", "float32"],
from_dtype=["uint8", "int16", "int32", "bfloat16", "float32"],
to_dtype=["uint8", "int16", "bfloat16", "float32"],
scale=[True, False],
)
)
Expand All @@ -37,13 +37,22 @@ def test_correctness(self, from_dtype, to_dtype, scale):
layer = ToDType(to_dtype, scale)
y = layer(x)

if from_dtype == "bfloat16":
x = x.astype("float32")
ref_y = TF.to_dtype(
torch.tensor(np.transpose(x, [0, 3, 1, 2])),
dtype=to_torch_dtype(to_dtype),
scale=scale,
)

if to_dtype == "bfloat16":
y = keras.ops.cast(y, "float32")
ref_y = ref_y.to(torch.float32)
to_dtype = "float32"
ref_y = np.transpose(ref_y.cpu().numpy(), [0, 2, 3, 1])
self.assertDType(y, to_dtype)
if from_dtype == "bfloat16" and to_dtype in ("uint8", "int16"):
return
self.assertAllClose(y, ref_y)

def test_shape(self):
Expand Down
6 changes: 6 additions & 0 deletions keras_aug/_src/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def get_images(dtype, data_format="channels_first", size=(32, 32)):
x = np.random.uniform(0, 255, (2, 3, *size)).astype(dtype)
elif dtype == "int8":
x = np.random.uniform(-128, 127, (2, 3, *size)).astype(dtype)
elif dtype == "int16":
x = np.random.uniform(-32768, 32767, (2, 3, *size)).astype(dtype)
elif dtype == "int32":
x = np.random.uniform(-2147483648, 2147483647, (2, 3, *size)).astype(
dtype
)
if data_format == "channels_last":
x = np.transpose(x, [0, 2, 3, 1])
return x
Expand Down

0 comments on commit 95f4d30

Please sign in to comment.