Skip to content

Commit

Permalink
fix(test): deserialization of h5 models w/ custom obj doesn't work in…
Browse files Browse the repository at this point in the history
… TF2.13

It seems that (de)serialization of h5 models in TF 2.13 has been changed:
custom objects, even with "register_keras_serializable", cannot be loaded.
Two solutions are possible:
- use "with_custom_object_scope()" to load a model with custom objects.
- save model in Keras format and not h5.

The second option was chosen because this format is now preferred for saving
models.
  • Loading branch information
cofri committed Jul 26, 2023
1 parent 04b03f8 commit a3e61fe
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def train_k_lip_model(
)
empirical_lip_const = evaluate_lip_const(model=model, x=x, seed=42)
# save the model
model_checkpoint_path = os.path.join(logdir, "model.h5")
model.save(model_checkpoint_path, overwrite=True, save_format="h5")
model_checkpoint_path = os.path.join(logdir, "model.keras")
model.save(model_checkpoint_path, overwrite=True)
del model
K.clear_session()
model = load_model(model_checkpoint_path)
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def test_vanilla_export(self):

# Test saving/loading model
with tempfile.TemporaryDirectory() as tmpdir:
model_path = os.path.join(tmpdir, "model.h5")
model_path = os.path.join(tmpdir, "model.keras")
model.save(model_path)
tf.keras.models.load_model(model_path)

Expand Down

0 comments on commit a3e61fe

Please sign in to comment.