Skip to content

Commit

Permalink
Add support for logging of H2O MOJO Models (#486)
Browse files Browse the repository at this point in the history
* feat: h2o model logging supports MOJO

* feat: add support for deserializing of h2o MOJO artifacts

* fix: log MOJO model as directory

* test: update unit test to test for binary vs MOJO

* test: update h2o schema integration test to make sure binary model is used

* fix: download_mojo instead of save_mojo

* test: add unit test for retrieval of MOJO model

* chore: pre-commit

* chore: backwards compatability for h2o

* chore: update literals
  • Loading branch information
thebrianbn authored Sep 30, 2024
1 parent cfc90ab commit 6f22bce
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 16 deletions.
23 changes: 20 additions & 3 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_data(self):
@failsafe
def get_data(
self,
deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None,
deserialize: Optional[Literal["h2o", "h2o_binary", "h2o_mojo", "pickle", "xgboost"]] = None,
unpickle: bool = False, # TODO: deprecate & move to `deserialize`
):
"""Loads the data associated with this artifact and
Expand All @@ -82,7 +82,8 @@ def get_data(
deseralize : str, optional
Method to use to deseralize this artifact's data.
* None to disable deseralization and return the raw data.
* "h2o" to use `h2o.load_model` to load the data.
* "h2o" or "h2o_binary" to use `h2o.load_model` to load the data.
* "h2o_mojo" to use `h2o.import_mojo` to load the data.
* "pickle" to use pickles to load the data.
* "xgboost" to use xgboost's JSON loader to load the data as a fitted model.
Defaults to None.
Expand All @@ -101,6 +102,13 @@ def get_data(
)
deserialize = "pickle"

if deserialize == "h2o":
warnings.warn(
"'deserialize' method 'h2o' will be deprecated in a future release,"
" please use 'h2o_binary' instead.",
DeprecationWarning,
)

for repo in self.repositories or []:
try:
if deserialize == "xgboost":
Expand All @@ -119,12 +127,21 @@ def get_data(
except Exception as err:
return_err = err
else:
if deserialize == "h2o":
if deserialize in [
"h2o",
"h2o_binary",
]: # "h2o" will be deprecated in a future release
import h2o

data = h2o.load_model(
repo._get_artifact_data_path(project_name, experiment_id, self.id)
)
elif deserialize == "h2o_mojo":
import h2o

data = h2o.import_mojo(
repo._get_artifact_data_path(project_name, experiment_id, self.id)
)
elif deserialize == "pickle":
data = pickle.loads(data)

Expand Down
20 changes: 14 additions & 6 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def log_h2o_model(
h2o_model,
artifact_name: Optional[str] = None,
export_cross_validation_predictions: bool = False,
use_mojo: bool = False,
**log_artifact_kwargs,
) -> Artifact:
"""Log an `h2o` model as an artifact using `h2o.save_model`.
Expand All @@ -256,6 +257,9 @@ def log_h2o_model(
The name of the artifact. Defaults to None, using `h2o_model`'s class name.
export_cross_validation_predictions: bool, optional (default False)
Passed directly to `h2o.save_model`.
use_mojo: bool, optional (default False)
Whether to log the model in MOJO format. If False, the model will be
logged in binary format.
log_artifact_kwargs : dict
Additional kwargs to be passed directly to `self.log_artifact`.
"""
Expand All @@ -268,12 +272,16 @@ def log_h2o_model(
artifact_name = h2o_model.__class__.__name__

with tempfile.TemporaryDirectory() as temp_dir_name:
model_data_path = h2o.save_model(
h2o_model,
export_cross_validation_predictions=export_cross_validation_predictions,
filename=artifact_name,
path=temp_dir_name,
)
if use_mojo:
model_data_path = f"{temp_dir_name}/{artifact_name}.zip"
h2o_model.download_mojo(path=model_data_path)
else:
model_data_path = h2o.save_model(
h2o_model,
export_cross_validation_predictions=export_cross_validation_predictions,
filename=artifact_name,
path=temp_dir_name,
)

artifact = self.log_artifact(
name=artifact_name,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,6 @@ def test_estimator_h2o_schema_train(
model_artifact = experiment.artifact(name=schema_cls.__name__)

assert len(project.schema_["parameters"]) == len(experiment.parameters())
assert model_artifact.get_data(deserialize="h2o").__class__.__name__ == schema_cls.__name__
assert (
model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__
)
23 changes: 19 additions & 4 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import xgboost
from h2o import H2OFrame
from h2o.estimators.generic import H2OGenericEstimator
from h2o.estimators.random_forest import H2ORandomForestEstimator

from rubicon_ml import domain
Expand Down Expand Up @@ -159,8 +160,19 @@ def test_download_location(mock_open, project_client):
mock_file().write.assert_called_once_with(data)


@pytest.mark.parametrize(
["use_mojo", "deserialization_method"],
[
(False, "h2o"),
(False, "h2o_binary"),
(True, "h2o_mojo"),
],
)
def test_get_data_deserialize_h2o(
make_classification_df, rubicon_local_filesystem_client_with_project
make_classification_df,
rubicon_local_filesystem_client_with_project,
use_mojo,
deserialization_method,
):
"""Test logging `h2o` model data."""
_, project = rubicon_local_filesystem_client_with_project
Expand All @@ -181,10 +193,13 @@ def test_get_data_deserialize_h2o(
y=target_name,
)

artifact = project.log_h2o_model(h2o_model)
artifact_data = artifact.get_data(deserialize="h2o")
artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo)
artifact_data = artifact.get_data(deserialize=deserialization_method)

assert artifact_data.__class__ == h2o_model.__class__
if use_mojo:
assert isinstance(artifact_data, H2OGenericEstimator)
else:
assert artifact_data.__class__ == h2o_model.__class__


def test_get_data_deserialize_xgboost(
Expand Down
7 changes: 5 additions & 2 deletions tests/unit/client/test_mixin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ def test_log_json(project_client):
assert artifact_b.id in [a.id for a in artifacts]


def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project):
@pytest.mark.parametrize("use_mojo", [False, True])
def test_log_h2o_model(
make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo
):
"""Test logging `h2o` model data."""
_, project = rubicon_local_filesystem_client_with_project
X, y = make_classification_df
Expand All @@ -222,7 +225,7 @@ def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_w
y=target_name,
)

artifact = project.log_h2o_model(h2o_model, tags=["h2o"])
artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo, tags=["h2o"])
read_artifact = project.artifact(name=artifact.name)

assert artifact.id == read_artifact.id
Expand Down

0 comments on commit 6f22bce

Please sign in to comment.