diff --git a/kedro_mlflow/mlflow/kedro_pipeline_model.py b/kedro_mlflow/mlflow/kedro_pipeline_model.py index fab3e47e..23e78bff 100644 --- a/kedro_mlflow/mlflow/kedro_pipeline_model.py +++ b/kedro_mlflow/mlflow/kedro_pipeline_model.py @@ -217,7 +217,7 @@ def predict(self, context, model_input, params=None): # so if it is the default we just use the existing runner runner = ( self.runner - if runner_class == self.runner.__name__ + if runner_class == type(self.runner).__name__ else load_obj(runner_class, "kedro.runner") ) diff --git a/tests/framework/hooks/test_hook_pipeline_ml.py b/tests/framework/hooks/test_hook_pipeline_ml.py index 82c8d0b4..7731a20a 100644 --- a/tests/framework/hooks/test_hook_pipeline_ml.py +++ b/tests/framework/hooks/test_hook_pipeline_ml.py @@ -159,7 +159,14 @@ def convert_probs_to_pred(data, threshold): @pytest.fixture def dummy_signature(dummy_catalog, dummy_pipeline_ml): input_data = dummy_catalog.load(dummy_pipeline_ml.input_name) - dummy_signature = infer_signature(input_data) + params_dict = { + key: dummy_catalog.load(key) + for key in dummy_pipeline_ml.inference.inputs() + if key.startswith("params:") + } + dummy_signature = infer_signature( + model_input=input_data, params={**params_dict, "runner": "SequentialRunner"} + ) return dummy_signature @@ -303,7 +310,7 @@ def test_mlflow_hook_save_pipeline_ml( assert trained_model.metadata.signature.to_dict() == { "inputs": '[{"type": "long", "name": "a", "required": true}]', "outputs": None, - "params": None, + "params": '[{"name": "runner", "type": "string", "default": "SequentialRunner", "shape": null}]', }