Skip to content

Commit

Permalink
Add a protocol for expected interface of a model and save input schem…
Browse files Browse the repository at this point in the history
…a to metadata directory (#680)

* Add a runtime checkable protocol for expected interface of a model

* Save input and output schema from XGBoost model

* Add `load_model` function to load any Merlin Model

* Add save and load methods to Tensorflow Model

* Use `mm` instead of `ml` in `test_reload`

* Use InputBlockV2 in model `test_reload`

* Add docstring to `save` method of MerlinModel protocol

* Update name of variable in `merlin.models.io`

* Check signatures in save and load Model test

* Don't raise exception if merlin metadata directory already exists

* Raise ValueError when target columns passed to input block

* Remove targets before passing to input block in test_save_and_load

* Add Optional typehint reflecting valid inputs to save_merlin_metadata

* Remove `load_model` function from this PR

* Revert change to check targets in InputBlock

* Restore newline in InputBlockV2

* Update merlin metadata directory name to `.merlin`

Co-authored-by: Marc Romeyn <[email protected]>
  • Loading branch information
oliverholworthy and marcromeyn authored Nov 1, 2022
1 parent f9ab96e commit 8ad65c5
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 0 deletions.
76 changes: 76 additions & 0 deletions merlin/models/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Protocol, runtime_checkable


@runtime_checkable
class MerlinModel(Protocol):
"""Protocol for a merlin model.
This defines common methods that should be implemented on all model implementations.
Including those with different frameworks (e.g. Tensorflow, XGBoost, Pytorch)
"""

def save(self, path) -> None:
"""Save the model to a local path.
Parameters
----------
path : Union[str, os.PathLike]
"""
...

@classmethod
def load(cls, path):
"""Load the model from a path-like argument provided where a model was previously saved.
Parameters
----------
path : Union[str, os.PathLike]
A path correspoonding to the directory where a model was previously saved.
"""
...

def fit(self, dataset, *args, **kwargs):
"""Fit the model on the provided dataset.
Parameters
----------
dataset : merlin.io.Dataset
The training dataset to be used to fit the model.
"""
...

def evaluate(self, dataset, *args, **kwargs):
"""Return evaluation metrics on a dataset.
Parameters
----------
dataset : merlin.io.Dataset
The evaluation dataset to be used to compute metrics.
"""
...

def predict(self, dataset, *args, **kwargs):
"""Return predictions generated by the model.
Parameters
----------
dataset : merlin.io.Dataset
The dataset to generate predictions from.
"""
...
47 changes: 47 additions & 0 deletions merlin/models/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pathlib
from typing import Optional, Union

from merlin.models.api import MerlinModel
from merlin.models.utils.schema_utils import schema_to_tensorflow_metadata_json
from merlin.schema import Schema

_MERLIN_METADATA_DIR_NAME = ".merlin"


def save_merlin_metadata(
export_path: Union[str, os.PathLike],
model: MerlinModel,
input_schema: Optional[Schema],
output_schema: Optional[Schema],
) -> None:
"""Saves data to Merlin Metadata Directory."""
export_path = pathlib.Path(export_path)
merlin_metadata_dir = export_path / _MERLIN_METADATA_DIR_NAME
merlin_metadata_dir.mkdir(exist_ok=True)

if input_schema is not None:
schema_to_tensorflow_metadata_json(
input_schema,
merlin_metadata_dir / "input_schema.json",
)
if output_schema is not None:
schema_to_tensorflow_metadata_json(
output_schema,
merlin_metadata_dir / "output_schema.json",
)
45 changes: 45 additions & 0 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

import collections
import inspect
import os
import sys
import warnings
from collections.abc import Sequence as SequenceCollection
Expand All @@ -15,6 +31,7 @@
from tensorflow.keras.utils import unpack_x_y_sample_weight

import merlin.io
from merlin.models.io import save_merlin_metadata
from merlin.models.tf.core.base import Block, ModelContext, PredictionOutput, is_input_block
from merlin.models.tf.core.combinators import ParallelBlock, SequentialBlock
from merlin.models.tf.core.prediction import Prediction, PredictionContext, TensorLike
Expand Down Expand Up @@ -999,6 +1016,34 @@ def __init__(
self.schema = sum(input_block_schemas, Schema())
self._frozen_blocks = set()

def save(
self,
export_path: Union[str, os.PathLike],
include_optimizer=True,
save_traces=True,
) -> None:
"""Saves the model to export_path as a Tensorflow Saved Model.
Along with merlin model metadata.
"""
super().save(
export_path,
include_optimizer=include_optimizer,
save_traces=save_traces,
save_format="tf",
)
save_merlin_metadata(export_path, self, self.schema, None)

@classmethod
def load(cls, export_path: Union[str, os.PathLike]) -> "Model":
"""Loads a model that was saved with `model.save()`.
Parameters
----------
export_path : Union[str, os.PathLike]
The path to the saved model.
"""
return tf.keras.models.load_model(export_path)

def _maybe_build(self, inputs):
if isinstance(inputs, dict):
_ragged_inputs = ListToRagged()(inputs)
Expand Down
9 changes: 9 additions & 0 deletions merlin/models/xgb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from merlin.core.utils import global_dask_client
from merlin.io import Dataset
from merlin.models.io import save_merlin_metadata
from merlin.models.utils.schema_utils import (
schema_to_tensorflow_metadata_json,
tensorflow_metadata_json_to_schema,
Expand Down Expand Up @@ -254,6 +255,14 @@ def save(self, path: Union[str, os.PathLike]) -> None:
export_dir.mkdir(parents=True)
self.booster.save_model(export_dir / "model.json")
schema_to_tensorflow_metadata_json(self.schema, export_dir / "schema.json")

save_merlin_metadata(
export_dir,
self,
self.schema.select_by_name(self.feature_columns),
self.schema.select_by_name(self.target_columns),
)

with open(export_dir / "params.json", "w") as f:
json.dump(self.params, f, indent=4)
with open(export_dir / "config.json", "w") as f:
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/tf/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,32 @@ def test_unfreeze_all_blocks(ecommerce_data):
model.fit(ecommerce_data, batch_size=128, epochs=1)


def test_save_and_load(tmpdir):
dataset = generate_data("e-commerce", num_rows=10)
dataset.schema = dataset.schema.select_by_name(["click", "user_age"])
model = mm.Model(
mm.InputBlockV2(dataset.schema.remove_by_tag(Tags.TARGET)),
mm.MLPBlock([4]),
mm.BinaryClassificationTask("click"),
)
model.compile()
_ = model.fit(
dataset,
epochs=1,
batch_size=10,
)
model.save(tmpdir)
reloaded_model = mm.Model.load(tmpdir)
signature_input_keys = set(
reloaded_model.signatures["serving_default"].structured_input_signature[1].keys()
)
assert signature_input_keys == {"user_age"}
test_case = TestCase()
test_case.assertAllClose(
model.predict(dataset, batch_size=10), reloaded_model.predict(dataset, batch_size=10)
)


def test_retrieval_model_query(ecommerce_data: Dataset, run_eagerly=True):
query = ecommerce_data.schema.select_by_tag(Tags.USER_ID)
candidate = ecommerce_data.schema.select_by_tag(Tags.ITEM_ID)
Expand Down

0 comments on commit 8ad65c5

Please sign in to comment.