diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index af0cb14c7e..02a56c2ec0 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -69,6 +69,7 @@ from merlin.models.tf.outputs.base import ModelOutput, ModelOutputType from merlin.models.tf.outputs.classification import CategoricalOutput from merlin.models.tf.outputs.contrastive import ContrastiveOutput +from merlin.models.tf.outputs.topk import TopKOutput from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask from merlin.models.tf.transforms.features import PrepareFeatures, expected_input_cols_from_schema from merlin.models.tf.transforms.sequence import SequenceTransform @@ -446,7 +447,10 @@ def compile( if num_v1_blocks > 0: self.output_names = [task.task_name for task in self.prediction_tasks] else: - self.output_names = [block.full_name for block in self.model_outputs] + if num_v2_blocks == 1 and isinstance(self.model_outputs[0], TopKOutput): + pass + else: + self.output_names = [block.full_name for block in self.model_outputs] # This flag will make Keras change the metric-names which is not needed in v2 from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) diff --git a/tests/unit/tf/core/test_encoder.py b/tests/unit/tf/core/test_encoder.py index 36c309ed64..af8390dd7c 100644 --- a/tests/unit/tf/core/test_encoder.py +++ b/tests/unit/tf/core/test_encoder.py @@ -122,6 +122,15 @@ def test_topk_encoder(music_streaming_data: Dataset): loaded_topk_encoder = tf.keras.models.load_model(tmpdir) batch_output = loaded_topk_encoder(batch[0]) + output_signature = loaded_topk_encoder.signatures["serving_default"].structured_outputs + assert len(output_signature) == 2 + assert output_signature["scores"] == tf.TensorSpec( + shape=(None, TOP_K), dtype=tf.float32, name="scores" + ) + assert output_signature["identifiers"] == tf.TensorSpec( + shape=(None, TOP_K), dtype=tf.int32, name="identifiers" + ) + assert list(batch_output.scores.shape) == [BATCH_SIZE, TOP_K] tf.debugging.assert_equal( topk_encoder.topk_layer._candidates,