From a3e61fe7aadb6148497d79b134c44ffccee4ead9 Mon Sep 17 00:00:00 2001 From: cofri Date: Mon, 24 Jul 2023 14:24:38 +0200 Subject: [PATCH] fix(test): deserialization of h5 models w/ custom obj doesn't work in TF2.13 It seems that (de)serialization of h5 models in TF 2.13 has been changed: custom objects, even with "register_keras_serializable", cannot be loaded. Two solutions are possible: - use "with_custom_object_scope()" to load a model with custom objects. - save model in Keras format and not h5. The second option was chosen because this format is now preferred for saving models. --- tests/test_layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index c378025e..46c3cfc4 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -226,8 +226,8 @@ def train_k_lip_model( ) empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42) # save the model - model_checkpoint_path = os.path.join(logdir, "model.h5") - model.save(model_checkpoint_path, overwrite=True, save_format="h5") + model_checkpoint_path = os.path.join(logdir, "model.keras") + model.save(model_checkpoint_path, overwrite=True) del model K.clear_session() model = load_model(model_checkpoint_path) @@ -1050,7 +1050,7 @@ def test_vanilla_export(self): # Test saving/loading model with tempfile.TemporaryDirectory() as tmpdir: - model_path = os.path.join(tmpdir, "model.h5") + model_path = os.path.join(tmpdir, "model.keras") model.save(model_path) tf.keras.models.load_model(model_path)