Skip to content

Commit

Permalink
Fix the serialization of SequenceSummary block (#927)
Browse files Browse the repository at this point in the history
* fix serialization of SequenceSummary block, used by Transformer-based models

* Add `from_config` method
  • Loading branch information
sararb authored Dec 29, 2022
1 parent b1797e5 commit 9a29def
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
12 changes: 12 additions & 0 deletions merlin/models/tf/transformers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,21 @@ def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]:
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class SequenceSummary(TFSequenceSummary):
def __init__(self, summary: str, initializer_range: float = 0.02, **kwargs):
self.summary = summary
config = SimpleNamespace(summary_type=summary)
super().__init__(config, initializer_range=initializer_range, **kwargs)

def get_config(self):
config = super().get_config()
config["summary"] = self.summary
return config

@classmethod
def from_config(cls, config, custom_objects=None):
output = SequenceSummary(**config)
output.__class__ = cls
return output


@Block.registry.register("sequence_last")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tf/transformers/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_transformer_encoder_with_post():
post="sequence_mean",
)
outputs = transformer_encod(inputs)

testing_utils.assert_serialization(transformer_encod)
assert list(outputs.shape) == [NUM_ROWS, EMBED_DIM]


Expand Down

0 comments on commit 9a29def

Please sign in to comment.