From 5a9b26d1c7d4ee7f5e9472f7d393888e79971a29 Mon Sep 17 00:00:00 2001 From: "Hongyu, Chiu" <20734616+james77777778@users.noreply.github.com> Date: Fri, 13 Dec 2024 08:34:27 +0800 Subject: [PATCH] Rework `Model.export` and `keras.export.ExportArchive` to support exporting in TFLite and ONNX formats in the future (#20631) * Rework `Model.export` and `keras.export.ExportArchive` * Try fixing PyDatasetAdapterTest CI issues --- keras/src/backend/jax/export.py | 194 ++++++++ keras/src/backend/numpy/export.py | 10 + keras/src/backend/tensorflow/export.py | 32 ++ keras/src/backend/torch/export.py | 35 ++ keras/src/export/export_lib.py | 440 ++++++++---------- keras/src/export/export_lib_test.py | 104 ++++- keras/src/layers/__init__.py | 1 + keras/src/layers/core/dense_test.py | 4 +- keras/src/layers/core/einsum_dense_test.py | 4 +- keras/src/layers/core/embedding_test.py | 2 +- keras/src/models/model.py | 76 ++- keras/src/models/model_test.py | 45 ++ .../data_adapters/py_dataset_adapter_test.py | 13 +- keras/src/utils/jax_layer_test.py | 3 +- 14 files changed, 662 insertions(+), 301 deletions(-) create mode 100644 keras/src/backend/jax/export.py create mode 100644 keras/src/backend/numpy/export.py create mode 100644 keras/src/backend/tensorflow/export.py create mode 100644 keras/src/backend/torch/export.py diff --git a/keras/src/backend/jax/export.py b/keras/src/backend/jax/export.py new file mode 100644 index 00000000000..963648460dc --- /dev/null +++ b/keras/src/backend/jax/export.py @@ -0,0 +1,194 @@ +import copy +import inspect +import itertools +import string +import warnings + +from keras.src import layers +from keras.src import tree +from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.utils.module_utils import tensorflow as tf + + +class JaxExportArchive: + def __init__(self): + self._backend_variables = [] + self._backend_trainable_variables = [] + self._backend_non_trainable_variables = [] + + def track(self, resource): + if not isinstance(resource, layers.Layer): + raise ValueError( + "Invalid resource type. Expected an instance of a " + "JAX-based Keras `Layer` or `Model`. " + f"Received instead an object of type '{type(resource)}'. " + f"Object received: {resource}" + ) + + if isinstance(resource, layers.Layer): + # Variables in the lists below are actually part of the trackables + # that get saved, because the lists are created in __init__. + trainable_variables = resource.trainable_variables + non_trainable_variables = resource.non_trainable_variables + + self._tf_trackable.trainable_variables += tree.map_structure( + self._convert_to_tf_variable, trainable_variables + ) + self._tf_trackable.non_trainable_variables += tree.map_structure( + self._convert_to_tf_variable, non_trainable_variables + ) + self._tf_trackable.variables = ( + self._tf_trackable.trainable_variables + + self._tf_trackable.non_trainable_variables + ) + + self._backend_trainable_variables += trainable_variables + self._backend_non_trainable_variables += non_trainable_variables + self._backend_variables = ( + self._backend_trainable_variables + + self._backend_non_trainable_variables + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None) + # Use `copy.copy()` to avoid modification issues. + jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {} + is_static = bool(kwargs.pop("is_static", False)) + + # Configure `jax2tf_kwargs` + if "native_serialization" not in jax2tf_kwargs: + jax2tf_kwargs["native_serialization"] = ( + self._check_device_compatible() + ) + if "polymorphic_shapes" not in jax2tf_kwargs: + jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( + input_signature + ) + + # Note: we truncate the number of parameters to what is specified by + # `input_signature`. + fn_signature = inspect.signature(fn) + fn_parameters = list(fn_signature.parameters.values()) + + if is_static: + from jax.experimental import jax2tf + + jax_fn = jax2tf.convert(fn, **jax2tf_kwargs) + jax_fn.__signature__ = inspect.Signature( + parameters=fn_parameters[0 : len(input_signature)], + return_annotation=fn_signature.return_annotation, + ) + + decorated_fn = tf.function( + jax_fn, + input_signature=input_signature, + autograph=False, + ) + else: + # 1. Create a stateless wrapper for `fn` + # 2. jax2tf the stateless wrapper + # 3. Create a stateful function that binds the variables with + # the jax2tf converted stateless wrapper + # 4. Make the signature of the stateful function the same as the + # original function + # 5. Wrap in a `tf.function` + def stateless_fn(variables, *args, **kwargs): + state_mapping = zip(self._backend_variables, variables) + with StatelessScope(state_mapping=state_mapping) as scope: + output = fn(*args, **kwargs) + + # Gather updated non-trainable variables + non_trainable_variables = [] + for var in self._backend_non_trainable_variables: + new_value = scope.get_current_value(var) + non_trainable_variables.append(new_value) + return output, non_trainable_variables + + jax2tf_stateless_fn = self._convert_jax2tf_function( + stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs + ) + + def stateful_fn(*args, **kwargs): + output, non_trainable_variables = jax2tf_stateless_fn( + # Change the trackable `ListWrapper` to a plain `list` + list(self._tf_trackable.variables), + *args, + **kwargs, + ) + for var, new_value in zip( + self._tf_trackable.non_trainable_variables, + non_trainable_variables, + ): + var.assign(new_value) + return output + + stateful_fn.__signature__ = inspect.Signature( + parameters=fn_parameters[0 : len(input_signature)], + return_annotation=fn_signature.return_annotation, + ) + + decorated_fn = tf.function( + stateful_fn, + input_signature=input_signature, + autograph=False, + ) + return decorated_fn + + def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None): + from jax.experimental import jax2tf + + variables_shapes = self._to_polymorphic_shape( + self._backend_variables, allow_none=False + ) + input_shapes = list(jax2tf_kwargs["polymorphic_shapes"]) + jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes + return jax2tf.convert(fn, **jax2tf_kwargs) + + def _to_polymorphic_shape(self, struct, allow_none=True): + if allow_none: + # Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz + # for unknown non-batch dims. Defined here to be scope per endpoint. + dim_names = itertools.chain( + string.ascii_lowercase, + itertools.starmap( + lambda a, b: a + b, + itertools.product(string.ascii_lowercase, repeat=2), + ), + ) + + def convert_shape(x): + poly_shape = [] + for index, dim in enumerate(list(x.shape)): + if dim is not None: + poly_shape.append(str(dim)) + elif not allow_none: + raise ValueError( + f"Illegal None dimension in {x} with shape {x.shape}" + ) + elif index == 0: + poly_shape.append("batch") + else: + poly_shape.append(next(dim_names)) + return "(" + ", ".join(poly_shape) + ")" + + return tree.map_structure(convert_shape, struct) + + def _check_device_compatible(self): + from jax import default_backend as jax_device + + if ( + jax_device() == "gpu" + and len(tf.config.list_physical_devices("GPU")) == 0 + ): + warnings.warn( + "JAX backend is using GPU for export, but installed " + "TF package cannot access GPU, so reloading the model with " + "the TF runtime in the same environment will not work. " + "To use JAX-native serialization for high-performance export " + "and serving, please install `tensorflow-gpu` and ensure " + "CUDA version compatibility between your JAX and TF " + "installations." + ) + return False + else: + return True diff --git a/keras/src/backend/numpy/export.py b/keras/src/backend/numpy/export.py new file mode 100644 index 00000000000..f754c5bc633 --- /dev/null +++ b/keras/src/backend/numpy/export.py @@ -0,0 +1,10 @@ +class NumpyExportArchive: + def track(self, resource): + raise NotImplementedError( + "`track` is not implemented in the numpy backend." + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + raise NotImplementedError( + "`add_endpoint` is not implemented in the numpy backend." + ) diff --git a/keras/src/backend/tensorflow/export.py b/keras/src/backend/tensorflow/export.py new file mode 100644 index 00000000000..d3aaef63da5 --- /dev/null +++ b/keras/src/backend/tensorflow/export.py @@ -0,0 +1,32 @@ +import tensorflow as tf + +from keras.src import layers + + +class TFExportArchive: + def track(self, resource): + if not isinstance(resource, tf.__internal__.tracking.Trackable): + raise ValueError( + "Invalid resource type. Expected an instance of a " + "TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). " + f"Received instead an object of type '{type(resource)}'. " + f"Object received: {resource}" + ) + + if isinstance(resource, layers.Layer): + # Variables in the lists below are actually part of the trackables + # that get saved, because the lists are created in __init__. + variables = resource.variables + trainable_variables = resource.trainable_variables + non_trainable_variables = resource.non_trainable_variables + self._tf_trackable.variables += variables + self._tf_trackable.trainable_variables += trainable_variables + self._tf_trackable.non_trainable_variables += ( + non_trainable_variables + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + decorated_fn = tf.function( + fn, input_signature=input_signature, autograph=False + ) + return decorated_fn diff --git a/keras/src/backend/torch/export.py b/keras/src/backend/torch/export.py new file mode 100644 index 00000000000..55fc68ed954 --- /dev/null +++ b/keras/src/backend/torch/export.py @@ -0,0 +1,35 @@ +from keras.src import layers +from keras.src import tree + + +class TorchExportArchive: + def track(self, resource): + if not isinstance(resource, layers.Layer): + raise ValueError( + "Invalid resource type. Expected an instance of a " + "JAX-based Keras `Layer` or `Model`. " + f"Received instead an object of type '{type(resource)}'. " + f"Object received: {resource}" + ) + + if isinstance(resource, layers.Layer): + # Variables in the lists below are actually part of the trackables + # that get saved, because the lists are created in __init__. + variables = resource.variables + trainable_variables = resource.trainable_variables + non_trainable_variables = resource.non_trainable_variables + self._tf_trackable.variables += tree.map_structure( + self._convert_to_tf_variable, variables + ) + self._tf_trackable.trainable_variables += tree.map_structure( + self._convert_to_tf_variable, trainable_variables + ) + self._tf_trackable.non_trainable_variables += tree.map_structure( + self._convert_to_tf_variable, non_trainable_variables + ) + + def add_endpoint(self, name, fn, input_signature=None, **kwargs): + # TODO: torch-xla? + raise NotImplementedError( + "`add_endpoint` is not implemented in the torch backend." + ) diff --git a/keras/src/export/export_lib.py b/keras/src/export/export_lib.py index 923ca2e86e7..3f5e04d9309 100644 --- a/keras/src/export/export_lib.py +++ b/keras/src/export/export_lib.py @@ -1,24 +1,38 @@ """Library for exporting inference-only Keras models/layers.""" -import inspect -import itertools -import string - -from absl import logging - from keras.src import backend +from keras.src import layers from keras.src import tree from keras.src.api_export import keras_export -from keras.src.backend.common.stateless_scope import StatelessScope -from keras.src.layers import Layer from keras.src.models import Functional from keras.src.models import Sequential from keras.src.utils import io_utils from keras.src.utils.module_utils import tensorflow as tf +if backend.backend() == "tensorflow": + from keras.src.backend.tensorflow.export import ( + TFExportArchive as BackendExportArchive, + ) +elif backend.backend() == "jax": + from keras.src.backend.jax.export import ( + JaxExportArchive as BackendExportArchive, + ) +elif backend.backend() == "torch": + from keras.src.backend.torch.export import ( + TorchExportArchive as BackendExportArchive, + ) +elif backend.backend() == "numpy": + from keras.src.backend.numpy.export import ( + NumpyExportArchive as BackendExportArchive, + ) +else: + raise RuntimeError( + f"Backend '{backend.backend()}' must implement a layer mixin class." + ) + @keras_export("keras.export.ExportArchive") -class ExportArchive: +class ExportArchive(BackendExportArchive): """ExportArchive is used to write SavedModel artifacts (e.g. for inference). If you have a Keras model or layer that you want to export as SavedModel for @@ -42,7 +56,7 @@ class ExportArchive: export_archive.add_endpoint( name="serve", fn=model.call, - input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], ) export_archive.write_out("path/to/location") @@ -61,12 +75,12 @@ class ExportArchive: export_archive.add_endpoint( name="call_inference", fn=lambda x: model.call(x, training=False), - input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], ) export_archive.add_endpoint( name="call_training", fn=lambda x: model.call(x, training=True), - input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], ) export_archive.write_out("path/to/location") ``` @@ -85,6 +99,12 @@ class ExportArchive: """ def __init__(self): + super().__init__() + if backend.backend() not in ("tensorflow", "jax"): + raise NotImplementedError( + "The export API is only compatible with JAX and TF backends." + ) + self._endpoint_names = [] self._endpoint_signatures = {} self.tensorflow_version = tf.__version__ @@ -94,16 +114,6 @@ def __init__(self): self._tf_trackable.trainable_variables = [] self._tf_trackable.non_trainable_variables = [] - if backend.backend() == "jax": - self._backend_variables = [] - self._backend_trainable_variables = [] - self._backend_non_trainable_variables = [] - - if backend.backend() not in ("tensorflow", "jax"): - raise NotImplementedError( - "The export API is only compatible with JAX and TF backends." - ) - @property def variables(self): return self._tf_trackable.variables @@ -130,30 +140,11 @@ def track(self, resource): Arguments: resource: A trackable TensorFlow resource. """ - if backend.backend() == "tensorflow" and not isinstance( - resource, tf.__internal__.tracking.Trackable - ): + if isinstance(resource, layers.Layer) and not resource.built: raise ValueError( - "Invalid resource type. Expected an instance of a " - "TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). " - f"Received instead an object of type '{type(resource)}'. " - f"Object received: {resource}" + "The layer provided has not yet been built. " + "It must be built before export." ) - if backend.backend() == "jax" and not isinstance( - resource, backend.jax.layer.JaxLayer - ): - raise ValueError( - "Invalid resource type. Expected an instance of a " - "JAX-based Keras `Layer` or `Model`. " - f"Received instead an object of type '{type(resource)}'. " - f"Object received: {resource}" - ) - if isinstance(resource, Layer): - if not resource.built: - raise ValueError( - "The layer provided has not yet been built. " - "It must be built before export." - ) # Layers in `_tracked` are not part of the trackables that get saved, # because we're creating the attribute in a @@ -162,66 +153,39 @@ def track(self, resource): self._tracked = [] self._tracked.append(resource) - if isinstance(resource, Layer): - # Variables in the lists below are actually part of the trackables - # that get saved, because the lists are created in __init__. - if backend.backend() == "jax": - trainable_variables = tree.flatten(resource.trainable_variables) - non_trainable_variables = tree.flatten( - resource.non_trainable_variables - ) - self._backend_trainable_variables += trainable_variables - self._backend_non_trainable_variables += non_trainable_variables - self._backend_variables = ( - self._backend_trainable_variables - + self._backend_non_trainable_variables - ) - - self._tf_trackable.trainable_variables += [ - tf.Variable(v) for v in trainable_variables - ] - self._tf_trackable.non_trainable_variables += [ - tf.Variable(v) for v in non_trainable_variables - ] - self._tf_trackable.variables = ( - self._tf_trackable.trainable_variables - + self._tf_trackable.non_trainable_variables - ) - else: - self._tf_trackable.variables += resource.variables - self._tf_trackable.trainable_variables += ( - resource.trainable_variables - ) - self._tf_trackable.non_trainable_variables += ( - resource.non_trainable_variables - ) + BackendExportArchive.track(self, resource) - def add_endpoint(self, name, fn, input_signature=None, jax2tf_kwargs=None): + def add_endpoint(self, name, fn, input_signature=None, **kwargs): """Register a new serving endpoint. - Arguments: - name: Str, name of the endpoint. - fn: A function. It should only leverage resources - (e.g. `tf.Variable` objects or `tf.lookup.StaticHashTable` - objects) that are available on the models/layers - tracked by the `ExportArchive` (you can call `.track(model)` - to track a new model). + Args: + name: `str`. The name of the endpoint. + fn: A callable. It should only leverage resources + (e.g. `keras.Variable` objects or `tf.lookup.StaticHashTable` + objects) that are available on the models/layers tracked by the + `ExportArchive` (you can call `.track(model)` to track a new + model). The shape and dtype of the inputs to the function must be - known. For that purpose, you can either 1) make sure that - `fn` is a `tf.function` that has been called at least once, or - 2) provide an `input_signature` argument that specifies the - shape and dtype of the inputs (see below). - input_signature: Used to specify the shape and dtype of the - inputs to `fn`. List of `tf.TensorSpec` objects (one - per positional input argument of `fn`). Nested arguments are - allowed (see below for an example showing a Functional model - with 2 input arguments). - jax2tf_kwargs: Optional. A dict for arguments to pass to `jax2tf`. - Supported only when the backend is JAX. See documentation for - [`jax2tf.convert`]( - https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). - The values for `native_serialization` and `polymorphic_shapes`, - if not provided, are automatically computed. + known. For that purpose, you can either 1) make sure that `fn` + is a `tf.function` that has been called at least once, or 2) + provide an `input_signature` argument that specifies the shape + and dtype of the inputs (see below). + input_signature: Optional. Specifies the shape and dtype of `fn`. + Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor (see below for an + example showing a `Functional` model with 2 input arguments). If + not provided, `fn` must be a `tf.function` that has been called + at least once. Defaults to `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are + not provided, they are automatically computed. Returns: The `tf.function` wrapping `fn` that was added to the archive. @@ -237,7 +201,7 @@ def add_endpoint(self, name, fn, input_signature=None, jax2tf_kwargs=None): export_archive.add_endpoint( name="serve", fn=model.call, - input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], ) ``` @@ -251,8 +215,8 @@ def add_endpoint(self, name, fn, input_signature=None, jax2tf_kwargs=None): name="serve", fn=model.call, input_signature=[ - tf.TensorSpec(shape=(None, 3), dtype=tf.float32), - tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + keras.InputSpec(shape=(None, 3), dtype="float32"), + keras.InputSpec(shape=(None, 4), dtype="float32"), ], ) ``` @@ -271,8 +235,8 @@ def add_endpoint(self, name, fn, input_signature=None, jax2tf_kwargs=None): fn=model.call, input_signature=[ [ - tf.TensorSpec(shape=(None, 3), dtype=tf.float32), - tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + keras.InputSpec(shape=(None, 3), dtype="float32"), + keras.InputSpec(shape=(None, 4), dtype="float32"), ], ], ) @@ -290,8 +254,8 @@ def add_endpoint(self, name, fn, input_signature=None, jax2tf_kwargs=None): fn=model.call, input_signature=[ { - "x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32), - "x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + "x1": keras.InputSpec(shape=(None, 3), dtype="float32"), + "x2": keras.InputSpec(shape=(None, 4), dtype="float32"), }, ], ) @@ -315,73 +279,15 @@ def serving_fn(x): if name in self._endpoint_names: raise ValueError(f"Endpoint name '{name}' is already taken.") - if jax2tf_kwargs and backend.backend() != "jax": - raise ValueError( - "'jax2tf_kwargs' is only supported with the jax backend. " - f"Current backend: {backend.backend()}" - ) - - if input_signature: - if backend.backend() == "tensorflow": - decorated_fn = tf.function( - fn, input_signature=input_signature, autograph=False - ) - else: # JAX backend - # 1. Create a stateless wrapper for `fn` - # 2. jax2tf the stateless wrapper - # 3. Create a stateful function that binds the variables with - # the jax2tf converted stateless wrapper - # 4. Make the signature of the stateful function the same as the - # original function - # 5. Wrap in a `tf.function` - def stateless_fn(variables, *args, **kwargs): - state_mapping = zip(self._backend_variables, variables) - with StatelessScope(state_mapping=state_mapping) as scope: - output = fn(*args, **kwargs) - - # Gather updated non-trainable variables - non_trainable_variables = [] - for var in self._backend_non_trainable_variables: - new_value = scope.get_current_value(var) - non_trainable_variables.append(new_value) - return output, non_trainable_variables - - jax2tf_stateless_fn = self._convert_jax2tf_function( - stateless_fn, - input_signature, - jax2tf_kwargs=jax2tf_kwargs, - ) - - def stateful_fn(*args, **kwargs): - output, non_trainable_variables = jax2tf_stateless_fn( - # Change the trackable `ListWrapper` to a plain `list` - list(self._tf_trackable.variables), - *args, - **kwargs, - ) - for var, new_value in zip( - self._tf_trackable.non_trainable_variables, - non_trainable_variables, - ): - var.assign(new_value) - return output - - # Note: we truncate the number of parameters to what is - # specified by `input_signature`. - fn_signature = inspect.signature(fn) - fn_parameters = list(fn_signature.parameters.values()) - stateful_fn.__signature__ = inspect.Signature( - parameters=fn_parameters[0 : len(input_signature)], - return_annotation=fn_signature.return_annotation, + if backend.backend() != "jax": + if "jax2tf_kwargs" in kwargs or "is_static" in kwargs: + raise ValueError( + "'jax2tf_kwargs' and 'is_static' are only supported with " + f"the jax backend. Current backend: {backend.backend()}" ) - decorated_fn = tf.function( - stateful_fn, - input_signature=input_signature, - autograph=False, - ) - self._endpoint_signatures[name] = input_signature - else: + # The fast path if `fn` is already a `tf.function`. + if input_signature is None: if isinstance(fn, tf.types.experimental.GenericFunction): if not fn._list_all_concrete_functions(): raise ValueError( @@ -404,13 +310,22 @@ def stateful_fn(*args, **kwargs): " name='call',\n" " fn=model.call,\n" " input_signature=[\n" - " tf.TensorSpec(\n" + " keras.InputSpec(\n" " shape=(None, 224, 224, 3),\n" - " dtype=tf.float32,\n" + " dtype='float32',\n" " )\n" " ],\n" ")" ) + setattr(self._tf_trackable, name, decorated_fn) + self._endpoint_names.append(name) + return decorated_fn + + input_signature = tree.map_structure(_make_tensor_spec, input_signature) + decorated_fn = BackendExportArchive.add_endpoint( + self, name, fn, input_signature, **kwargs + ) + self._endpoint_signatures[name] = input_signature setattr(self._tf_trackable, name, decorated_fn) self._endpoint_names.append(name) return decorated_fn @@ -431,7 +346,7 @@ def add_variable_collection(self, name, variables): export_archive.add_endpoint( name="serve", fn=model.call, - input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)], + input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")], ) # Save a variable collection export_archive.add_variable_collection( @@ -460,7 +375,9 @@ def add_variable_collection(self, name, variables): f"{list(set(type(v) for v in variables))}" ) if backend.backend() == "jax": - variables = tree.flatten(tree.map_structure(tf.Variable, variables)) + variables = tree.flatten( + tree.map_structure(self._convert_to_tf_variable, variables) + ) setattr(self._tf_trackable, name, list(variables)) def write_out(self, filepath, options=None, verbose=True): @@ -517,6 +434,20 @@ def write_out(self, filepath, options=None, verbose=True): f"{endpoints}" ) + def _convert_to_tf_variable(self, backend_variable): + if not isinstance(backend_variable, backend.Variable): + raise TypeError( + "`backend_variable` must be a `backend.Variable`. " + f"Recevied: backend_variable={backend_variable} of type " + f"({type(backend_variable)})" + ) + return tf.Variable( + backend_variable.value, + dtype=backend_variable.dtype, + trainable=backend_variable.trainable, + name=backend_variable.name, + ) + def _get_concrete_fn(self, endpoint): """Workaround for some SavedModel quirks.""" if endpoint in self._endpoint_signatures: @@ -555,94 +486,81 @@ def _filter_and_track_resources(self): ): self._tf_trackable._misc_assets.append(trackable) - def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None): - from jax.experimental import jax2tf - - if jax2tf_kwargs is None: - jax2tf_kwargs = {} - - if "native_serialization" not in jax2tf_kwargs: - jax2tf_kwargs["native_serialization"] = ( - self._check_device_compatible() - ) - variables_shapes = self._to_polymorphic_shape( - self._backend_variables, allow_none=False - ) - if "polymorphic_shapes" in jax2tf_kwargs: - input_shapes = jax2tf_kwargs["polymorphic_shapes"] - else: - input_shapes = self._to_polymorphic_shape(input_signature) - jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes - - return jax2tf.convert(fn, **jax2tf_kwargs) - - def _to_polymorphic_shape(self, struct, allow_none=True): - if allow_none: - # Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz - # for unknown non-batch dims. Defined here to be scope per endpoint. - dim_names = itertools.chain( - string.ascii_lowercase, - itertools.starmap( - lambda a, b: a + b, - itertools.product(string.ascii_lowercase, repeat=2), - ), - ) +def export_saved_model( + model, filepath, verbose=True, input_signature=None, **kwargs +): + """Export the model as a TensorFlow SavedModel artifact for inference. + + **Note:** This feature is currently supported only with TensorFlow and + JAX backends. + + This method lets you export a model to a lightweight SavedModel artifact + that contains the model's forward pass only (its `call()` method) + and can be served via e.g. TensorFlow Serving. The forward pass is + registered under the name `serve()` (see example below). + + The original code of the model (including any custom layers you may + have used) is *no longer* necessary to reload the artifact -- it is + entirely standalone. + + Args: + filepath: `str` or `pathlib.Path` object. The path to save the artifact. + verbose: `bool`. Whether to print a message during export. Defaults to + True`. + input_signature: Optional. Specifies the shape and dtype of the model + inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`, + `backend.KerasTensor`, or backend tensor. If not provided, it will + be automatically computed. Defaults to `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are not + provided, they are automatically computed. - def convert_shape(x): - poly_shape = [] - for index, dim in enumerate(list(x.shape)): - if dim is not None: - poly_shape.append(str(dim)) - elif not allow_none: - raise ValueError( - f"Illegal None dimension in {x} with shape {x.shape}" - ) - elif index == 0: - poly_shape.append("batch") - else: - poly_shape.append(next(dim_names)) - return "(" + ", ".join(poly_shape) + ")" - - return tree.map_structure(convert_shape, struct) - - def _check_device_compatible(self): - from jax import default_backend as jax_device + Example: - if ( - jax_device() == "gpu" - and len(tf.config.list_physical_devices("GPU")) == 0 - ): - logging.warning( - "JAX backend is using GPU for export, but installed " - "TF package cannot access GPU, so reloading the model with " - "the TF runtime in the same environment will not work. " - "To use JAX-native serialization for high-performance export " - "and serving, please install `tensorflow-gpu` and ensure " - "CUDA version compatibility between your JAX and TF " - "installations." - ) - return False - else: - return True + ```python + # Export the model as a TensorFlow SavedModel artifact + model.export("path/to/location", format="tf_saved_model") + # Load the artifact in a different process/environment + reloaded_artifact = tf.saved_model.load("path/to/location") + predictions = reloaded_artifact.serve(input_data) + ``` -def export_model(model, filepath, verbose=True): + If you would like to customize your serving endpoints, you can + use the lower-level `keras.export.ExportArchive` class. The + `export()` method relies on `ExportArchive` internally. + """ export_archive = ExportArchive() export_archive.track(model) if isinstance(model, (Functional, Sequential)): - input_signature = tree.map_structure(_make_tensor_spec, model.inputs) + if input_signature is None: + input_signature = tree.map_structure( + _make_tensor_spec, model.inputs + ) if isinstance(input_signature, list) and len(input_signature) > 1: input_signature = [input_signature] - export_archive.add_endpoint("serve", model.__call__, input_signature) + export_archive.add_endpoint( + "serve", model.__call__, input_signature, **kwargs + ) else: - input_signature = _get_input_signature(model) + if input_signature is None: + input_signature = _get_input_signature(model) if not input_signature or not model._called: raise ValueError( "The model provided has never called. " "It must be called at least once before export." ) - export_archive.add_endpoint("serve", model.__call__, input_signature) + export_archive.add_endpoint( + "serve", model.__call__, input_signature, **kwargs + ) export_archive.write_out(filepath, verbose=verbose) @@ -677,7 +595,7 @@ def make_tensor_spec(structure): @keras_export("keras.layers.TFSMLayer") -class TFSMLayer(Layer): +class TFSMLayer(layers.Layer): """Reload a Keras model/layer that was saved via SavedModel / ExportArchive. Arguments: @@ -811,8 +729,28 @@ def get_config(self): def _make_tensor_spec(x): - shape = (None,) + x.shape[1:] - return tf.TensorSpec(shape, dtype=x.dtype, name=x.name) + if isinstance(x, layers.InputSpec): + if x.shape is None or x.dtype is None: + raise ValueError( + "The `shape` and `dtype` must be provided. " f"Received: x={x}" + ) + tensor_spec = tf.TensorSpec(x.shape, dtype=x.dtype, name=x.name) + elif isinstance(x, tf.TensorSpec): + tensor_spec = x + elif isinstance(x, backend.KerasTensor): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + tensor_spec = tf.TensorSpec(shape, dtype=x.dtype, name=x.name) + elif backend.is_tensor(x): + shape = (None,) + backend.standardize_shape(x.shape)[1:] + dtype = backend.standardize_dtype(x.dtype) + tensor_spec = tf.TensorSpec(shape, dtype=dtype, name=None) + else: + raise TypeError( + f"Unsupported x={x} of the type ({type(x)}). Supported types are: " + "`keras.InputSpec`, `tf.TensorSpec`, `keras.KerasTensor` and " + "backend tensor." + ) + return tensor_spec def _print_signature(fn, name, verbose=True): diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/export_lib_test.py index 7e5a9c52010..def6203df61 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/export_lib_test.py @@ -54,7 +54,7 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None): reason="Export only currently supports the TF and JAX backends.", ) @pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") -class ExportArchiveTest(testing.TestCase): +class ExportSavedModelTest(testing.TestCase): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) ) @@ -64,7 +64,7 @@ def test_standard_model_export(self, model_type): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size @@ -89,7 +89,7 @@ def call(self, inputs): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape) # Test with a different batch size @@ -118,7 +118,7 @@ def call(self, inputs): model = get_model(model_type, layer_list=[StateLayer()]) model(tf.random.normal((3, 10))) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) # The non-trainable counter is expected to increment @@ -139,7 +139,7 @@ def test_model_with_tf_data_layer(self, model_type): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size @@ -182,7 +182,7 @@ def call(self, inputs): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input)) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose(ref_output, revived_model.serve(ref_input)) # Test with a different batch size @@ -205,7 +205,7 @@ def call(self, inputs): }, ) self.assertAllClose(ref_output, revived_model(ref_input)) - export_lib.export_model(revived_model, self.get_temp_dir()) + export_lib.export_saved_model(revived_model, self.get_temp_dir()) def test_model_with_multiple_inputs(self): class TwoInputsModel(models.Model): @@ -221,7 +221,7 @@ def build(self, y_shape, x_shape): ref_input_y = tf.random.normal((3, 10)) ref_output = model(ref_input_x, ref_input_y) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) revived_model = tf.saved_model.load(temp_filepath) self.assertAllClose( ref_output, revived_model.serve(ref_input_x, ref_input_y) @@ -231,6 +231,80 @@ def build(self, y_shape, x_shape): tf.random.normal((6, 10)), tf.random.normal((6, 10)) ) + @parameterized.named_parameters( + named_product( + model_type=["sequential", "functional", "subclass"], + input_signature=[ + layers.InputSpec( + dtype="float32", shape=(None, 10), name="inputs" + ), + tf.TensorSpec((None, 10), dtype="float32", name="inputs"), + backend.KerasTensor((None, 10), dtype="float32", name="inputs"), + "backend_tensor", + ], + ) + ) + def test_input_signature(self, model_type, input_signature): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + ref_input = ops.random.uniform((3, 10)) + ref_output = model(ref_input) + + if input_signature == "backend_tensor": + input_signature = (ref_input,) + else: + input_signature = (input_signature,) + export_lib.export_saved_model( + model, temp_filepath, input_signature=input_signature + ) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + + def test_input_signature_error(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model("functional") + with self.assertRaisesRegex(TypeError, "Unsupported x="): + input_signature = (123,) + export_lib.export_saved_model( + model, temp_filepath, input_signature=input_signature + ) + + @parameterized.named_parameters( + named_product( + model_type=["sequential", "functional", "subclass"], + is_static=(True, False), + jax2tf_kwargs=( + None, + {"enable_xla": False, "native_serialization": False}, + ), + ) + ) + @pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is only for the jax backend.", + ) + def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = get_model(model_type) + ref_input = ops.random.uniform((3, 10)) + ref_output = model(ref_input) + + export_lib.export_saved_model( + model, + temp_filepath, + is_static=is_static, + jax2tf_kwargs=jax2tf_kwargs, + ) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve(ref_input)) + + +@pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax"), + reason="Export only currently supports the TF and JAX backends.", +) +@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI") +class ExportArchiveTest(testing.TestCase): @parameterized.named_parameters( named_product(model_type=["sequential", "functional", "subclass"]) ) @@ -679,7 +753,7 @@ def test_multi_input_output_functional_model(self): # ref_input = tf.convert_to_tensor(["one two three four"]) # ref_output = model(ref_input) - # export_lib.export_model(model, temp_filepath) + # export_lib.export_saved_model(model, temp_filepath) # revived_model = tf.saved_model.load(temp_filepath) # self.assertAllClose(ref_output, revived_model.serve(ref_input)) @@ -782,19 +856,19 @@ def test_variable_collection(self): revived_model = tf.saved_model.load(temp_filepath) self.assertLen(revived_model.my_vars, 2) - def test_export_model_errors(self): + def test_export_saved_model_errors(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") # Model has not been built model = models.Sequential([layers.Dense(2)]) with self.assertRaisesRegex(ValueError, "It must be built"): - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) # Subclassed model has not been called model = get_model("subclass") model.build((2, 10)) with self.assertRaisesRegex(ValueError, "It must be called"): - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) def test_export_archive_errors(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") @@ -893,7 +967,7 @@ def test_reloading_export_archive(self): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) reloaded_layer = export_lib.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7) self.assertLen(reloaded_layer.weights, len(model.weights)) @@ -977,7 +1051,7 @@ def test_serialization(self): ref_input = tf.random.normal((3, 10)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) reloaded_layer = export_lib.TFSMLayer(temp_filepath) # Test reinstantiation from config @@ -999,7 +1073,7 @@ def test_errors(self): # Test missing call endpoint temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") model = models.Sequential([layers.Input((2,)), layers.Dense(3)]) - export_lib.export_model(model, temp_filepath) + export_lib.export_saved_model(model, temp_filepath) with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"): export_lib.TFSMLayer(temp_filepath, call_endpoint="wrong") diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 9d987923718..34ffc2d327b 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -30,6 +30,7 @@ from keras.src.layers.core.lambda_layer import Lambda from keras.src.layers.core.masking import Masking from keras.src.layers.core.wrapper import Wrapper +from keras.src.layers.input_spec import InputSpec from keras.src.layers.layer import Layer from keras.src.layers.merging.add import Add from keras.src.layers.merging.add import add diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 6ef9a55f42c..2c2faac218a 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -565,7 +565,7 @@ def test_quantize_int8_when_lora_enabled(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_input = tf.random.normal((2, 8)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + model.export(temp_filepath, format="tf_saved_model") reloaded_layer = export_lib.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 @@ -737,7 +737,7 @@ def test_quantize_float8_fitting(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_input = tf.random.normal((2, 8)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + model.export(temp_filepath, format="tf_saved_model") reloaded_layer = export_lib.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output) self.assertLen(reloaded_layer.weights, len(model.weights)) diff --git a/keras/src/layers/core/einsum_dense_test.py b/keras/src/layers/core/einsum_dense_test.py index 06409ed6f55..796cb37fd76 100644 --- a/keras/src/layers/core/einsum_dense_test.py +++ b/keras/src/layers/core/einsum_dense_test.py @@ -698,7 +698,7 @@ def test_quantize_int8_when_lora_enabled( temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_input = tf.random.normal(input_shape) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + model.export(temp_filepath, format="tf_saved_model") reloaded_layer = export_lib.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 @@ -877,7 +877,7 @@ def test_quantize_float8_fitting(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_input = tf.random.normal((2, 3)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + model.export(temp_filepath, format="tf_saved_model") reloaded_layer = export_lib.TFSMLayer(temp_filepath) self.assertAllClose(reloaded_layer(ref_input), ref_output) self.assertLen(reloaded_layer.weights, len(model.weights)) diff --git a/keras/src/layers/core/embedding_test.py b/keras/src/layers/core/embedding_test.py index 1e4f6c69258..ac4b6d6c8c7 100644 --- a/keras/src/layers/core/embedding_test.py +++ b/keras/src/layers/core/embedding_test.py @@ -438,7 +438,7 @@ def test_quantize_when_lora_enabled(self): temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") ref_input = tf.random.normal((32, 3)) ref_output = model(ref_input) - export_lib.export_model(model, temp_filepath) + model.export(temp_filepath, format="tf_saved_model") reloaded_layer = export_lib.TFSMLayer(temp_filepath) self.assertAllClose( reloaded_layer(ref_input), ref_output, atol=1e-7 diff --git a/keras/src/models/model.py b/keras/src/models/model.py index 3b5b6d8fc0a..9684995d19c 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -458,44 +458,72 @@ def to_json(self, **kwargs): model_config = serialization_lib.serialize_keras_object(self) return json.dumps(model_config, **kwargs) - def export(self, filepath, format="tf_saved_model", verbose=True): - """Create a TF SavedModel artifact for inference. - - **Note:** This can currently only be used with - the TensorFlow or JAX backends. - - This method lets you export a model to a lightweight SavedModel artifact - that contains the model's forward pass only (its `call()` method) - and can be served via e.g. TF-Serving. The forward pass is registered - under the name `serve()` (see example below). + def export( + self, + filepath, + format="tf_saved_model", + verbose=True, + input_signature=None, + **kwargs, + ): + """Export the model as an artifact for inference. - The original code of the model (including any custom layers you may - have used) is *no longer* necessary to reload the artifact -- it is - entirely standalone. + **Note:** This feature is currently supported only with TensorFlow and + JAX backends. + **Note:** Currently, only `format="tf_saved_model"` is supported. Args: - filepath: `str` or `pathlib.Path` object. Path where to save - the artifact. - verbose: whether to print all the variables of the exported model. + filepath: `str` or `pathlib.Path` object. The path to save the + artifact. + format: `str`. The export format. Supported value: + `"tf_saved_model"`. Defaults to `"tf_saved_model"`. + verbose: `bool`. Whether to print a message during export. Defaults + to `True`. + input_signature: Optional. Specifies the shape and dtype of the + model inputs. Can be a structure of `keras.InputSpec`, + `tf.TensorSpec`, `backend.KerasTensor`, or backend tensor. If + not provided, it will be automatically computed. Defaults to + `None`. + **kwargs: Additional keyword arguments: + - Specific to the JAX backend: + - `is_static`: Optional `bool`. Indicates whether `fn` is + static. Set to `False` if `fn` involves state updates + (e.g., RNG seeds and counters). + - `jax2tf_kwargs`: Optional `dict`. Arguments for + `jax2tf.convert`. See the documentation for + [`jax2tf.convert`]( + https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + If `native_serialization` and `polymorphic_shapes` are + not provided, they will be automatically computed. Example: ```python - # Create the artifact - model.export("path/to/location") + # Export the model as a TensorFlow SavedModel artifact + model.export("path/to/location", format="tf_saved_model") - # Later, in a different process/environment... + # Load the artifact in a different process/environment reloaded_artifact = tf.saved_model.load("path/to/location") predictions = reloaded_artifact.serve(input_data) ``` - - If you would like to customize your serving endpoints, you can - use the lower-level `keras.export.ExportArchive` class. The - `export()` method relies on `ExportArchive` internally. """ from keras.src.export import export_lib - export_lib.export_model(self, filepath, verbose) + available_formats = ("tf_saved_model",) + if format not in available_formats: + raise ValueError( + f"Unrecognized format={format}. Supported formats are: " + f"{list(available_formats)}." + ) + + if format == "tf_saved_model": + export_lib.export_saved_model( + self, + filepath, + verbose, + input_signature=input_signature, + **kwargs, + ) @classmethod def from_config(cls, config, custom_objects=None): diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 356c7598183..de7fd98e9db 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,3 +1,4 @@ +import os import pickle from collections import namedtuple @@ -1217,3 +1218,47 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ] ) self.assertListEqual(hist_keys, ref_keys) + + @pytest.mark.skipif( + backend.backend() not in ("tensorflow", "jax"), + reason=( + "Currently, `Model.export` only supports the tensorflow and jax" + " backends." + ), + ) + @pytest.mark.skipif( + testing.jax_uses_gpu(), reason="Leads to core dumps on CI" + ) + def test_export(self): + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = _get_model() + x1 = np.random.rand(2, 3) + x2 = np.random.rand(2, 3) + ref_output = model([x1, x2]) + + model.export(temp_filepath) + revived_model = tf.saved_model.load(temp_filepath) + self.assertAllClose(ref_output, revived_model.serve([x1, x2])) + + # Test with a different batch size + revived_model.serve( + [np.concatenate([x1, x1], axis=0), np.concatenate([x2, x2], axis=0)] + ) + + def test_export_error(self): + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + model = _get_model() + + # Bad format + with self.assertRaisesRegex(ValueError, "Unrecognized format="): + model.export(temp_filepath, format="bad_format") + + # Bad backend + if backend.backend() not in ("tensorflow", "jax"): + with self.assertRaisesRegex( + NotImplementedError, + "The export API is only compatible with JAX and TF backends.", + ): + model.export(temp_filepath) diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py index 9b3bf88331f..f4e3c95d914 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter_test.py @@ -87,8 +87,8 @@ def num_batches(self): def __getitem__(self, index): if index < 2: return ( - np.random.random((64, 4)).astype("float32"), - np.random.random((64, 2)).astype("float32"), + np.random.random((8, 4)).astype("float32"), + np.random.random((8, 2)).astype("float32"), ) raise ValueError("Expected exception") @@ -229,7 +229,7 @@ def test_speedup(self): x, y, batch_size=4, - delay=0.5, + delay=0.2, ) adapter = py_dataset_adapter.PyDatasetAdapter( no_speedup_py_dataset, shuffle=False @@ -249,7 +249,7 @@ def test_speedup(self): # multiprocessing # use_multiprocessing=True, max_queue_size=8, - delay=0.5, + delay=0.2, ) adapter = py_dataset_adapter.PyDatasetAdapter( speedup_py_dataset, shuffle=False @@ -361,6 +361,11 @@ def test_exception_reported( use_multiprocessing=False, max_queue_size=0, ): + if backend.backend() == "jax" and use_multiprocessing is True: + self.skipTest( + "The CI failed for an unknown reason with " + "`use_multiprocessing=True` in the jax backend" + ) dataset = ExceptionPyDataset( workers=workers, use_multiprocessing=use_multiprocessing, diff --git a/keras/src/utils/jax_layer_test.py b/keras/src/utils/jax_layer_test.py index 2c85ecc2e11..359bdca41c9 100644 --- a/keras/src/utils/jax_layer_test.py +++ b/keras/src/utils/jax_layer_test.py @@ -15,7 +15,6 @@ from keras.src import testing from keras.src import tree from keras.src import utils -from keras.src.export import export_lib from keras.src.saving import object_registration from keras.src.utils.jax_layer import FlaxLayer from keras.src.utils.jax_layer import JaxLayer @@ -321,7 +320,7 @@ def verify_identical_model(model): # export, load back and compare results path = os.path.join(self.get_temp_dir(), "jax_layer_export") - export_lib.export_model(model2, path) + model2.export(path, format="tf_saved_model") model4 = tf.saved_model.load(path) output4 = model4.serve(x_test) self.assertAllClose(output1, output4)