diff --git a/merlin/models/tf/transformers/transforms.py b/merlin/models/tf/transformers/transforms.py index ace3fce40b..53d3119931 100644 --- a/merlin/models/tf/transformers/transforms.py +++ b/merlin/models/tf/transformers/transforms.py @@ -190,9 +190,21 @@ def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]: @tf.keras.utils.register_keras_serializable(package="merlin.models") class SequenceSummary(TFSequenceSummary): def __init__(self, summary: str, initializer_range: float = 0.02, **kwargs): + self.summary = summary config = SimpleNamespace(summary_type=summary) super().__init__(config, initializer_range=initializer_range, **kwargs) + def get_config(self): + config = super().get_config() + config["summary"] = self.summary + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + output = SequenceSummary(**config) + output.__class__ = cls + return output + @Block.registry.register("sequence_last") @tf.keras.utils.register_keras_serializable(package="merlin.models") diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index ddf5325911..398d469f04 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -141,7 +141,7 @@ def test_transformer_encoder_with_post(): post="sequence_mean", ) outputs = transformer_encod(inputs) - + testing_utils.assert_serialization(transformer_encod) assert list(outputs.shape) == [NUM_ROWS, EMBED_DIM]