From e66ab50f6b576b08e5c3395ec8c9da08821ed0d7 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 8 Jan 2025 16:15:13 -0800 Subject: [PATCH] merge master (#20741) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Specify window_length dtype requirement in tf.keras.ops.istft in math.py (#20728) The `window_length` parameter in `tf.keras.ops.istft` requires `tf.int32` dtype, but this isn't documented. This can cause unexpected `ValueError` when using `tf.int64` and `tf.int16` Here is the Example case: ``` import tensorflow as tf input_dict = { 'stfts': tf.constant([[-0.87817144+1.14583987j, -0.32066484+0.25565411j]], dtype=tf.complex128), 'frame_length': tf.constant(256, dtype=tf.int16), 'frame_step': tf.constant(5120,dtype=tf.int64) } result = tf.signal.inverse_stft(**input_dict) print(result) ``` The code throws the following error: ``` ValueError: window_length: Tensor conversion requested dtype int32 for Tensor with dtype int64 ``` * Add rand_augment processing layer (#20716) * Add rand_augment init * Update rand_augment init * Add rand_augment * Add NotImplementedError * Add some test cases * Fix failed test case * Update rand_augment * Update rand_augment test * Fix random_rotation bug * Add build method to supress warning. * Add implementation for transform_bboxes * Fixing batch_dim_name attribute (#20674) * fixing wrong trainer assumption that batch dim is always the first one in the mesh * need functools partial * lint * fix test failure when distribution=None * lint2 * fix for test failure * added data sharding for 3D+ meshes * lint3 * added @property for batch_dim_name + refactoring * fix typo * Add support for `dtype` / `DTypePolicy` to `JaxLayer` and `FlaxLayer`. (#20732) The `dtype` / `DTypePolicy` is applied to all float variables. * Allow dynamic shape in `STFTSpectrogram` layer. (#20736) by simply using `ops.shape(x)` instead of `x.shape`. * Remove duplicate export tests in `model_test`. (#20735) The same tests exist at: - https://github.com/keras-team/keras/blob/master/keras/src/export/saved_model_test.py#L66 - https://github.com/keras-team/keras/blob/master/keras/src/export/onnx_test.py#L62 The goal is to isolate the use of `onnxruntime` to a single file, `onnx_test.py`. * Add OpenVINO into README.md (#20739) * Add OpenVINO into README.md Signed-off-by: Kazantsev, Roman * Update README.md --------- Signed-off-by: Kazantsev, Roman * Multiple Example Title has removed in metrics.MeanIoU method (#20738) Multiple Example Title has removed in metrics.MeanIoU method --------- Signed-off-by: Kazantsev, Roman Co-authored-by: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Co-authored-by: Ugeun Park <37043543+shashaka@users.noreply.github.com> Co-authored-by: Martin Görner Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com> Co-authored-by: Roman Kazantsev Co-authored-by: LavanyaKV1234 <154420106+LavanyaKV1234@users.noreply.github.com> --- README.md | 8 +- keras/api/_tf_keras/keras/layers/__init__.py | 3 + keras/api/layers/__init__.py | 3 + keras/src/backend/jax/distribution_lib.py | 12 +- .../src/backend/jax/distribution_lib_test.py | 4 +- keras/src/backend/jax/export.py | 2 +- keras/src/backend/jax/trainer.py | 8 +- keras/src/distribution/distribution_lib.py | 20 +- .../src/distribution/distribution_lib_test.py | 6 +- keras/src/layers/__init__.py | 3 + keras/src/layers/layer.py | 10 +- .../image_preprocessing/rand_augment.py | 235 ++++++++++++++++++ .../image_preprocessing/rand_augment_test.py | 114 +++++++++ .../random_brightness_test.py | 4 +- .../image_preprocessing/random_rotation.py | 3 +- .../layers/preprocessing/stft_spectrogram.py | 10 +- .../preprocessing/stft_spectrogram_test.py | 29 ++- keras/src/metrics/iou_metrics.py | 1 - keras/src/models/model_test.py | 69 ----- keras/src/ops/math.py | 2 +- keras/src/utils/jax_layer.py | 38 ++- keras/src/utils/jax_layer_test.py | 26 ++ 22 files changed, 489 insertions(+), 121 deletions(-) create mode 100644 keras/src/layers/preprocessing/image_preprocessing/rand_augment.py create mode 100644 keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py diff --git a/README.md b/README.md index b8a179b18f6..047906baa9a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Keras 3: Deep Learning for Humans -Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, and PyTorch. +Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only). Effortlessly build and train models for computer vision, natural language processing, audio processing, timeseries forecasting, recommender systems, etc. @@ -73,7 +73,7 @@ python pip_build.py --install ## Configuring your backend You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json` -to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`. Example: +to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example: ``` export KERAS_BACKEND="jax" @@ -91,6 +91,10 @@ import keras **Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after the package has been imported. +**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model +predictions using `model.predict()` method. +To use `openvino` backend, install the required dependencies from the `requirements-openvino.txt` file. + ## Backwards compatibility Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 4f13a596130..6ba8ce78308 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -152,6 +152,9 @@ MaxNumBoundingBoxes, ) from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index a4aaf7c9917..3aa267859c7 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -152,6 +152,9 @@ MaxNumBoundingBoxes, ) from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index edb1dc1184a..5dc5c057d29 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout): return global_value -def distribute_data_input(per_process_batch, layout): +def distribute_data_input(per_process_batch, layout, batch_dim_name): """Distribute the input data with the corresponding layout. Note that the inputs here is a local worker batch. Within the local worker, @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout): if not isinstance(layout, jax.sharding.Sharding): layout = _to_jax_layout(layout) - mesh_shape = list(layout.mesh.shape.values()) - num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh - mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1 + num_model_replicas_total = layout.mesh.shape[batch_dim_name] + + mesh_model_dim_size = 1 + for name, dim_size in layout.mesh.shape.items(): + if not name == batch_dim_name: + mesh_model_dim_size *= dim_size + num_model_replicas_per_process = num_model_replicas_total / num_processes() per_process_batch_size = per_process_batch.shape[0] diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 5ab8eeb4133..81ceddfd305 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -337,7 +337,9 @@ def test_distribute_data_input(self): mesh, jax.sharding.PartitionSpec("batch", None) ) - result = backend_dlib.distribute_data_input(per_process_batch, layout) + result = backend_dlib.distribute_data_input( + per_process_batch, layout, "batch" + ) # Check the shape of the global batch array self.assertEqual( diff --git a/keras/src/backend/jax/export.py b/keras/src/backend/jax/export.py index 963648460dc..dd754c14418 100644 --- a/keras/src/backend/jax/export.py +++ b/keras/src/backend/jax/export.py @@ -119,7 +119,7 @@ def stateful_fn(*args, **kwargs): self._tf_trackable.non_trainable_variables, non_trainable_variables, ): - var.assign(new_value) + var.assign(tf.cast(new_value, var.dtype)) return output stateful_fn.__signature__ = inspect.Signature( diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index f5ae91ea8d8..c127f7f8334 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -1,5 +1,6 @@ import collections import itertools +from functools import partial import jax import numpy as np @@ -988,15 +989,18 @@ def _get_jax_state( def _distribute_data(data, layouts=None): distribution = distribution_lib.distribution() + if distribution is not None: if layouts is None: layouts = tree.map_structure( lambda d: distribution.get_data_layout(d.shape), data, ) - return tree.map_structure( - jax_distribution_lib.distribute_data_input, data, layouts + jax_dist_data_input = partial( + jax_distribution_lib.distribute_data_input, + batch_dim_name=distribution.batch_dim_name, ) + return tree.map_structure(jax_dist_data_input, data, layouts) return tree.map_structure(jax.device_put, data) diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py index b4736426afe..1528fa8fc15 100644 --- a/keras/src/distribution/distribution_lib.py +++ b/keras/src/distribution/distribution_lib.py @@ -287,8 +287,9 @@ class Distribution: device_mesh: A `DeviceMesh` instance. """ - def __init__(self, device_mesh): + def __init__(self, device_mesh, batch_dim_name=None): self._device_mesh = device_mesh + self._batch_dim_name = batch_dim_name def get_data_layout(self, data_shape): """Retrieve the `TensorLayout` for the input data. @@ -341,6 +342,10 @@ def scope(self): def device_mesh(self): return self._device_mesh + @property + def batch_dim_name(self): + return self._batch_dim_name + def distribute_dataset(self, dataset): """Create a distributed dataset instance from the original user dataset. @@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True): else: self._initialize_mesh_from_list_devices() - self._batch_dim_name = self.device_mesh.axis_names[0] # Those following attributes might get convert to public methods. self._num_process = distribution_lib.num_processes() self._process_id = distribution_lib.process_id() @@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh): "Expect `mesh` to be an instance of `DeviceMesh`. " f"Received: mesh={device_mesh} (of type {type(device_mesh)})" ) - super().__init__(device_mesh) + super().__init__(device_mesh, device_mesh.axis_names[0]) if self.device_mesh.devices.ndim != 1: warnings.warn( "Expect the input mesh to be 1D, but received " @@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices): axis_names=[DEFAULT_BATCH_DIM_NAME], devices=devices, ) - super().__init__(device_mesh) + super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME) def _initialize_mesh_from_list_devices(self): devices = np.array(list_devices()) @@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self): axis_names=[DEFAULT_BATCH_DIM_NAME], devices=devices, ) - super().__init__(device_mesh) + super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME) def get_data_layout(self, data_shape): data_shard_spec = [None] * len(data_shape) - data_shard_spec[0] = self._batch_dim_name # Shard on the first dim + data_shard_spec[0] = self.batch_dim_name # Shard on the first dim return TensorLayout(data_shard_spec, self.device_mesh) def get_variable_layout(self, variable): @@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs): def get_data_layout(self, data_shape): data_shard_spec = [None] * len(data_shape) - data_shard_spec[0] = self._batch_dim_name # Shard on the first dim + data_shard_spec[0] = self.batch_dim_name # Shard on the first dim return TensorLayout(data_shard_spec, self.device_mesh) def get_variable_layout(self, variable): @@ -631,7 +635,7 @@ def distribute_dataset(self, dataset): # Note that this might be smaller than one if model replicas are sharded # across multiple processes. mesh_batch_dim_index = self.device_mesh.axis_names.index( - self._batch_dim_name + self.batch_dim_name ) num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index] if num_model_replicas == 1: diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py index 8fd0988aec3..fba998fae46 100644 --- a/keras/src/distribution/distribution_lib_test.py +++ b/keras/src/distribution/distribution_lib_test.py @@ -186,7 +186,7 @@ def test_create_with_device_mesh(self): device_mesh = distribution.device_mesh self.assertEqual(len(device_mesh.devices), 8) self.assertEqual(device_mesh.axis_names, ["data"]) - self.assertEqual(distribution._batch_dim_name, "data") + self.assertEqual(distribution.batch_dim_name, "data") self.assertFalse(distribution._is_multi_process) self.assertEqual(distribution._process_id, 0) @@ -197,7 +197,7 @@ def test_create_with_devices(self): device_mesh = distribution.device_mesh self.assertEqual(len(device_mesh.devices), 8) self.assertEqual(device_mesh.axis_names, ["batch"]) - self.assertEqual(distribution._batch_dim_name, "batch") + self.assertEqual(distribution.batch_dim_name, "batch") @mock.patch.object( distribution_lib, @@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices): device_mesh = distribution.device_mesh self.assertEqual(len(device_mesh.devices), 8) self.assertEqual(device_mesh.axis_names, ["batch"]) - self.assertEqual(distribution._batch_dim_name, "batch") + self.assertEqual(distribution.batch_dim_name, "batch") def test_get_data_layout(self): distribution = distribution_lib.DataParallel( diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index f9719bfe442..59f241cbaf2 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -96,6 +96,9 @@ MaxNumBoundingBoxes, ) from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp +from keras.src.layers.preprocessing.image_preprocessing.rand_augment import ( + RandAugment, +) from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 8e36bb20456..a4f830912d5 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -85,12 +85,10 @@ class Layer(BackendLayer, Operation, KerasSaveable): trainable: Boolean, whether the layer's variables should be trainable. name: String name of the layer. dtype: The dtype of the layer's computations and weights. Can also be a - `keras.DTypePolicy`, - which allows the computation and - weight dtype to differ. Defaults to `None`. `None` means to use - `keras.config.dtype_policy()`, - which is a `float32` policy unless set to different value - (via `keras.config.set_dtype_policy()`). + `keras.DTypePolicy`, which allows the computation and weight dtype + to differ. Defaults to `None`. `None` means to use + `keras.config.dtype_policy()`, which is a `float32` policy unless + set to different value (via `keras.config.set_dtype_policy()`). Attributes: name: The name of the layer (string). diff --git a/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py b/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py new file mode 100644 index 00000000000..f7d2794ce26 --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py @@ -0,0 +1,235 @@ +import random + +import keras.src.layers as layers +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator +from keras.src.utils import backend_utils + + +@keras_export("keras.layers.RandAugment") +class RandAugment(BaseImagePreprocessingLayer): + """RandAugment performs the Rand Augment operation on input images. + + This layer can be thought of as an all-in-one image augmentation layer. The + policy implemented by this layer has been benchmarked extensively and is + effective on a wide variety of datasets. + + References: + - [RandAugment](https://arxiv.org/abs/1909.13719) + + Args: + value_range: The range of values the input image can take. + Default is `(0, 255)`. Typically, this would be `(0, 1)` + for normalized images or `(0, 255)` for raw images. + num_ops: The number of augmentation operations to apply sequentially + to each image. Default is 2. + factor: The strength of the augmentation as a normalized value + between 0 and 1. Default is 0.5. + interpolation: The interpolation method to use for resizing operations. + Options include `nearest`, `bilinear`. Default is `bilinear`. + seed: Integer. Used to create a random seed. + + """ + + _USE_BASE_FACTOR = False + _FACTOR_BOUNDS = (0, 1) + + _AUGMENT_LAYERS = [ + "random_shear", + "random_translation", + "random_rotation", + "random_brightness", + "random_color_degeneration", + "random_contrast", + "random_sharpness", + "random_posterization", + "solarization", + "auto_contrast", + "equalization", + ] + + def __init__( + self, + value_range=(0, 255), + num_ops=2, + factor=0.5, + interpolation="bilinear", + seed=None, + data_format=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + + self.value_range = value_range + self.num_ops = num_ops + self._set_factor(factor) + self.interpolation = interpolation + self.seed = seed + self.generator = SeedGenerator(seed) + + self.random_shear = layers.RandomShear( + x_factor=self.factor, + y_factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_translation = layers.RandomTranslation( + height_factor=self.factor, + width_factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_rotation = layers.RandomRotation( + factor=self.factor, + interpolation=interpolation, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_brightness = layers.RandomBrightness( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_color_degeneration = layers.RandomColorDegeneration( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_contrast = layers.RandomContrast( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_sharpness = layers.RandomSharpness( + factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.solarization = layers.Solarization( + addition_factor=self.factor, + threshold_factor=self.factor, + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.random_posterization = layers.RandomPosterization( + factor=max(1, int(8 * self.factor[1])), + value_range=self.value_range, + seed=self.seed, + data_format=data_format, + **kwargs, + ) + + self.auto_contrast = layers.AutoContrast( + value_range=self.value_range, data_format=data_format, **kwargs + ) + + self.equalization = layers.Equalization( + value_range=self.value_range, data_format=data_format, **kwargs + ) + + def build(self, input_shape): + for layer_name in self._AUGMENT_LAYERS: + augmentation_layer = getattr(self, layer_name) + augmentation_layer.build(input_shape) + + def get_random_transformation(self, data, training=True, seed=None): + if not training: + return None + + if backend_utils.in_tf_graph(): + self.backend.set_backend("tensorflow") + + for layer_name in self._AUGMENT_LAYERS: + augmentation_layer = getattr(self, layer_name) + augmentation_layer.backend.set_backend("tensorflow") + + transformation = {} + random.shuffle(self._AUGMENT_LAYERS) + for layer_name in self._AUGMENT_LAYERS[: self.num_ops]: + augmentation_layer = getattr(self, layer_name) + transformation[layer_name] = ( + augmentation_layer.get_random_transformation( + data, + training=training, + seed=self._get_seed_generator(self.backend._backend), + ) + ) + + return transformation + + def transform_images(self, images, transformation, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + + for layer_name, transformation_value in transformation.items(): + augmentation_layer = getattr(self, layer_name) + images = augmentation_layer.transform_images( + images, transformation_value + ) + + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_bounding_boxes( + self, + bounding_boxes, + transformation, + training=True, + ): + if training: + for layer_name, transformation_value in transformation.items(): + augmentation_layer = getattr(self, layer_name) + bounding_boxes = augmentation_layer.transform_bounding_boxes( + bounding_boxes, transformation_value, training=training + ) + return bounding_boxes + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return self.transform_images( + segmentation_masks, transformation, training=training + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "value_range": self.value_range, + "num_ops": self.num_ops, + "factor": self.factor, + "interpolation": self.interpolation, + "seed": self.seed, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py b/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py new file mode 100644 index 00000000000..6abe65e240a --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandAugmentTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandAugment, + init_kwargs={ + "value_range": (0, 255), + "num_ops": 2, + "factor": 1, + "interpolation": "nearest", + "seed": 1, + "data_format": "channels_last", + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_rand_augment_inference(self): + seed = 3481 + layer = layers.RandAugment() + + np.random.seed(seed) + inputs = np.random.randint(0, 255, size=(224, 224, 3)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output) + + def test_rand_augment_basic(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment(data_format=data_format) + + augmented_image = layer(input_data) + self.assertEqual(augmented_image.shape, input_data.shape) + + def test_rand_augment_no_operations(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment(num_ops=0, data_format=data_format) + + augmented_image = layer(input_data) + self.assertAllClose( + backend.convert_to_numpy(augmented_image), input_data + ) + + def test_random_augment_randomness(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + + layer = layers.RandAugment(num_ops=11, data_format=data_format) + augmented_image = layer(input_data) + + self.assertNotAllClose( + backend.convert_to_numpy(augmented_image), input_data + ) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandAugment(data_format=data_format) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy() + + def test_rand_augment_tf_data_bounding_boxes(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + image_shape = (1, 10, 8, 3) + else: + image_shape = (1, 3, 10, 8) + input_image = np.random.random(image_shape) + bounding_boxes = { + "boxes": np.array( + [ + [ + [2, 1, 4, 3], + [6, 4, 8, 6], + ] + ] + ), + "labels": np.array([[1, 2]]), + } + + input_data = {"images": input_image, "bounding_boxes": bounding_boxes} + + ds = tf_data.Dataset.from_tensor_slices(input_data) + layer = layers.RandAugment( + data_format=data_format, + seed=42, + bounding_box_format="xyxy", + ) + ds.map(layer) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py index 6a6c3c79102..b33bb439c53 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py @@ -34,7 +34,7 @@ def test_correctness(self): seed = 2390 # Always scale up, but randomly between 0 ~ 255 - layer = layers.RandomBrightness([0, 1.0]) + layer = layers.RandomBrightness([0.1, 1.0]) np.random.seed(seed) inputs = np.random.randint(0, 255, size=(224, 224, 3)) output = backend.convert_to_numpy(layer(inputs)) @@ -44,7 +44,7 @@ def test_correctness(self): self.assertTrue(np.mean(diff) > 0) # Always scale down, but randomly between 0 ~ 255 - layer = layers.RandomBrightness([-1.0, 0.0]) + layer = layers.RandomBrightness([-1.0, -0.1]) np.random.seed(seed) inputs = np.random.randint(0, 255, size=(224, 224, 3)) output = backend.convert_to_numpy(layer(inputs)) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py index ea1e4b882fe..7ddc2b3eaf0 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_rotation.py @@ -182,12 +182,11 @@ def get_random_transformation(self, data, training=True, seed=None): images = data shape = ops.core.shape(images) if len(shape) == 4: + batch_size = shape[0] if self.data_format == "channels_last": - batch_size = shape[0] image_height = shape[1] image_width = shape[2] else: - batch_size = shape[1] image_height = shape[2] image_width = shape[3] else: diff --git a/keras/src/layers/preprocessing/stft_spectrogram.py b/keras/src/layers/preprocessing/stft_spectrogram.py index 6834bc356c2..eb44343589f 100644 --- a/keras/src/layers/preprocessing/stft_spectrogram.py +++ b/keras/src/layers/preprocessing/stft_spectrogram.py @@ -232,7 +232,7 @@ def build(self, input_shape): self.built = True def _adjust_shapes(self, outputs): - _, channels, freq_channels, time_seq = outputs.shape + _, channels, freq_channels, time_seq = ops.shape(outputs) batch_size = -1 if self.data_format == "channels_last": if self.expand_dims: @@ -258,11 +258,11 @@ def _adjust_shapes(self, outputs): def _apply_conv(self, inputs, kernel): if self.data_format == "channels_last": - _, time_seq, channels = inputs.shape + _, time_seq, channels = ops.shape(inputs) inputs = ops.transpose(inputs, [0, 2, 1]) inputs = ops.reshape(inputs, [-1, time_seq, 1]) else: - _, channels, time_seq = inputs.shape + _, channels, time_seq = ops.shape(inputs) inputs = ops.reshape(inputs, [-1, 1, time_seq]) outputs = ops.conv( @@ -274,14 +274,14 @@ def _apply_conv(self, inputs, kernel): ) batch_size = -1 if self.data_format == "channels_last": - _, time_seq, freq_channels = outputs.shape + _, time_seq, freq_channels = ops.shape(outputs) outputs = ops.transpose(outputs, [0, 2, 1]) outputs = ops.reshape( outputs, [batch_size, channels, freq_channels, time_seq], ) else: - _, freq_channels, time_seq = outputs.shape + _, freq_channels, time_seq = ops.shape(outputs) outputs = ops.reshape( outputs, [batch_size, channels, freq_channels, time_seq], diff --git a/keras/src/layers/preprocessing/stft_spectrogram_test.py b/keras/src/layers/preprocessing/stft_spectrogram_test.py index fa2eb878bb0..769a961cd5e 100644 --- a/keras/src/layers/preprocessing/stft_spectrogram_test.py +++ b/keras/src/layers/preprocessing/stft_spectrogram_test.py @@ -273,6 +273,30 @@ def test_spectrogram_basics(self): supports_masking=False, ) + @pytest.mark.skipif( + backend.backend() != "tensorflow", + reason="Backend does not support dynamic shapes", + ) + def test_spectrogram_dynamic_shape(self): + model = Sequential( + [ + Input(shape=(None, 1), dtype=TestSpectrogram.DTYPE), + layers.STFTSpectrogram( + frame_length=500, + frame_step=25, + fft_length=1024, + mode="stft", + data_format="channels_last", + ), + ] + ) + + def generator(): + yield (np.random.random((2, 16000, 1)),) + yield (np.random.random((3, 8000, 1)),) + + model.predict(generator()) + @pytest.mark.requires_trainable_backend def test_spectrogram_error(self): rnd = np.random.RandomState(41) @@ -310,10 +334,9 @@ def test_spectrogram_error(self): init_args["mode"] = "angle" y_true, y = self._calc_spectrograms(x, **init_args) - pi = np.arccos(np.float128(-1)).astype(y_true.dtype) mask = np.isclose(y, y_true, **tol_kwargs) - mask |= np.isclose(y + 2 * pi, y_true, **tol_kwargs) - mask |= np.isclose(y - 2 * pi, y_true, **tol_kwargs) + mask |= np.isclose(y + 2 * np.pi, y_true, **tol_kwargs) + mask |= np.isclose(y - 2 * np.pi, y_true, **tol_kwargs) mask |= np.isclose(np.cos(y), np.cos(y_true), **tol_kwargs) mask |= np.isclose(np.sin(y), np.sin(y_true), **tol_kwargs) diff --git a/keras/src/metrics/iou_metrics.py b/keras/src/metrics/iou_metrics.py index 65c84e591b9..f9a10ffb647 100644 --- a/keras/src/metrics/iou_metrics.py +++ b/keras/src/metrics/iou_metrics.py @@ -474,7 +474,6 @@ class MeanIoU(IoU): is used to determine each sample's most likely associated label. axis: (Optional) The dimension containing the logits. Defaults to `-1`. - Example: Example: diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index eb83cad4235..6ed7d3c6543 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) self.assertListEqual(hist_keys, ref_keys) - @parameterized.named_parameters( - ("tf_saved_model", "tf_saved_model"), - ("onnx", "onnx"), - ) - @pytest.mark.skipif( - backend.backend() not in ("tensorflow", "jax", "torch"), - reason=( - "Currently, `Model.export` only supports the tensorflow, jax and " - "torch backends." - ), - ) - @pytest.mark.skipif( - testing.jax_uses_gpu(), reason="Leads to core dumps on CI" - ) - def test_export(self, export_format): - if export_format == "tf_saved_model" and testing.torch_uses_gpu(): - self.skipTest("Leads to core dumps on CI") - - temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") - model = _get_model() - x1 = np.random.rand(1, 3).astype("float32") - x2 = np.random.rand(1, 3).astype("float32") - ref_output = model([x1, x2]) - - model.export(temp_filepath, format=export_format) - - if export_format == "tf_saved_model": - import tensorflow as tf - - revived_model = tf.saved_model.load(temp_filepath) - self.assertAllClose(ref_output, revived_model.serve([x1, x2])) - - # Test with a different batch size - if backend.backend() == "torch": - # TODO: Dynamic shape is not supported yet in the torch backend - return - revived_model.serve( - [ - np.concatenate([x1, x1], axis=0), - np.concatenate([x2, x2], axis=0), - ] - ) - elif export_format == "onnx": - import onnxruntime - - ort_session = onnxruntime.InferenceSession(temp_filepath) - ort_inputs = { - k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2]) - } - self.assertAllClose( - ref_output, ort_session.run(None, ort_inputs)[0] - ) - - # Test with a different batch size - if backend.backend() == "torch": - # TODO: Dynamic shape is not supported yet in the torch backend - return - ort_inputs = { - k.name: v - for k, v in zip( - ort_session.get_inputs(), - [ - np.concatenate([x1, x1], axis=0), - np.concatenate([x2, x2], axis=0), - ], - ) - } - ort_session.run(None, ort_inputs) - def test_export_error(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = _get_model() diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py index 7caaa41f628..6cedef62cee 100644 --- a/keras/src/ops/math.py +++ b/keras/src/ops/math.py @@ -888,7 +888,7 @@ def istft( sequence_length: An integer representing the sequence length. sequence_stride: An integer representing the sequence hop size. fft_length: An integer representing the size of the FFT that produced - `stft`. + `stft`. Should be of type `int32`. length: An integer representing the output is clipped to exactly length. If not specified, no padding or clipping take place. Defaults to `None`. diff --git a/keras/src/utils/jax_layer.py b/keras/src/utils/jax_layer.py index 8fd69d1f5bf..67416cef61a 100644 --- a/keras/src/utils/jax_layer.py +++ b/keras/src/utils/jax_layer.py @@ -5,6 +5,8 @@ from keras.src import backend from keras.src import tree from keras.src.api_export import keras_export +from keras.src.backend.common.variables import is_float_dtype +from keras.src.backend.common.variables import standardize_dtype from keras.src.layers.layer import Layer from keras.src.saving import serialization_lib from keras.src.utils import jax_utils @@ -204,6 +206,8 @@ def my_haiku_module_fn(inputs, training): argument, then `init_fn` is called at build time to initialize the non-trainable state of the model. seed: Seed for random number generator. Optional. + dtype: The dtype of the layer's computations and weights. Can also be a + `keras.DTypePolicy`. Optional. Defaults to the default policy. """ def __init__( @@ -213,6 +217,7 @@ def __init__( params=None, state=None, seed=None, + dtype=None, **kwargs, ): if backend.backend() != "jax": @@ -226,9 +231,10 @@ def __init__( "`init_fn`, `params` and `state` cannot all be `None`." ) - super().__init__(**kwargs) + super().__init__(dtype=dtype, **kwargs) self.call_fn = call_fn self.init_fn = init_fn + self.has_dtype_policy = dtype is not None self.seed_generator = backend.random.SeedGenerator(seed) self.tracked_params = self._create_variables(params, trainable=True) self.tracked_state = self._create_variables(state, trainable=False) @@ -291,18 +297,28 @@ def _create_variables(self, values, trainable): """ def create_variable(value): - if backend.is_tensor(value) or isinstance(value, np.ndarray): - variable = self.add_weight( - value.shape, initializer="zeros", trainable=trainable + if backend.is_tensor(value) or isinstance( + value, (np.ndarray, np.generic) + ): + dtype = value.dtype + if self.has_dtype_policy and is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + value.shape, + initializer=value, + dtype=dtype, + trainable=trainable, ) - variable.assign(value) - return variable - elif isinstance(value, (np.generic, int, float)): - variable = self.add_weight( - (), initializer="zeros", trainable=trainable + elif isinstance(value, (bool, int, float)): + dtype = standardize_dtype(type(value)) + if self.has_dtype_policy and is_float_dtype(dtype): + dtype = None # Use the layer dtype policy + return self.add_weight( + (), + initializer=backend.convert_to_tensor(value), + dtype=dtype, + trainable=trainable, ) - variable.assign(value) - return variable else: return value diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 359bdca41c9..306c930660f 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -15,6 +15,7 @@ from keras.src import testing from keras.src import tree from keras.src import utils +from keras.src.dtype_policies.dtype_policy import DTypePolicy from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer @@ -362,6 +363,18 @@ def call(self, inputs): "non_trainable_weights": 1, "non_trainable_params": 1, }, + { + "testcase_name": "training_state_dtype_policy", + "init_kwargs": { + "call_fn": jax_stateful_apply, + "init_fn": jax_stateful_init, + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 6, + "trainable_params": 266610, + "non_trainable_weights": 1, + "non_trainable_params": 1, + }, ) def test_jax_layer( self, @@ -414,6 +427,19 @@ def test_jax_layer( "non_trainable_weights": 8, "non_trainable_params": 536, }, + { + "testcase_name": "training_rng_unbound_method_dtype_policy", + "flax_model_class": "FlaxDropoutModel", + "flax_model_method": None, + "init_kwargs": { + "method": "flax_dropout_wrapper", + "dtype": DTypePolicy("mixed_float16"), + }, + "trainable_weights": 8, + "trainable_params": 648226, + "non_trainable_weights": 0, + "non_trainable_params": 0, + }, ) @pytest.mark.skipif(flax is None, reason="Flax library is not available.") def test_flax_layer(