Skip to content

Commit

Permalink
Add new hook for specifying config per forward model step
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Oct 31, 2024
1 parent c4e1365 commit c550e32
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ert/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any:
"installable_workflow_jobs",
"help_links",
"installable_forward_model_steps",
"forward_model_configuration",
"ecl100_config_path",
"ecl300_config_path",
"flow_config_path",
Expand Down
2 changes: 2 additions & 0 deletions src/ert/plugins/hook_specifications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
flow_config_path,
)
from .forward_model_steps import (
forward_model_configuration,
installable_forward_model_steps,
)
from .help_resources import help_links
Expand All @@ -21,6 +22,7 @@
"ecl100_config_path",
"ecl300_config_path",
"flow_config_path",
"forward_model_configuration",
"help_links",
"installable_forward_model_steps",
"installable_jobs",
Expand Down
8 changes: 8 additions & 0 deletions src/ert/plugins/hook_specifications/forward_model_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ def installable_forward_model_steps() -> (
:return: List of forward model step plugins in the form of subclasses of the
ForwardModelStepPlugin class
"""


@no_type_check
@hook_specification
def forward_model_configuration() -> PluginResponse[List[Type[ForwardModelStepPlugin]]]:
"""
:return: List of configurations to be merged to be provided to forward model steps.
"""
45 changes: 45 additions & 0 deletions src/ert/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import collections
import logging
import os
import shutil
Expand Down Expand Up @@ -139,6 +140,50 @@ def get_flow_config_path(self) -> Optional[str]:
hook=self.hook.flow_config_path, config_name="flow"
)

def get_forward_model_configuration(self) -> Dict[str, Dict[str, Any]]:
response: List[PluginResponse[Dict[str, str]]] = (
self.hook.forward_model_configuration()
)
if response == []:
return {}

fm_configs: Dict[str, Dict[str, Any]] = collections.defaultdict(dict)
for res in response:
if not isinstance(res.data, dict):
raise TypeError(
f"{res.plugin_metadata.plugin_name} did not provide "
"dict[str, dict]"
)

for fmstep_name, fmstep_config in res.data.items():
if not isinstance(fmstep_name, str):
raise TypeError(
f"{res.plugin_metadata.plugin_name} did not "
"provide dict[str, dict[str, Any]]"
)
if not isinstance(fmstep_config, dict):
raise TypeError(
f"{res.plugin_metadata.plugin_name} did not "
"provide dict[str, dict[str, Any]]"
)
for key, value in fmstep_config.items():
if not isinstance(key, str):
raise TypeError(
f"{res.plugin_metadata.plugin_name} did not "
"provide dict[str, dict[str, Any]]"
)
if key.lower() in [
existing.lower() for existing in fm_configs[fmstep_name]
]:
raise RuntimeError(
"Duplicate configuration or fm_step "
f"{fmstep_name} for key {key} when parsing plugin "
f"{res.plugin_metadata.plugin_name}, it is already "
"registered by another plugin."
)
fm_configs[fmstep_name][key] = value
return fm_configs

def _site_config_lines(self) -> List[str]:
try:
plugin_responses = self.hook.site_config_lines()
Expand Down
5 changes: 5 additions & 0 deletions tests/ert/unit_tests/plugins/dummy_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ def help_links():
return {"test": "test", "test2": "test"}


@plugin(name="dummy")
def forward_model_configuration():
return {"FLOW": {"mpipath": "/foo"}}


@plugin(name="dummy")
def ecl100_config_path():
return "/dummy/path/ecl100_config.yml"
Expand Down
142 changes: 142 additions & 0 deletions tests/ert/unit_tests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import tempfile
from unittest.mock import Mock

import pytest

import ert.plugins.hook_implementations
from ert import plugin
from ert.plugins import ErtPluginManager
from tests.ert.unit_tests.plugins import dummy_plugins
from tests.ert.unit_tests.plugins.dummy_plugins import (
Expand All @@ -16,6 +19,7 @@ def test_no_plugins():
assert pm.get_flow_config_path() is None
assert pm.get_ecl100_config_path() is None
assert pm.get_ecl300_config_path() is None
assert pm.get_forward_model_configuration() == {}

assert len(pm.forward_model_steps) > 0
assert len(pm._get_config_workflow_jobs()) > 0
Expand All @@ -38,6 +42,7 @@ def test_with_plugins():
assert pm.get_flow_config_path() == "/dummy/path/flow_config.yml"
assert pm.get_ecl100_config_path() == "/dummy/path/ecl100_config.yml"
assert pm.get_ecl300_config_path() == "/dummy/path/ecl300_config.yml"
assert pm.get_forward_model_configuration() == {"FLOW": {"mpipath": "/foo"}}

assert pm.get_installable_jobs()["job1"] == "/dummy/path/job1"
assert pm.get_installable_jobs()["job2"] == "/dummy/path/job2"
Expand All @@ -55,6 +60,143 @@ def test_with_plugins():
]


def test_fm_config_with_empty_config():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {}

assert (
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration() == {}
)


def test_fm_config_with_empty_config_for_step():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo": {}}

assert (
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration() == {}
)


def test_fm_config_merges_data_for_step():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo": {"com": 3}}

class OtherPlugin:
@plugin(name="bar")
def forward_model_configuration():
return {"foo": {"bar": 2}}

assert ErtPluginManager(
plugins=[SomePlugin, OtherPlugin]
).get_forward_model_configuration() == {"foo": {"com": 3, "bar": 2}}


def test_fm_config_multiple_steps():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo100": {"com": 3}}

class OtherPlugin:
@plugin(name="bar")
def forward_model_configuration():
return {"foo200": {"bar": 2}}

assert ErtPluginManager(
plugins=[SomePlugin, OtherPlugin]
).get_forward_model_configuration() == {"foo100": {"com": 3}, "foo200": {"bar": 2}}


def test_fm_config_conflicting_config():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo100": {"com": "from_someplugin"}}

class OtherPlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo100": {"com": "from_otherplugin"}}

with pytest.raises(RuntimeError, match="Duplicate configuration"):
ErtPluginManager(
plugins=[SomePlugin, OtherPlugin]
).get_forward_model_configuration()


def test_fm_config_with_repeated_keys_different_fm_step():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo1": {"bar": "1"}}

class OtherPlugin:
@plugin(name="foo2")
def forward_model_configuration():
return {"foo2": {"bar": "2"}}

assert ErtPluginManager(
plugins=[SomePlugin, OtherPlugin]
).get_forward_model_configuration() == {"foo1": {"bar": "1"}, "foo2": {"bar": "2"}}


def test_fm_config_with_repeated_keys_with_different_case():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo": {"bar": "lower", "BAR": "higher"}}

with pytest.raises(RuntimeError, match="Duplicate configuration"):
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration()


def test_fm_config_with_wrong_type():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return 1

with pytest.raises(TypeError, match="foo did not provide dict"):
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration()


def test_fm_config_with_wrong_steptype():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {1: {"bar": "1"}}

with pytest.raises(TypeError, match="foo did not provide dict"):
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration()


def test_fm_config_with_wrong_subtype():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo100": 1}

with pytest.raises(TypeError, match="foo did not provide dict"):
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration()


def test_fm_config_with_wrong_keytype():
class SomePlugin:
@plugin(name="foo")
def forward_model_configuration():
return {"foo100": {1: "bar"}}

with pytest.raises(TypeError, match="foo did not provide dict"):
ErtPluginManager(plugins=[SomePlugin]).get_forward_model_configuration()


def test_job_documentation():
pm = ErtPluginManager(plugins=[dummy_plugins])
expected = {
Expand Down

0 comments on commit c550e32

Please sign in to comment.