diff --git a/conftest.py b/conftest.py index 5c27d947c13..d71de30a616 100644 --- a/conftest.py +++ b/conftest.py @@ -15,6 +15,9 @@ import pytest # noqa: E402 from keras.src.backend import backend # noqa: E402 +from keras.src.saving.object_registration import ( # noqa: E402 + get_custom_objects, +) def pytest_configure(config): @@ -32,3 +35,9 @@ def pytest_collection_modifyitems(config, items): for item in items: if "requires_trainable_backend" in item.keywords: item.add_marker(requires_trainable_backend) + + +# Ensure each test is run in isolation regarding the custom objects dict +@pytest.fixture(autouse=True) +def reset_custom_objects_global_dictionary(request): + get_custom_objects().clear() diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 7fa91c5b95d..cd06e154616 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,5 +1,7 @@ import pickle +import sys +import cloudpickle import numpy as np import pytest from absl.testing import parameterized @@ -11,6 +13,7 @@ from keras.src.models.functional import Functional from keras.src.models.model import Model from keras.src.models.model import model_from_json +from keras.src.saving.object_registration import register_keras_serializable def _get_model(): @@ -68,6 +71,17 @@ def _get_model_multi_outputs_dict(): return model +@pytest.fixture +def fake_main_module(request, monkeypatch): + original_main = sys.modules["__main__"] + + def restore_main_module(): + sys.modules["__main__"] = original_main + + request.addfinalizer(restore_main_module) + sys.modules["__main__"] = sys.modules[__name__] + + @pytest.mark.requires_trainable_backend class ModelTest(testing.TestCase, parameterized.TestCase): def test_functional_rerouting(self): @@ -141,6 +155,45 @@ def test_functional_pickling(self, model_fn): self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + # Fake the __main__ module because cloudpickle only serializes + # functions & classes if they are defined in the __main__ module. + @pytest.mark.usefixtures("fake_main_module") + def test_functional_pickling_custom_layer(self): + @register_keras_serializable() + class CustomDense(layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + self.dense = layers.Dense(units) + + def call(self, x): + return self.dense(x) + + def get_config(self): + config = super().get_config() + config.update({"units": self.units}) + return config + + x = Input(shape=(3,), name="input_a") + output_a = CustomDense(10, name="output_a")(x) + model = Model(x, output_a) + + self.assertIsInstance(model, Functional) + model.compile() + x = np.random.rand(8, 3) + + dumped_pickle = cloudpickle.dumps(model) + + # Verify that we can load the dumped pickle even if the custom object + # is not available in the loading environment. + del CustomDense + reloaded_pickle = cloudpickle.loads(dumped_pickle) + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + + self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + @parameterized.named_parameters( ("single_output_1", _get_model_single_output, None), ("single_output_2", _get_model_single_output, "list"), diff --git a/keras/src/saving/keras_saveable.py b/keras/src/saving/keras_saveable.py index 7fc536b470c..e122a4fa8d3 100644 --- a/keras/src/saving/keras_saveable.py +++ b/keras/src/saving/keras_saveable.py @@ -1,5 +1,7 @@ import io +from keras.src.saving.object_registration import get_custom_objects + class KerasSaveable: # Note: renaming this function will cause old pickles to be broken. @@ -14,12 +16,23 @@ def _obj_type(self): ) @classmethod - def _unpickle_model(cls, bytesio): + def _unpickle_model(cls, data): import keras.src.saving.saving_lib as saving_lib # pickle is not safe regardless of what you do. + + if "custom_objects_buf" in data.keys(): + import pickle + + custom_objects = pickle.load(data["custom_objects_buf"]) + else: + custom_objects = None + return saving_lib._load_model_from_fileobj( - bytesio, custom_objects=None, compile=True, safe_mode=False + data["model_buf"], + custom_objects=custom_objects, + compile=True, + safe_mode=False, ) def __reduce__(self): @@ -30,9 +43,23 @@ def __reduce__(self): keras saving library.""" import keras.src.saving.saving_lib as saving_lib - buf = io.BytesIO() - saving_lib._save_model_to_fileobj(self, buf, "h5") + data = {} + + model_buf = io.BytesIO() + saving_lib._save_model_to_fileobj(self, model_buf, "h5") + data["model_buf"] = model_buf + + try: + import cloudpickle + + custom_objects_buf = io.BytesIO() + cloudpickle.dump(get_custom_objects(), custom_objects_buf) + custom_objects_buf.seek(0) + data["custom_objects_buf"] = custom_objects_buf + except ImportError: + pass + return ( self._unpickle_model, - (buf,), + (data,), )