Skip to content

Commit

Permalink
🔊 Log model version information when a model is loaded from the regis…
Browse files Browse the repository at this point in the history
…try (#552) (#588)

* 🔊 Log model version information when a model is loaded from the registry (#552)

* get run  id from metadata

* simplify and log only the run_id

* add failing test

* up changelog and add extra info on test

* typo in changelog

* remove unused client property

* fix changelog typo
  • Loading branch information
Galileo-Galilei authored Oct 26, 2024
1 parent d6758ac commit 8f8f6aa
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- :sparkles: Implement missing ``PipelineML`` filtering functionalities to let ``kedro`` display resume hints and avoid breaking ``kedro-viz`` ([#377](https://github.com/Galileo-Galilei/kedro-mlflow/pull/377), [#601, Calychas](https://github.com/Galileo-Galilei/kedro-mlflow/pull/601))
- :sparkles: Sanitize parameters name with unsupported characters to avoid ``mlflow`` errors when logging ([#595, pascalwhoop](https://github.com/Galileo-Galilei/kedro-mlflow/pull/595))
- :loud_sound: Add logs about the exact ``run_id`` loaded within a ``MlflowRegistryDataset`` because some URI are confusing (e.g. ``latest``) and hard to debug ([#552](https://github.com/Galileo-Galilei/kedro-mlflow/pull/552))

### Changed

Expand Down
15 changes: 14 additions & 1 deletion kedro_mlflow/io/models/mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import Logger, getLogger
from typing import Any, Dict, Optional, Union

from kedro.io.core import DatasetError
Expand Down Expand Up @@ -67,6 +68,10 @@ def __init__(
else f"models:/{model_name}/{stage_or_version}"
)

@property
def _logger(self) -> Logger:
return getLogger(__name__)

def _load(self) -> Any:
"""Loads an MLflow model from local path or from MLflow run.
Expand All @@ -77,10 +82,18 @@ def _load(self) -> Any:
# If `run_id` is specified, pull the model from MLflow.
# TODO: enable loading from another mlflow conf (with a client with another tracking uri)
# Alternatively, use local path to load the model.
return self._mlflow_model_module.load_model(
model = self._mlflow_model_module.load_model(
model_uri=self.model_uri, **self._load_args
)

# log some info because "latest" model is not very informative
# the model itself does not have information about its registry
# because the same run can be registered under several different names
# in the registry. See https://github.com/Galileo-Galilei/kedro-mlflow/issues/552

self._logger.info(f"Loading model from run_id='{model.metadata.run_id}'")
return model

def _save(self, model: Any) -> None:
raise NotImplementedError(
"The 'save' method is not implemented for MlflowModelRegistryDataset. You can pass 'registered_model_name' argument in 'MLflowModelTrackingDataset(..., save_args={registered_model_name='my_model'}' to save and register a model in the same step. "
Expand Down
1 change: 0 additions & 1 deletion tests/framework/cli/test_cli_modelify.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def test_modelify_with_pip_requirements(monkeypatch, kp_for_modelify):
runs_list_before_cmd = context.mlflow.server._mlflow_client.search_runs(
context.mlflow.tracking.experiment._experiment.experiment_id
)
print(runs_list_before_cmd)
cli_runner = CliRunner()

result = cli_runner.invoke(
Expand Down
46 changes: 46 additions & 0 deletions tests/io/models/test_mlflow_model_registry_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,52 @@ def test_mlflow_model_registry_alias_and_stage_or_version_fails(tmp_path):
)


# this test is failing because of long standing issues like this :
# https://github.com/pytest-dev/pytest/issues/7335
# https://github.com/pytest-dev/pytest/issues/5160
# To make logging occur, we need to from kedro.framework.projcet import LOGGING at the beginning
# ironically, the sderr error reported by pytest shows that logging actually occurs!
# If I remove with mlflow.start_run(), caplog is indeed not empty, it seems mlflow flushes the internal loger
# probably related to https://github.com/mlflow/mlflow/issues/4957
@pytest.mark.xfail
def test_mlflow_model_registry_logs_run_id(caplog, tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
# are stored in a relative mlruns/ folder so we need to have
# the same working directory that the one of the tracking uri
monkeypatch.chdir(tmp_path)
tracking_and_registry_uri = r"sqlite:///" + (tmp_path / "mlruns3.db").as_posix()
mlflow.set_tracking_uri(tracking_and_registry_uri)
mlflow.set_registry_uri(tracking_and_registry_uri)

# setup: we train 2 version of a model under a single
# registered model and stage the 2nd one
run_ids = {}
for i in range(2):
with mlflow.start_run():
model = DecisionTreeClassifier()
mlflow.sklearn.log_model(
model, artifact_path="demo_model", registered_model_name="demo_model"
)
run_ids[i + 1] = mlflow.active_run().info.run_id

# case 1: no version is provided, we take the last one

ml_ds = MlflowModelRegistryDataset(model_name="demo_model", stage_or_version=1)
ml_ds.load()

# caplog.text, caplog.messages, caplog.records are all empty ???, but th stderr will show them
assert run_ids[1] in caplog.text

# case 2: a stage is provided, we take the last model with this stage
ml_ds = MlflowModelRegistryDataset(
model_name="demo_model", stage_or_version="latest"
)
ml_ds._load()

assert run_ids[2] in caplog.text


def test_mlflow_model_registry_load_given_stage_or_version(tmp_path, monkeypatch):
# we must change the working directory because when
# using mlflow with a local database tracking, the artifacts
Expand Down

0 comments on commit 8f8f6aa

Please sign in to comment.