Skip to content

Commit

Permalink
pass runner through params
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Nov 30, 2024
1 parent 25010aa commit 47a0ceb
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions kedro_mlflow/mlflow/kedro_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kedro.io import DataCatalog, MemoryDataset
from kedro.pipeline import Pipeline
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.utils import load_obj
from kedro_datasets.pickle import PickleDataset
from mlflow.pyfunc import PythonModel

Expand Down Expand Up @@ -208,16 +209,23 @@ def predict(self, context, model_input, params=None):
# TODO runner

params = params or {}

runner_class = params.pop("runner", "SequentialRunner")

# we don't want to recreate the runner object on each predict
# because reimporting comes with a performance penalty in a serving setup
# so if it is the default we just use the existing runner
runner = (
self.runner
) # runner="build it dynamically from runner class" or self.runner
if runner_class == self.runner.__name__
else load_obj(runner_class, "kedro.runner")
)

hook_manager = _create_hook_manager()
# _register_hooks(hook_manager, predict_params.hooks)

for name, value in params.items():
# no need to check if params are ni the catalog, because mlflow already checks that the params mathc the signature
# no need to check if params are in the catalog, because mlflow already checks that the params matching the signature
param = f"params:{name}"
self._logger.info(f"Using {param}={value} for the prediction")
self.loaded_catalog.save(name=param, data=value)
Expand Down

0 comments on commit 47a0ceb

Please sign in to comment.