-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rework
Model.export
and keras.export.ExportArchive
to support exp…
…orting in TFLite and ONNX formats in the future (#20631) * Rework `Model.export` and `keras.export.ExportArchive` * Try fixing PyDatasetAdapterTest CI issues
- Loading branch information
1 parent
32a642d
commit 5a9b26d
Showing
14 changed files
with
662 additions
and
301 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import copy | ||
import inspect | ||
import itertools | ||
import string | ||
import warnings | ||
|
||
from keras.src import layers | ||
from keras.src import tree | ||
from keras.src.backend.common.stateless_scope import StatelessScope | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
|
||
class JaxExportArchive: | ||
def __init__(self): | ||
self._backend_variables = [] | ||
self._backend_trainable_variables = [] | ||
self._backend_non_trainable_variables = [] | ||
|
||
def track(self, resource): | ||
if not isinstance(resource, layers.Layer): | ||
raise ValueError( | ||
"Invalid resource type. Expected an instance of a " | ||
"JAX-based Keras `Layer` or `Model`. " | ||
f"Received instead an object of type '{type(resource)}'. " | ||
f"Object received: {resource}" | ||
) | ||
|
||
if isinstance(resource, layers.Layer): | ||
# Variables in the lists below are actually part of the trackables | ||
# that get saved, because the lists are created in __init__. | ||
trainable_variables = resource.trainable_variables | ||
non_trainable_variables = resource.non_trainable_variables | ||
|
||
self._tf_trackable.trainable_variables += tree.map_structure( | ||
self._convert_to_tf_variable, trainable_variables | ||
) | ||
self._tf_trackable.non_trainable_variables += tree.map_structure( | ||
self._convert_to_tf_variable, non_trainable_variables | ||
) | ||
self._tf_trackable.variables = ( | ||
self._tf_trackable.trainable_variables | ||
+ self._tf_trackable.non_trainable_variables | ||
) | ||
|
||
self._backend_trainable_variables += trainable_variables | ||
self._backend_non_trainable_variables += non_trainable_variables | ||
self._backend_variables = ( | ||
self._backend_trainable_variables | ||
+ self._backend_non_trainable_variables | ||
) | ||
|
||
def add_endpoint(self, name, fn, input_signature=None, **kwargs): | ||
jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None) | ||
# Use `copy.copy()` to avoid modification issues. | ||
jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {} | ||
is_static = bool(kwargs.pop("is_static", False)) | ||
|
||
# Configure `jax2tf_kwargs` | ||
if "native_serialization" not in jax2tf_kwargs: | ||
jax2tf_kwargs["native_serialization"] = ( | ||
self._check_device_compatible() | ||
) | ||
if "polymorphic_shapes" not in jax2tf_kwargs: | ||
jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( | ||
input_signature | ||
) | ||
|
||
# Note: we truncate the number of parameters to what is specified by | ||
# `input_signature`. | ||
fn_signature = inspect.signature(fn) | ||
fn_parameters = list(fn_signature.parameters.values()) | ||
|
||
if is_static: | ||
from jax.experimental import jax2tf | ||
|
||
jax_fn = jax2tf.convert(fn, **jax2tf_kwargs) | ||
jax_fn.__signature__ = inspect.Signature( | ||
parameters=fn_parameters[0 : len(input_signature)], | ||
return_annotation=fn_signature.return_annotation, | ||
) | ||
|
||
decorated_fn = tf.function( | ||
jax_fn, | ||
input_signature=input_signature, | ||
autograph=False, | ||
) | ||
else: | ||
# 1. Create a stateless wrapper for `fn` | ||
# 2. jax2tf the stateless wrapper | ||
# 3. Create a stateful function that binds the variables with | ||
# the jax2tf converted stateless wrapper | ||
# 4. Make the signature of the stateful function the same as the | ||
# original function | ||
# 5. Wrap in a `tf.function` | ||
def stateless_fn(variables, *args, **kwargs): | ||
state_mapping = zip(self._backend_variables, variables) | ||
with StatelessScope(state_mapping=state_mapping) as scope: | ||
output = fn(*args, **kwargs) | ||
|
||
# Gather updated non-trainable variables | ||
non_trainable_variables = [] | ||
for var in self._backend_non_trainable_variables: | ||
new_value = scope.get_current_value(var) | ||
non_trainable_variables.append(new_value) | ||
return output, non_trainable_variables | ||
|
||
jax2tf_stateless_fn = self._convert_jax2tf_function( | ||
stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs | ||
) | ||
|
||
def stateful_fn(*args, **kwargs): | ||
output, non_trainable_variables = jax2tf_stateless_fn( | ||
# Change the trackable `ListWrapper` to a plain `list` | ||
list(self._tf_trackable.variables), | ||
*args, | ||
**kwargs, | ||
) | ||
for var, new_value in zip( | ||
self._tf_trackable.non_trainable_variables, | ||
non_trainable_variables, | ||
): | ||
var.assign(new_value) | ||
return output | ||
|
||
stateful_fn.__signature__ = inspect.Signature( | ||
parameters=fn_parameters[0 : len(input_signature)], | ||
return_annotation=fn_signature.return_annotation, | ||
) | ||
|
||
decorated_fn = tf.function( | ||
stateful_fn, | ||
input_signature=input_signature, | ||
autograph=False, | ||
) | ||
return decorated_fn | ||
|
||
def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None): | ||
from jax.experimental import jax2tf | ||
|
||
variables_shapes = self._to_polymorphic_shape( | ||
self._backend_variables, allow_none=False | ||
) | ||
input_shapes = list(jax2tf_kwargs["polymorphic_shapes"]) | ||
jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes | ||
return jax2tf.convert(fn, **jax2tf_kwargs) | ||
|
||
def _to_polymorphic_shape(self, struct, allow_none=True): | ||
if allow_none: | ||
# Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz | ||
# for unknown non-batch dims. Defined here to be scope per endpoint. | ||
dim_names = itertools.chain( | ||
string.ascii_lowercase, | ||
itertools.starmap( | ||
lambda a, b: a + b, | ||
itertools.product(string.ascii_lowercase, repeat=2), | ||
), | ||
) | ||
|
||
def convert_shape(x): | ||
poly_shape = [] | ||
for index, dim in enumerate(list(x.shape)): | ||
if dim is not None: | ||
poly_shape.append(str(dim)) | ||
elif not allow_none: | ||
raise ValueError( | ||
f"Illegal None dimension in {x} with shape {x.shape}" | ||
) | ||
elif index == 0: | ||
poly_shape.append("batch") | ||
else: | ||
poly_shape.append(next(dim_names)) | ||
return "(" + ", ".join(poly_shape) + ")" | ||
|
||
return tree.map_structure(convert_shape, struct) | ||
|
||
def _check_device_compatible(self): | ||
from jax import default_backend as jax_device | ||
|
||
if ( | ||
jax_device() == "gpu" | ||
and len(tf.config.list_physical_devices("GPU")) == 0 | ||
): | ||
warnings.warn( | ||
"JAX backend is using GPU for export, but installed " | ||
"TF package cannot access GPU, so reloading the model with " | ||
"the TF runtime in the same environment will not work. " | ||
"To use JAX-native serialization for high-performance export " | ||
"and serving, please install `tensorflow-gpu` and ensure " | ||
"CUDA version compatibility between your JAX and TF " | ||
"installations." | ||
) | ||
return False | ||
else: | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
class NumpyExportArchive: | ||
def track(self, resource): | ||
raise NotImplementedError( | ||
"`track` is not implemented in the numpy backend." | ||
) | ||
|
||
def add_endpoint(self, name, fn, input_signature=None, **kwargs): | ||
raise NotImplementedError( | ||
"`add_endpoint` is not implemented in the numpy backend." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import tensorflow as tf | ||
|
||
from keras.src import layers | ||
|
||
|
||
class TFExportArchive: | ||
def track(self, resource): | ||
if not isinstance(resource, tf.__internal__.tracking.Trackable): | ||
raise ValueError( | ||
"Invalid resource type. Expected an instance of a " | ||
"TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). " | ||
f"Received instead an object of type '{type(resource)}'. " | ||
f"Object received: {resource}" | ||
) | ||
|
||
if isinstance(resource, layers.Layer): | ||
# Variables in the lists below are actually part of the trackables | ||
# that get saved, because the lists are created in __init__. | ||
variables = resource.variables | ||
trainable_variables = resource.trainable_variables | ||
non_trainable_variables = resource.non_trainable_variables | ||
self._tf_trackable.variables += variables | ||
self._tf_trackable.trainable_variables += trainable_variables | ||
self._tf_trackable.non_trainable_variables += ( | ||
non_trainable_variables | ||
) | ||
|
||
def add_endpoint(self, name, fn, input_signature=None, **kwargs): | ||
decorated_fn = tf.function( | ||
fn, input_signature=input_signature, autograph=False | ||
) | ||
return decorated_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from keras.src import layers | ||
from keras.src import tree | ||
|
||
|
||
class TorchExportArchive: | ||
def track(self, resource): | ||
if not isinstance(resource, layers.Layer): | ||
raise ValueError( | ||
"Invalid resource type. Expected an instance of a " | ||
"JAX-based Keras `Layer` or `Model`. " | ||
f"Received instead an object of type '{type(resource)}'. " | ||
f"Object received: {resource}" | ||
) | ||
|
||
if isinstance(resource, layers.Layer): | ||
# Variables in the lists below are actually part of the trackables | ||
# that get saved, because the lists are created in __init__. | ||
variables = resource.variables | ||
trainable_variables = resource.trainable_variables | ||
non_trainable_variables = resource.non_trainable_variables | ||
self._tf_trackable.variables += tree.map_structure( | ||
self._convert_to_tf_variable, variables | ||
) | ||
self._tf_trackable.trainable_variables += tree.map_structure( | ||
self._convert_to_tf_variable, trainable_variables | ||
) | ||
self._tf_trackable.non_trainable_variables += tree.map_structure( | ||
self._convert_to_tf_variable, non_trainable_variables | ||
) | ||
|
||
def add_endpoint(self, name, fn, input_signature=None, **kwargs): | ||
# TODO: torch-xla? | ||
raise NotImplementedError( | ||
"`add_endpoint` is not implemented in the torch backend." | ||
) |
Oops, something went wrong.