Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a protocol for expected interface of a model and save input schema to metadata directory #680

Merged
merged 26 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
83887ab
Add a runtime checkable protocol for expected interface of a model
oliverholworthy Aug 25, 2022
eb5d221
Save input and output schema from XGBoost model
oliverholworthy Oct 4, 2022
c28574f
Add `load_model` function to load any Merlin Model
oliverholworthy Oct 4, 2022
9e4741b
Add save and load methods to Tensorflow Model
oliverholworthy Oct 4, 2022
6d5a8cf
Merge branch 'main' into model-protocol
oliverholworthy Oct 5, 2022
0ad32ae
Use `mm` instead of `ml` in `test_reload`
oliverholworthy Oct 5, 2022
a75a336
Use InputBlockV2 in model `test_reload`
oliverholworthy Oct 5, 2022
f875c4a
Merge branch 'main' into model-protocol
oliverholworthy Oct 7, 2022
37893ee
Add docstring to `save` method of MerlinModel protocol
oliverholworthy Oct 7, 2022
4315de2
Update name of variable in `merlin.models.io`
oliverholworthy Oct 7, 2022
e45f25b
Check signatures in save and load Model test
oliverholworthy Oct 7, 2022
7822dbb
Don't raise exception if merlin metadata directory already exists
oliverholworthy Oct 10, 2022
324217f
Raise ValueError when target columns passed to input block
oliverholworthy Oct 10, 2022
6b875a6
Remove targets before passing to input block in test_save_and_load
oliverholworthy Oct 10, 2022
6c5c496
Add Optional typehint reflecting valid inputs to save_merlin_metadata
oliverholworthy Oct 10, 2022
27d165c
Remove `load_model` function from this PR
oliverholworthy Oct 10, 2022
ac87c58
Revert change to check targets in InputBlock
oliverholworthy Oct 10, 2022
167bc9e
Restore newline in InputBlockV2
oliverholworthy Oct 10, 2022
d8ab1f0
Merge branch 'main' into model-protocol
oliverholworthy Oct 10, 2022
0707b8f
Merge branch 'main' into model-protocol
oliverholworthy Oct 12, 2022
cb01a27
Merge branch 'main' into model-protocol
oliverholworthy Oct 18, 2022
e9a5445
Merge branch 'main' into model-protocol
marcromeyn Oct 20, 2022
4425c88
Update merlin metadata directory name to `.merlin`
oliverholworthy Oct 21, 2022
85566e4
Merge branch 'main' into model-protocol
marcromeyn Oct 24, 2022
9e46d3d
Merge branch 'main' into model-protocol
oliverholworthy Oct 24, 2022
3f5da6d
Merge branch 'main' into model-protocol
marcromeyn Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions merlin/models/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Protocol, runtime_checkable


@runtime_checkable
class MerlinModel(Protocol):
"""Protocol for a merlin model.

This defines common methods that should be implemented on all model implementations.
Including those with different frameworks (e.g. Tensorflow, XGBoost, Pytorch)
"""

def save(self, path) -> None:
"""Save the model to a local path.

Parameters
----------
path : Union[str, os.PathLike]
"""
...

@classmethod
def load(cls, path):
"""Load the model from a path-like argument provided where a model was previously saved.

Parameters
----------
path : Union[str, os.PathLike]
A path correspoonding to the directory where a model was previously saved.
"""
...

def fit(self, dataset, *args, **kwargs):
"""Fit the model on the provided dataset.

Parameters
----------
dataset : merlin.io.Dataset
The training dataset to be used to fit the model.
"""
...

def evaluate(self, dataset, *args, **kwargs):
"""Return evaluation metrics on a dataset.

Parameters
----------
dataset : merlin.io.Dataset
The evaluation dataset to be used to compute metrics.
"""
...

def predict(self, dataset, *args, **kwargs):
"""Return predictions generated by the model.

Parameters
----------
dataset : merlin.io.Dataset
The dataset to generate predictions from.
"""
...
47 changes: 47 additions & 0 deletions merlin/models/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pathlib
from typing import Optional, Union

from merlin.models.api import MerlinModel
from merlin.models.utils.schema_utils import schema_to_tensorflow_metadata_json
from merlin.schema import Schema

_MERLIN_METADATA_DIR_NAME = "merlin_metadata"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling this metadata directory merlin_metadata here.

Another alternative name could be something like .merlin.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would vote for .merlin



def save_merlin_metadata(
export_path: Union[str, os.PathLike],
model: MerlinModel,
input_schema: Optional[Schema],
output_schema: Optional[Schema],
) -> None:
"""Saves data to Merlin Metadata Directory."""
export_path = pathlib.Path(export_path)
merlin_metadata_dir = export_path / _MERLIN_METADATA_DIR_NAME
merlin_metadata_dir.mkdir(exist_ok=True)

if input_schema is not None:
schema_to_tensorflow_metadata_json(
input_schema,
merlin_metadata_dir / "input_schema.json",
)
if output_schema is not None:
schema_to_tensorflow_metadata_json(
output_schema,
merlin_metadata_dir / "output_schema.json",
)
45 changes: 45 additions & 0 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

import collections
import inspect
import os
import sys
import warnings
from collections.abc import Sequence as SequenceCollection
Expand All @@ -14,6 +30,7 @@
from tensorflow.keras.utils import unpack_x_y_sample_weight

import merlin.io
from merlin.models.io import save_merlin_metadata
from merlin.models.tf.core.base import Block, ModelContext, PredictionOutput, is_input_block
from merlin.models.tf.core.combinators import ParallelBlock, SequentialBlock
from merlin.models.tf.core.prediction import Prediction, PredictionContext
Expand Down Expand Up @@ -858,6 +875,34 @@ def __init__(
self.schema = sum(input_block_schemas, Schema())
self._frozen_blocks = set()

def save(
self,
export_path: Union[str, os.PathLike],
include_optimizer=True,
save_traces=True,
) -> None:
"""Saves the model to export_path as a Tensorflow Saved Model.
Along with merlin model metadata.
"""
super().save(
export_path,
include_optimizer=include_optimizer,
save_traces=save_traces,
save_format="tf",
)
save_merlin_metadata(export_path, self, self.schema, None)

@classmethod
def load(cls, export_path: Union[str, os.PathLike]) -> "Model":
"""Loads a model that was saved with `model.save()`.

Parameters
----------
export_path : Union[str, os.PathLike]
The path to the saved model.
"""
return tf.keras.models.load_model(export_path)

def _maybe_build(self, inputs):
if isinstance(inputs, dict):
_ragged_inputs = ListToRagged()(inputs)
Expand Down
9 changes: 9 additions & 0 deletions merlin/models/xgb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from merlin.core.utils import global_dask_client
from merlin.io import Dataset
from merlin.models.io import save_merlin_metadata
from merlin.models.utils.schema_utils import (
schema_to_tensorflow_metadata_json,
tensorflow_metadata_json_to_schema,
Expand Down Expand Up @@ -254,6 +255,14 @@ def save(self, path: Union[str, os.PathLike]) -> None:
export_dir.mkdir(parents=True)
self.booster.save_model(export_dir / "model.json")
schema_to_tensorflow_metadata_json(self.schema, export_dir / "schema.json")

save_merlin_metadata(
export_dir,
self,
self.schema.select_by_name(self.feature_columns),
self.schema.select_by_name(self.target_columns),
)

with open(export_dir / "params.json", "w") as f:
json.dump(self.params, f, indent=4)
with open(export_dir / "config.json", "w") as f:
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/tf/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,32 @@ def test_unfreeze_all_blocks(ecommerce_data):
model.fit(ecommerce_data, batch_size=128, epochs=1)


def test_save_and_load(tmpdir):
dataset = generate_data("e-commerce", num_rows=10)
dataset.schema = dataset.schema.select_by_name(["click", "user_age"])
model = mm.Model(
mm.InputBlockV2(dataset.schema.remove_by_tag(Tags.TARGET)),
mm.MLPBlock([4]),
mm.BinaryClassificationTask("click"),
)
model.compile()
_ = model.fit(
dataset,
epochs=1,
batch_size=10,
)
model.save(tmpdir)
reloaded_model = mm.Model.load(tmpdir)
signature_input_keys = set(
reloaded_model.signatures["serving_default"].structured_input_signature[1].keys()
)
assert signature_input_keys == {"user_age"}
test_case = TestCase()
test_case.assertAllClose(
model.predict(dataset, batch_size=10), reloaded_model.predict(dataset, batch_size=10)
)


def test_retrieval_model_query(ecommerce_data: Dataset, run_eagerly=True):
query = ecommerce_data.schema.select_by_tag(Tags.USER_ID)
candidate = ecommerce_data.schema.select_by_tag(Tags.ITEM_ID)
Expand Down