Skip to content

Commit

Permalink
Ensure TopKEncoder has correct outputs when model is saved (#1225)
Browse files Browse the repository at this point in the history
* Remove output_names from base BaseModel.

* Add assertion for output signature of saved model to test_topk_encoder

* Move compile method from BaseModel to Model

* Correct name of structured outputs

* Move compile method back and add special case for TopKOutput
  • Loading branch information
oliverholworthy authored Nov 10, 2023
1 parent 16d289a commit 80b086f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
6 changes: 5 additions & 1 deletion merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from merlin.models.tf.outputs.base import ModelOutput, ModelOutputType
from merlin.models.tf.outputs.classification import CategoricalOutput
from merlin.models.tf.outputs.contrastive import ContrastiveOutput
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
from merlin.models.tf.transforms.features import PrepareFeatures, expected_input_cols_from_schema
from merlin.models.tf.transforms.sequence import SequenceTransform
Expand Down Expand Up @@ -446,7 +447,10 @@ def compile(
if num_v1_blocks > 0:
self.output_names = [task.task_name for task in self.prediction_tasks]
else:
self.output_names = [block.full_name for block in self.model_outputs]
if num_v2_blocks == 1 and isinstance(self.model_outputs[0], TopKOutput):
pass
else:
self.output_names = [block.full_name for block in self.model_outputs]

# This flag will make Keras change the metric-names which is not needed in v2
from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0)
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/tf/core/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ def test_topk_encoder(music_streaming_data: Dataset):
loaded_topk_encoder = tf.keras.models.load_model(tmpdir)
batch_output = loaded_topk_encoder(batch[0])

output_signature = loaded_topk_encoder.signatures["serving_default"].structured_outputs
assert len(output_signature) == 2
assert output_signature["scores"] == tf.TensorSpec(
shape=(None, TOP_K), dtype=tf.float32, name="scores"
)
assert output_signature["identifiers"] == tf.TensorSpec(
shape=(None, TOP_K), dtype=tf.int32, name="identifiers"
)

assert list(batch_output.scores.shape) == [BATCH_SIZE, TOP_K]
tf.debugging.assert_equal(
topk_encoder.topk_layer._candidates,
Expand Down

0 comments on commit 80b086f

Please sign in to comment.