diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 40007d9..15ec3d2 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -93,8 +93,8 @@ def _get_result_skeleton( def _initialize_model_dict(argument, expected_names: Collection[str]) -> dict: - if isinstance(argument, dict) and set(argument.keys()) == set(expected_names): - return argument + if isinstance(argument, dict) and set(argument.keys()) >= set(expected_names): + return {key: argument[key] for key in expected_names} return {name: argument for name in expected_names}