-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a protocol for expected interface of a model and save input schem…
…a to metadata directory (#680) * Add a runtime checkable protocol for expected interface of a model * Save input and output schema from XGBoost model * Add `load_model` function to load any Merlin Model * Add save and load methods to Tensorflow Model * Use `mm` instead of `ml` in `test_reload` * Use InputBlockV2 in model `test_reload` * Add docstring to `save` method of MerlinModel protocol * Update name of variable in `merlin.models.io` * Check signatures in save and load Model test * Don't raise exception if merlin metadata directory already exists * Raise ValueError when target columns passed to input block * Remove targets before passing to input block in test_save_and_load * Add Optional typehint reflecting valid inputs to save_merlin_metadata * Remove `load_model` function from this PR * Revert change to check targets in InputBlock * Restore newline in InputBlockV2 * Update merlin metadata directory name to `.merlin` Co-authored-by: Marc Romeyn <[email protected]>
- Loading branch information
1 parent
f9ab96e
commit 8ad65c5
Showing
5 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from typing import Protocol, runtime_checkable | ||
|
||
|
||
@runtime_checkable | ||
class MerlinModel(Protocol): | ||
"""Protocol for a merlin model. | ||
This defines common methods that should be implemented on all model implementations. | ||
Including those with different frameworks (e.g. Tensorflow, XGBoost, Pytorch) | ||
""" | ||
|
||
def save(self, path) -> None: | ||
"""Save the model to a local path. | ||
Parameters | ||
---------- | ||
path : Union[str, os.PathLike] | ||
""" | ||
... | ||
|
||
@classmethod | ||
def load(cls, path): | ||
"""Load the model from a path-like argument provided where a model was previously saved. | ||
Parameters | ||
---------- | ||
path : Union[str, os.PathLike] | ||
A path correspoonding to the directory where a model was previously saved. | ||
""" | ||
... | ||
|
||
def fit(self, dataset, *args, **kwargs): | ||
"""Fit the model on the provided dataset. | ||
Parameters | ||
---------- | ||
dataset : merlin.io.Dataset | ||
The training dataset to be used to fit the model. | ||
""" | ||
... | ||
|
||
def evaluate(self, dataset, *args, **kwargs): | ||
"""Return evaluation metrics on a dataset. | ||
Parameters | ||
---------- | ||
dataset : merlin.io.Dataset | ||
The evaluation dataset to be used to compute metrics. | ||
""" | ||
... | ||
|
||
def predict(self, dataset, *args, **kwargs): | ||
"""Return predictions generated by the model. | ||
Parameters | ||
---------- | ||
dataset : merlin.io.Dataset | ||
The dataset to generate predictions from. | ||
""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import os | ||
import pathlib | ||
from typing import Optional, Union | ||
|
||
from merlin.models.api import MerlinModel | ||
from merlin.models.utils.schema_utils import schema_to_tensorflow_metadata_json | ||
from merlin.schema import Schema | ||
|
||
_MERLIN_METADATA_DIR_NAME = ".merlin" | ||
|
||
|
||
def save_merlin_metadata( | ||
export_path: Union[str, os.PathLike], | ||
model: MerlinModel, | ||
input_schema: Optional[Schema], | ||
output_schema: Optional[Schema], | ||
) -> None: | ||
"""Saves data to Merlin Metadata Directory.""" | ||
export_path = pathlib.Path(export_path) | ||
merlin_metadata_dir = export_path / _MERLIN_METADATA_DIR_NAME | ||
merlin_metadata_dir.mkdir(exist_ok=True) | ||
|
||
if input_schema is not None: | ||
schema_to_tensorflow_metadata_json( | ||
input_schema, | ||
merlin_metadata_dir / "input_schema.json", | ||
) | ||
if output_schema is not None: | ||
schema_to_tensorflow_metadata_json( | ||
output_schema, | ||
merlin_metadata_dir / "output_schema.json", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters