-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a protocol for expected interface of a model and save input schema to metadata directory #680
Merged
marcromeyn
merged 26 commits into
NVIDIA-Merlin:main
from
oliverholworthy:model-protocol
Nov 1, 2022
Merged
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
83887ab
Add a runtime checkable protocol for expected interface of a model
oliverholworthy eb5d221
Save input and output schema from XGBoost model
oliverholworthy c28574f
Add `load_model` function to load any Merlin Model
oliverholworthy 9e4741b
Add save and load methods to Tensorflow Model
oliverholworthy 6d5a8cf
Merge branch 'main' into model-protocol
oliverholworthy 0ad32ae
Use `mm` instead of `ml` in `test_reload`
oliverholworthy a75a336
Use InputBlockV2 in model `test_reload`
oliverholworthy f875c4a
Merge branch 'main' into model-protocol
oliverholworthy 37893ee
Add docstring to `save` method of MerlinModel protocol
oliverholworthy 4315de2
Update name of variable in `merlin.models.io`
oliverholworthy e45f25b
Check signatures in save and load Model test
oliverholworthy 7822dbb
Don't raise exception if merlin metadata directory already exists
oliverholworthy 324217f
Raise ValueError when target columns passed to input block
oliverholworthy 6b875a6
Remove targets before passing to input block in test_save_and_load
oliverholworthy 6c5c496
Add Optional typehint reflecting valid inputs to save_merlin_metadata
oliverholworthy 27d165c
Remove `load_model` function from this PR
oliverholworthy ac87c58
Revert change to check targets in InputBlock
oliverholworthy 167bc9e
Restore newline in InputBlockV2
oliverholworthy d8ab1f0
Merge branch 'main' into model-protocol
oliverholworthy 0707b8f
Merge branch 'main' into model-protocol
oliverholworthy cb01a27
Merge branch 'main' into model-protocol
oliverholworthy e9a5445
Merge branch 'main' into model-protocol
marcromeyn 4425c88
Update merlin metadata directory name to `.merlin`
oliverholworthy 85566e4
Merge branch 'main' into model-protocol
marcromeyn 9e46d3d
Merge branch 'main' into model-protocol
oliverholworthy 3f5da6d
Merge branch 'main' into model-protocol
marcromeyn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import Protocol, runtime_checkable | ||
|
||
|
||
@runtime_checkable | ||
class MerlinModel(Protocol): | ||
"""Protocol for a merlin model. | ||
|
||
This defines common methods that should be implemented on all model implementations. | ||
Including those with different frameworks (e.g. Tensorflow, XGBoost, Pytorch) | ||
""" | ||
|
||
def save(self, path) -> None: | ||
"""Save the model to a local path. | ||
|
||
Parameters | ||
---------- | ||
path : Union[str, os.PathLike] | ||
""" | ||
... | ||
|
||
@classmethod | ||
def load(cls, path): | ||
"""Load the model from a path-like argument provided where a model was previously saved. | ||
|
||
Parameters | ||
---------- | ||
path : Union[str, os.PathLike] | ||
A path correspoonding to the directory where a model was previously saved. | ||
""" | ||
... | ||
|
||
def fit(self, dataset, *args, **kwargs): | ||
"""Fit the model on the provided dataset. | ||
|
||
Parameters | ||
---------- | ||
dataset : merlin.io.Dataset | ||
The training dataset to be used to fit the model. | ||
""" | ||
... | ||
|
||
def evaluate(self, dataset, *args, **kwargs): | ||
"""Return evaluation metrics on a dataset. | ||
|
||
Parameters | ||
---------- | ||
dataset : merlin.io.Dataset | ||
The evaluation dataset to be used to compute metrics. | ||
""" | ||
... | ||
|
||
def predict(self, dataset, *args, **kwargs): | ||
"""Return predictions generated by the model. | ||
|
||
Parameters | ||
---------- | ||
dataset : merlin.io.Dataset | ||
The dataset to generate predictions from. | ||
""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import os | ||
import pathlib | ||
from typing import Optional, Union | ||
|
||
from merlin.models.api import MerlinModel | ||
from merlin.models.utils.schema_utils import schema_to_tensorflow_metadata_json | ||
from merlin.schema import Schema | ||
|
||
_MERLIN_METADATA_DIR_NAME = "merlin_metadata" | ||
|
||
|
||
def save_merlin_metadata( | ||
export_path: Union[str, os.PathLike], | ||
model: MerlinModel, | ||
input_schema: Optional[Schema], | ||
output_schema: Optional[Schema], | ||
) -> None: | ||
"""Saves data to Merlin Metadata Directory.""" | ||
export_path = pathlib.Path(export_path) | ||
merlin_metadata_dir = export_path / _MERLIN_METADATA_DIR_NAME | ||
merlin_metadata_dir.mkdir(exist_ok=True) | ||
|
||
if input_schema is not None: | ||
schema_to_tensorflow_metadata_json( | ||
input_schema, | ||
merlin_metadata_dir / "input_schema.json", | ||
) | ||
if output_schema is not None: | ||
schema_to_tensorflow_metadata_json( | ||
output_schema, | ||
merlin_metadata_dir / "output_schema.json", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling this metadata directory
merlin_metadata
here.Another alternative name could be something like
.merlin
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would vote for .merlin