Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build_from_config() gets a dictionary with wrong types when loading a keras_v3 model #43

Open
MidnessX opened this issue Jul 21, 2023 · 10 comments
Assignees
Labels

Comments

@MidnessX
Copy link

MidnessX commented Jul 21, 2023

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Fedora 38 x86_64
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.13
  • Python version: 3.11

Describe the problem.

When loading back a custom tf.keras.Model saved in the keras_v3 format, its build() method receives argument input_shape having type list rather than tf.TensorShape as stated in the documentation.

This is a problem whenever the build() method uses attributes of the tf.TensorShape object, such as rank, because an exception is raised and the model cannot be loaded.

Describe the current behavior.

The following line calls build_from_config() on the custom model passing a build_config dictionary which contains an input_shape key having value of type list.

Describe the expected behavior.

build_from_config() receives a dictionary with proper types.

Contributing.

  • Do you want to contribute a PR? (yes/no): no

Standalone code to reproduce the issue.

import tensorflow as tf

# Define a custom model
@tf.keras.saving.register_keras_serializable()
class CustomModel(tf.keras.models.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, tf.TensorShape)

    def call(self, *args, **kwargs):
        return tf.random.uniform([1], maxval=5)

# Instantiate and build it
x = tf.random.uniform([1, 10, 10])
model = CustomModel()
model(x)

# Save the model in the keras_v3 format
model.save("./model_v3.keras", save_format="keras_v3")

# Try to load it back
model = tf.keras.models.load_model("./model_v3.keras") # Raises AssertionError

Source code / logs.

AssertionError                            Traceback (most recent call last)
Cell In[36], line 1
----> 1 model = tf.keras.models.load_model("./model_v3.keras")
      2 model(x)

File ~/redacted/temp-venv/lib64/python3.11/site-packages/keras/src/saving/saving_api.py:230, in load_model(filepath, custom_objects, compile, safe_mode, **kwargs)
    225     if kwargs:
    226         raise ValueError(
    227             "The following argument(s) are not supported "
    228             f"with the native Keras format: {list(kwargs.keys())}"
    229         )
--> 230     return saving_lib.load_model(
    231         filepath,
    232         custom_objects=custom_objects,
    233         compile=compile,
    234         safe_mode=safe_mode,
    235     )
    237 # Legacy case.
    238 return legacy_sm_saving_lib.load_model(
    239     filepath, custom_objects=custom_objects, compile=compile, **kwargs
    240 )

File ~/redacted/temp-venv/lib64/python3.11/site-packages/keras/src/saving/saving_lib.py:275, in load_model(filepath, custom_objects, compile, safe_mode)
    272             asset_store.close()
    274 except Exception as e:
--> 275     raise e
    276 else:
    277     return model

File ~/redacted/temp-venv/lib64/python3.11/site-packages/keras/src/saving/saving_lib.py:240, in load_model(filepath, custom_objects, compile, safe_mode)
    238 # Construct the model from the configuration file in the archive.
    239 with ObjectSharingScope():
--> 240     model = deserialize_keras_object(
    241         config_dict, custom_objects, safe_mode=safe_mode
    242     )
    244 all_filenames = zf.namelist()
    245 if _VARS_FNAME + ".h5" in all_filenames:

File ~/redacted/temp-venv/lib64/python3.11/site-packages/keras/src/saving/serialization_lib.py:707, in deserialize_keras_object(config, custom_objects, safe_mode, **kwargs)
    705 build_config = config.get("build_config", None)
    706 if build_config:
--> 707     instance.build_from_config(build_config)
    708 compile_config = config.get("compile_config", None)
    709 if compile_config:

File ~/redacted/temp-venv/lib64/python3.11/site-packages/keras/src/engine/base_layer.py:2341, in Layer.build_from_config(self, config)
   2339 input_shape = config["input_shape"]
   2340 if input_shape is not None:
-> 2341     self.build(input_shape)

