Skip to content

Commit

Permalink
fix change of default signature in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Galileo-Galilei committed Nov 30, 2024
1 parent 47a0ceb commit 678ad94
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion kedro_mlflow/mlflow/kedro_pipeline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def predict(self, context, model_input, params=None):
# so if it is the default we just use the existing runner
runner = (
self.runner
if runner_class == self.runner.__name__
if runner_class == type(self.runner).__name__
else load_obj(runner_class, "kedro.runner")
)

Expand Down
11 changes: 9 additions & 2 deletions tests/framework/hooks/test_hook_pipeline_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,14 @@ def convert_probs_to_pred(data, threshold):
@pytest.fixture
def dummy_signature(dummy_catalog, dummy_pipeline_ml):
input_data = dummy_catalog.load(dummy_pipeline_ml.input_name)
dummy_signature = infer_signature(input_data)
params_dict = {
key: dummy_catalog.load(key)
for key in dummy_pipeline_ml.inference.inputs()
if key.startswith("params:")
}
dummy_signature = infer_signature(
model_input=input_data, params={**params_dict, "runner": "SequentialRunner"}
)
return dummy_signature


Expand Down Expand Up @@ -303,7 +310,7 @@ def test_mlflow_hook_save_pipeline_ml(
assert trained_model.metadata.signature.to_dict() == {
"inputs": '[{"type": "long", "name": "a", "required": true}]',
"outputs": None,
"params": None,
"params": '[{"name": "runner", "type": "string", "default": "SequentialRunner", "shape": null}]',
}


Expand Down

0 comments on commit 678ad94

Please sign in to comment.