Skip to content

Commit

Permalink
Rework Model.export and keras.export.ExportArchive to support exp…
Browse files Browse the repository at this point in the history
…orting in TFLite and ONNX formats in the future (#20631)

* Rework `Model.export` and `keras.export.ExportArchive`

* Try fixing PyDatasetAdapterTest CI issues
  • Loading branch information
james77777778 authored Dec 13, 2024
1 parent 32a642d commit 5a9b26d
Show file tree
Hide file tree
Showing 14 changed files with 662 additions and 301 deletions.
194 changes: 194 additions & 0 deletions keras/src/backend/jax/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import copy
import inspect
import itertools
import string
import warnings

from keras.src import layers
from keras.src import tree
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.utils.module_utils import tensorflow as tf


class JaxExportArchive:
def __init__(self):
self._backend_variables = []
self._backend_trainable_variables = []
self._backend_non_trainable_variables = []

def track(self, resource):
if not isinstance(resource, layers.Layer):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"JAX-based Keras `Layer` or `Model`. "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)

if isinstance(resource, layers.Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
trainable_variables = resource.trainable_variables
non_trainable_variables = resource.non_trainable_variables

self._tf_trackable.trainable_variables += tree.map_structure(
self._convert_to_tf_variable, trainable_variables
)
self._tf_trackable.non_trainable_variables += tree.map_structure(
self._convert_to_tf_variable, non_trainable_variables
)
self._tf_trackable.variables = (
self._tf_trackable.trainable_variables
+ self._tf_trackable.non_trainable_variables
)

self._backend_trainable_variables += trainable_variables
self._backend_non_trainable_variables += non_trainable_variables
self._backend_variables = (
self._backend_trainable_variables
+ self._backend_non_trainable_variables
)

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None)
# Use `copy.copy()` to avoid modification issues.
jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {}
is_static = bool(kwargs.pop("is_static", False))

# Configure `jax2tf_kwargs`
if "native_serialization" not in jax2tf_kwargs:
jax2tf_kwargs["native_serialization"] = (
self._check_device_compatible()
)
if "polymorphic_shapes" not in jax2tf_kwargs:
jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape(
input_signature
)

# Note: we truncate the number of parameters to what is specified by
# `input_signature`.
fn_signature = inspect.signature(fn)
fn_parameters = list(fn_signature.parameters.values())

if is_static:
from jax.experimental import jax2tf

jax_fn = jax2tf.convert(fn, **jax2tf_kwargs)
jax_fn.__signature__ = inspect.Signature(
parameters=fn_parameters[0 : len(input_signature)],
return_annotation=fn_signature.return_annotation,
)

decorated_fn = tf.function(
jax_fn,
input_signature=input_signature,
autograph=False,
)
else:
# 1. Create a stateless wrapper for `fn`
# 2. jax2tf the stateless wrapper
# 3. Create a stateful function that binds the variables with
# the jax2tf converted stateless wrapper
# 4. Make the signature of the stateful function the same as the
# original function
# 5. Wrap in a `tf.function`
def stateless_fn(variables, *args, **kwargs):
state_mapping = zip(self._backend_variables, variables)
with StatelessScope(state_mapping=state_mapping) as scope:
output = fn(*args, **kwargs)

# Gather updated non-trainable variables
non_trainable_variables = []
for var in self._backend_non_trainable_variables:
new_value = scope.get_current_value(var)
non_trainable_variables.append(new_value)
return output, non_trainable_variables

jax2tf_stateless_fn = self._convert_jax2tf_function(
stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs
)

def stateful_fn(*args, **kwargs):
output, non_trainable_variables = jax2tf_stateless_fn(
# Change the trackable `ListWrapper` to a plain `list`
list(self._tf_trackable.variables),
*args,
**kwargs,
)
for var, new_value in zip(
self._tf_trackable.non_trainable_variables,
non_trainable_variables,
):
var.assign(new_value)
return output

