-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the serialization of SequenceSummary
block
#927
Conversation
Documentation preview |
fixes the model saving error in #908 due to SequenceSummary as a post block. |
rerun tests |
2 similar comments
rerun tests |
rerun tests |
@@ -193,6 +193,10 @@ def __init__(self, summary: str, initializer_range: float = 0.02, **kwargs): | |||
config = SimpleNamespace(summary_type=summary) | |||
super().__init__(config, initializer_range=initializer_range, **kwargs) | |||
|
|||
def get_config(self): | |||
config = super().get_config() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we may also need to also save the summary
argument from the __init__
in this config, and implement from_config
to restore the layer to the same state correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I have updated the SequenceSummary
to make it more flexible.
The reason I didn't add from_config
is that SequenceSummary
was used as a base class of the transform blocks SequenceLast
, SequenceFirst
, SequenceMean
, and SequenceClsIndex
where the summary
parameter was set to a default value for each block.
bfe1602
to
378be10
Compare
* fix serialization of SequenceSummary block, used by Transformer-based models * Add `from_config` method
Saving a transformer-based model with
SequenceSummary
as a post block was throwing an error:Goals ⚽
get_config
method to theSequenceSummary
blockTesting Details 🔍
test_transformer_encoder_with_post
to check the serialization of the transformer block defined with aSequenceSummary
post block.