Cell In[33], line 9, in CustomModel.build(self, input_shape)
      8 def build(self, input_shape):
----> 9     assert isinstance(input_shape, tf.TensorShape)

Workaround

Do not use any of the tf.TensorShape attributes and methods, treating its instances as lists (e.g. using len(input_shape) instead of input_shape.rank).

@tilakrayal
Copy link
Collaborator

@sachinprasadhs,
I was able to reproduce the issue on tensorflow v2.12, v2.13 and tf-nightly. Kindly find the gist of it here.

@nkovela1
Copy link
Contributor

nkovela1 commented Aug 9, 2023

Hi @MidnessX , build_from_config actually passes the input_shape argument as a list or tuple representing the dimensions of the input shape. This is because we need the input shape in the build_config to be serializable in order to save and load it.

The default build behavior for models, which accepts the input_shape arg as a list is shown here:
https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L427

To use the TensorShape methods, you can simply create a TensorShape using the input_shape arg.

@MidnessX
Copy link
Author

Hi @nkovela1, thanks for your reply.

I imagined the problem had to do with serialization. However, I don't understand why the list isn't converted into a TensorShape object before the call to build_from_config(). This would prevent any trouble due to the first argument type not matching the one stated in the docs.
Or, alternatively, why the docs mention TensorShape when it could be a list or a tuple.

@nkovela1
Copy link
Contributor

@MidnessX Ah I see the source of confusion here, the docs you referred to are for the Layer class build() method, but what is used by build_from_config is the Model class build() method (which subclasses from the Layer build() method with extra functionality to deal with lists and tuples): https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L427.

I'm not sure why the method's API reference did not make it onto tensorflow.org under the Model class, sorry about that. I will contact @MarkDaoust from TFDocs for help on this.

@MarkDaoust
Copy link
Contributor

MarkDaoust commented Aug 10, 2023

It is documented on Layer and Model is a sub-class:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build_from_config

It looks like its set not to show the inherited methods, but I'm not 100% sure where.

@nkovela1
Copy link
Contributor

I believe it would be useful to add a build() API reference to the Model class here:

https://www.tensorflow.org/api_docs/python/tf/keras/Model

Model.build_from_config() calls Model.build() which does accept lists and tuples in addition to TensorShape (the different behavior the user is seeing).

@MarkDaoust
Copy link
Contributor

MarkDaoust commented Aug 10, 2023

@MidnessX
Copy link
Author

@MidnessX Ah I see the source of confusion here, the docs you referred to are for the Layer class build() method, but what is used by build_from_config is the Model class build() method (which subclasses from the Layer build() method with extra functionality to deal with lists and tuples): https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L427.

I'm not sure why the method's API reference did not make it onto tensorflow.org under the Model class, sorry about that. I will contact @MarkDaoust from TFDocs for help on this.

Yes, I didn't look at the source code, just the docs, so I thought that Model simply used build() from Layer. That's the source of my confusion!

@nkovela1
Copy link
Contributor

@MidnessX thanks for raising this issue! It really helps us improve the documentation.

@MarkDaoust Are the APIs associated with these classes (Layer, Model, etc.) generated from the docstring automatically? Is there a config file that contains those that show up on the website?

@MarkDaoust
Copy link
Contributor

Yes, all the api reference pages on tensorflow.org are generated from the pip package and it's docstrings.

It's configurable, but mainly it tries to generate doc pages for everything in the public API.

What's visible is mostly controlled by the filters passed to callbacks here:

https://github.com/tensorflow/tensorflow/blob/7a44dcba5ed893685fe126abd2318650c126a3aa/tensorflow/tools/docs/generate2.py#L299

And these doc_controls tags:

https://github.com/tensorflow/tensorflow/blob/7a44dcba5ed893685fe126abd2318650c126a3aa/tensorflow/tools/docs/generate2.py#L265

@sachinprasadhs sachinprasadhs transferred this issue from keras-team/keras Sep 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants