From 558d38c52aadd94fe6df83023c5aae2fb898b782 Mon Sep 17 00:00:00 2001 From: Grvzard Date: Wed, 26 Jun 2024 02:31:26 +0800 Subject: [PATCH] Fix `export_lib.make_tensor_spec` (#19915) * Fix `export_lib.make_tensor_spec` * Add test * chore(format) --- keras/src/export/export_lib.py | 11 ++++++++--- keras/src/export/export_lib_test.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/keras/src/export/export_lib.py b/keras/src/export/export_lib.py index 02714c55c0a..be75c06cd3d 100644 --- a/keras/src/export/export_lib.py +++ b/keras/src/export/export_lib.py @@ -654,13 +654,18 @@ def make_tensor_spec(structure): # into plain Python structures because they don't work with jax2tf/JAX. if isinstance(structure, dict): return {k: make_tensor_spec(v) for k, v in structure.items()} - if isinstance(structure, (list, tuple)): + elif isinstance(structure, tuple): if all(isinstance(d, (int, type(None))) for d in structure): return tf.TensorSpec( shape=(None,) + structure[1:], dtype=model.input_dtype ) - result = [make_tensor_spec(v) for v in structure] - return tuple(result) if isinstance(structure, tuple) else result + return tuple(make_tensor_spec(v) for v in structure) + elif isinstance(structure, list): + if all(isinstance(d, (int, type(None))) for d in structure): + return tf.TensorSpec( + shape=[None] + structure[1:], dtype=model.input_dtype + ) + return [make_tensor_spec(v) for v in structure] else: raise ValueError( f"Unsupported type {type(structure)} for {structure}" diff --git a/keras/src/export/export_lib_test.py b/keras/src/export/export_lib_test.py index 29504cfb2b1..7b4b7d332dc 100644 --- a/keras/src/export/export_lib_test.py +++ b/keras/src/export/export_lib_test.py @@ -196,6 +196,22 @@ def call(self, inputs): ) revived_model.serve(bigger_input) + # Test with keras.saving_lib + temp_filepath = os.path.join( + self.get_temp_dir(), "exported_model.keras" + ) + saving_lib.save_model(model, temp_filepath) + revived_model = saving_lib.load_model( + temp_filepath, + { + "TupleModel": TupleModel, + "ArrayModel": ArrayModel, + "DictModel": DictModel, + }, + ) + self.assertAllClose(ref_output, revived_model(ref_input)) + export_lib.export_model(revived_model, self.get_temp_dir()) + def test_model_with_multiple_inputs(self): class TwoInputsModel(models.Model):