stateful_fn.__signature__ = inspect.Signature(
parameters=fn_parameters[0 : len(input_signature)],
return_annotation=fn_signature.return_annotation,
)

decorated_fn = tf.function(
stateful_fn,
input_signature=input_signature,
autograph=False,
)
return decorated_fn

def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None):
from jax.experimental import jax2tf

variables_shapes = self._to_polymorphic_shape(
self._backend_variables, allow_none=False
)
input_shapes = list(jax2tf_kwargs["polymorphic_shapes"])
jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes
return jax2tf.convert(fn, **jax2tf_kwargs)

def _to_polymorphic_shape(self, struct, allow_none=True):
if allow_none:
# Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz
# for unknown non-batch dims. Defined here to be scope per endpoint.
dim_names = itertools.chain(
string.ascii_lowercase,
itertools.starmap(
lambda a, b: a + b,
itertools.product(string.ascii_lowercase, repeat=2),
),
)

def convert_shape(x):
poly_shape = []
for index, dim in enumerate(list(x.shape)):
if dim is not None:
poly_shape.append(str(dim))
elif not allow_none:
raise ValueError(
f"Illegal None dimension in {x} with shape {x.shape}"
)
elif index == 0:
poly_shape.append("batch")
else:
poly_shape.append(next(dim_names))
return "(" + ", ".join(poly_shape) + ")"

return tree.map_structure(convert_shape, struct)

def _check_device_compatible(self):
from jax import default_backend as jax_device

if (
jax_device() == "gpu"
and len(tf.config.list_physical_devices("GPU")) == 0
):
warnings.warn(
"JAX backend is using GPU for export, but installed "
"TF package cannot access GPU, so reloading the model with "
"the TF runtime in the same environment will not work. "
"To use JAX-native serialization for high-performance export "
"and serving, please install `tensorflow-gpu` and ensure "
"CUDA version compatibility between your JAX and TF "
"installations."
)
return False
else:
return True
10 changes: 10 additions & 0 deletions keras/src/backend/numpy/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class NumpyExportArchive:
def track(self, resource):
raise NotImplementedError(
"`track` is not implemented in the numpy backend."
)

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
raise NotImplementedError(
"`add_endpoint` is not implemented in the numpy backend."
)
32 changes: 32 additions & 0 deletions keras/src/backend/tensorflow/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import tensorflow as tf

from keras.src import layers


class TFExportArchive:
def track(self, resource):
if not isinstance(resource, tf.__internal__.tracking.Trackable):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"TensorFlow `Trackable` (such as a Keras `Layer` or `Model`). "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)

if isinstance(resource, layers.Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
variables = resource.variables
trainable_variables = resource.trainable_variables
non_trainable_variables = resource.non_trainable_variables
self._tf_trackable.variables += variables
self._tf_trackable.trainable_variables += trainable_variables
self._tf_trackable.non_trainable_variables += (
non_trainable_variables
)

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
decorated_fn = tf.function(
fn, input_signature=input_signature, autograph=False
)
return decorated_fn
35 changes: 35 additions & 0 deletions keras/src/backend/torch/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from keras.src import layers
from keras.src import tree


class TorchExportArchive:
def track(self, resource):
if not isinstance(resource, layers.Layer):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"JAX-based Keras `Layer` or `Model`. "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)

if isinstance(resource, layers.Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
variables = resource.variables
trainable_variables = resource.trainable_variables
non_trainable_variables = resource.non_trainable_variables
self._tf_trackable.variables += tree.map_structure(
self._convert_to_tf_variable, variables
)
self._tf_trackable.trainable_variables += tree.map_structure(
self._convert_to_tf_variable, trainable_variables
)
self._tf_trackable.non_trainable_variables += tree.map_structure(
self._convert_to_tf_variable, non_trainable_variables
)

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
# TODO: torch-xla?
raise NotImplementedError(
"`add_endpoint` is not implemented in the torch backend."
)
Loading

0 comments on commit 5a9b26d

Please sign in to comment.