Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add registered custom objects inside pickled model file #19867

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import pytest # noqa: E402

from keras.src.backend import backend # noqa: E402
from keras.src.saving.object_registration import ( # noqa: E402
get_custom_objects,
)


def pytest_configure(config):
Expand All @@ -32,3 +35,9 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)


# Ensure each test is run in isolation regarding the custom objects dict
@pytest.fixture(autouse=True)
def reset_custom_objects_global_dictionary(request):
get_custom_objects().clear()
53 changes: 53 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pickle
import sys

import cloudpickle
import numpy as np
import pytest
from absl.testing import parameterized
Expand All @@ -11,6 +13,7 @@
from keras.src.models.functional import Functional
from keras.src.models.model import Model
from keras.src.models.model import model_from_json
from keras.src.saving.object_registration import register_keras_serializable


def _get_model():
Expand Down Expand Up @@ -68,6 +71,17 @@ def _get_model_multi_outputs_dict():
return model


@pytest.fixture
def fake_main_module(request, monkeypatch):
original_main = sys.modules["__main__"]

def restore_main_module():
sys.modules["__main__"] = original_main

request.addfinalizer(restore_main_module)
sys.modules["__main__"] = sys.modules[__name__]


@pytest.mark.requires_trainable_backend
class ModelTest(testing.TestCase, parameterized.TestCase):
def test_functional_rerouting(self):
Expand Down Expand Up @@ -141,6 +155,45 @@ def test_functional_pickling(self, model_fn):

self.assertAllClose(np.array(pred_reloaded), np.array(pred))

# Fake the __main__ module because cloudpickle only serializes
# functions & classes if they are defined in the __main__ module.
@pytest.mark.usefixtures("fake_main_module")
def test_functional_pickling_custom_layer(self):
@register_keras_serializable()
class CustomDense(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.dense = layers.Dense(units)

def call(self, x):
return self.dense(x)

def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config

x = Input(shape=(3,), name="input_a")
output_a = CustomDense(10, name="output_a")(x)
model = Model(x, output_a)

self.assertIsInstance(model, Functional)
model.compile()
x = np.random.rand(8, 3)

dumped_pickle = cloudpickle.dumps(model)

# Verify that we can load the dumped pickle even if the custom object
# is not available in the loading environment.
del CustomDense
reloaded_pickle = cloudpickle.loads(dumped_pickle)

pred_reloaded = reloaded_pickle.predict(x)
pred = model.predict(x)

self.assertAllClose(np.array(pred_reloaded), np.array(pred))

@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
("single_output_2", _get_model_single_output, "list"),
Expand Down
37 changes: 32 additions & 5 deletions keras/src/saving/keras_saveable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import io

from keras.src.saving.object_registration import get_custom_objects


class KerasSaveable:
# Note: renaming this function will cause old pickles to be broken.
Expand All @@ -14,12 +16,23 @@ def _obj_type(self):
)

@classmethod
def _unpickle_model(cls, bytesio):
def _unpickle_model(cls, data):
import keras.src.saving.saving_lib as saving_lib

# pickle is not safe regardless of what you do.

if "custom_objects_buf" in data.keys():
import pickle

custom_objects = pickle.load(data["custom_objects_buf"])
else:
custom_objects = None

return saving_lib._load_model_from_fileobj(
bytesio, custom_objects=None, compile=True, safe_mode=False
data["model_buf"],
custom_objects=custom_objects,
compile=True,
safe_mode=False,
)

def __reduce__(self):
Expand All @@ -30,9 +43,23 @@ def __reduce__(self):
keras saving library."""
import keras.src.saving.saving_lib as saving_lib

buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, buf, "h5")
data = {}

model_buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, model_buf, "h5")
data["model_buf"] = model_buf

try:
import cloudpickle

custom_objects_buf = io.BytesIO()
cloudpickle.dump(get_custom_objects(), custom_objects_buf)
custom_objects_buf.seek(0)
data["custom_objects_buf"] = custom_objects_buf
except ImportError:
pass

return (
self._unpickle_model,
(buf,),
(data,),
)
Loading