Skip to content

Commit

Permalink
fixup! Pass more workflow fixtures explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Feb 11, 2025
1 parent 6e47da3 commit 0aeee68
Show file tree
Hide file tree
Showing 15 changed files with 148 additions and 163 deletions.
17 changes: 8 additions & 9 deletions src/ert/cli/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import TYPE_CHECKING

from ert.config.workflow_fixtures import WorkflowFixtures
from ert.runpaths import Runpaths
from ert.workflow_runner import WorkflowRunner

Expand All @@ -27,14 +26,14 @@ def execute_workflow(
runner = WorkflowRunner(
workflow=workflow,
fixtures={
WorkflowFixtures.storage: storage,
WorkflowFixtures.random_seed: ert_config.random_seed,
WorkflowFixtures.reports_dir: storage.path.parent
/ "reports"
/ Path(ert_config.user_config_file).stem,
WorkflowFixtures.observation_settings: ert_config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: ert_config.analysis_config.es_module,
WorkflowFixtures.run_paths: Runpaths(
"storage": storage,
"random_seed": ert_config.random_seed,
"reports_dir": str(
storage.path.parent / "reports" / Path(ert_config.user_config_file).stem
),
"observation_settings": ert_config.analysis_config.observation_settings,
"es_settings": ert_config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
Expand Down
9 changes: 5 additions & 4 deletions src/ert/config/ert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def initializeAndRun(
self,
argument_types: list[type[Any]],
argument_values: list[str],
fixtures: dict[WorkflowFixtures, Any] | None = None,
fixtures: WorkflowFixtures | None = None,
**kwargs: dict[str, Any],
) -> Any:
fixtures = {} if fixtures is None else fixtures
Expand All @@ -96,7 +96,7 @@ def initializeAndRun(
else:
workflow_args.append(None)

fixtures[WorkflowFixtures.workflow_args] = workflow_args
fixtures["workflow_args"] = workflow_args

fixture_args = []
all_func_args = inspect.signature(self.run).parameters
Expand Down Expand Up @@ -173,14 +173,15 @@ def initializeAndRun(
def insert_fixtures(
self,
func_args: dict[str, inspect.Parameter],
fixtures: dict[WorkflowFixtures, Any],
fixtures: WorkflowFixtures,
kwargs: dict[str, Any],
) -> list[Any]:
arguments = []
errors = []
for val in func_args:
if val in fixtures:
arguments.append(fixtures[WorkflowFixtures(val)])
assert val in fixtures
arguments.append(fixtures.get(val))
elif val not in kwargs:
errors.append(val)
if errors:
Expand Down
31 changes: 19 additions & 12 deletions src/ert/config/workflow_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from __future__ import annotations

from enum import StrEnum


class WorkflowFixtures(StrEnum):
ensemble = "ensemble"
storage = "storage"
random_seed = "random_seed"
reports_dir = "reports_dir"
observation_settings = "observation_settings"
es_settings = "es_settings"
run_paths = "run_paths"
workflow_args = "workflow_args"
from typing import TYPE_CHECKING, Any

from typing_extensions import TypedDict

if TYPE_CHECKING:
from ert.config import ESSettings, UpdateSettings
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage


class WorkflowFixtures(TypedDict, total=False):
ensemble: Ensemble
storage: Storage
random_seed: int | None
reports_dir: str
observation_settings: UpdateSettings
es_settings: ESSettings
run_paths: Runpaths
workflow_args: list[Any]
3 changes: 1 addition & 2 deletions src/ert/gui/tools/export/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
if TYPE_CHECKING:
from ert.config import ErtConfig, WorkflowJob

from ert.config.workflow_fixtures import WorkflowFixtures
from ert.gui.ertnotifier import ErtNotifier
from ert.workflow_runner import WorkflowJobRunner

Expand All @@ -30,7 +29,7 @@ def run_export(self, parameters: list[Any]) -> None:

export_job_runner = WorkflowJobRunner(self.export_job)
user_warn = export_job_runner.run(
fixtures={WorkflowFixtures.storage: self._notifier.storage},
fixtures={"storage": self._notifier.storage},
arguments=parameters,
)
if export_job_runner.hasFailed():
Expand Down
2 changes: 1 addition & 1 deletion src/ert/gui/tools/plugins/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def getName(self) -> str:
def getDescription(self) -> str:
return self.__description

def getArguments(self, fixtures: dict[WorkflowFixtures, Any]) -> list[Any]:
def getArguments(self, fixtures: WorkflowFixtures) -> list[Any]:
"""
Returns a list of arguments. Either from GUI or from arbitrary code.
If the user for example cancels in the GUI a CancelPluginException is raised.
Expand Down
20 changes: 11 additions & 9 deletions src/ert/gui/tools/plugins/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ def run(self) -> None:
plugin = self.__plugin
arguments = plugin.getArguments(
fixtures={
WorkflowFixtures.storage: self.storage,
WorkflowFixtures.random_seed: ert_config.random_seed,
WorkflowFixtures.reports_dir: self.storage.path.parent
/ "reports"
/ Path(ert_config.user_config_file).stem,
WorkflowFixtures.observation_settings: ert_config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: ert_config.analysis_config.es_module,
WorkflowFixtures.run_paths: Runpaths(
"storage": self.storage,
"random_seed": ert_config.random_seed,
"reports_dir": str(
self.storage.path.parent
/ "reports"
/ Path(ert_config.user_config_file).stem
),
"observation_settings": ert_config.analysis_config.observation_settings,
"es_settings": ert_config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
Expand Down Expand Up @@ -88,7 +90,7 @@ def run(self) -> None:
print("Plugin cancelled before execution!")

def __runWorkflowJob(
self, arguments: list[Any] | None, fixtures: dict[WorkflowFixtures, Any]
self, arguments: list[Any] | None, fixtures: WorkflowFixtures
) -> None:
self.__result = self._runner.run(arguments, fixtures=fixtures)

Expand Down
21 changes: 11 additions & 10 deletions src/ert/gui/tools/workflows/run_workflow_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)

from _ert.threading import ErtThread
from ert.config.workflow_fixtures import WorkflowFixtures
from ert.gui.ertwidgets import EnsembleSelector
from ert.gui.tools.workflows.workflow_dialog import WorkflowDialog
from ert.runpaths import Runpaths
Expand Down Expand Up @@ -130,15 +129,17 @@ def startWorkflow(self) -> None:
self._workflow_runner = WorkflowRunner(
workflow=workflow,
fixtures={
WorkflowFixtures.ensemble: self.source_ensemble_selector.currentData(),
WorkflowFixtures.storage: self.storage,
WorkflowFixtures.random_seed: self.config.random_seed,
WorkflowFixtures.reports_dir: self.storage.path.parent
/ "reports"
/ Path(self.config.user_config_file).stem,
WorkflowFixtures.observation_settings: self.config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: self.config.analysis_config.es_module,
WorkflowFixtures.run_paths: Runpaths(
"ensemble": self.source_ensemble_selector.currentData(),
"storage": self.storage,
"random_seed": self.config.random_seed,
"reports_dir": str(
self.storage.path.parent
/ "reports"
/ Path(self.config.user_config_file).stem
),
"observation_settings": self.config.analysis_config.observation_settings,
"es_settings": self.config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=self.config.model_config.jobname_format_string,
runpath_format=self.config.model_config.runpath_format_string,
filename=str(self.config.runpath_file),
Expand Down
13 changes: 6 additions & 7 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ert.data._measured_data import ObservationError, ResponseError
from ert.load_status import LoadResult, LoadStatus

from .config.workflow_fixtures import WorkflowFixtures
from .plugins import ErtPluginContext

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -219,17 +218,17 @@ def run_ertscript( # type: ignore
[],
argument_values=args,
fixtures={
WorkflowFixtures.storage: storage,
WorkflowFixtures.ensemble: ensemble,
WorkflowFixtures.reports_dir: (
"storage": storage,
"ensemble": ensemble,
"reports_dir": (
storage.path.parent
/ "reports"
/ Path(str(self.user_config_file)).stem
/ ensemble.name
),
WorkflowFixtures.observation_settings: self.config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: self.config.analysis_config.es_module,
WorkflowFixtures.random_seed: self.config.random_seed,
"observation_settings": self.config.analysis_config.observation_settings,
"es_settings": self.config.analysis_config.es_module,
"random_seed": self.config.random_seed,
},
**kwargs,
)
Expand Down
42 changes: 19 additions & 23 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def validate_successful_realizations_count(self) -> None:
def run_workflows(
self,
runtime: HookRuntime,
fixtures: dict[WorkflowFixtures, Any],
fixtures: WorkflowFixtures,
) -> None:
for workflow in self._hooked_workflows[runtime]:
WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking()
Expand Down Expand Up @@ -713,13 +713,11 @@ def _evaluate_and_postprocess(
self.run_workflows(
HookRuntime.PRE_SIMULATION,
fixtures={
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: ensemble,
WorkflowFixtures.reports_dir: self.reports_dir(
ensemble_name=ensemble.name
),
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.run_paths: self.run_paths,
"storage": self._storage,
"ensemble": ensemble,
"reports_dir": self.reports_dir(ensemble_name=ensemble.name),
"random_seed": self.random_seed,
"run_paths": self.run_paths,
},
)
successful_realizations = self.run_ensemble_evaluator(
Expand Down Expand Up @@ -750,13 +748,11 @@ def _evaluate_and_postprocess(
self.run_workflows(
HookRuntime.POST_SIMULATION,
fixtures={
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: ensemble,
WorkflowFixtures.reports_dir: self.reports_dir(
ensemble_name=ensemble.name
),
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.run_paths: self.run_paths,
"storage": self._storage,
"ensemble": ensemble,
"reports_dir": self.reports_dir(ensemble_name=ensemble.name),
"random_seed": self.random_seed,
"run_paths": self.run_paths,
},
)

Expand Down Expand Up @@ -824,14 +820,14 @@ def update(
)
)

workflow_fixtures = {
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: prior,
WorkflowFixtures.observation_settings: self._update_settings,
WorkflowFixtures.es_settings: self._analysis_settings,
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.reports_dir: self.reports_dir(ensemble_name=prior.name),
WorkflowFixtures.run_paths: self.run_paths,
workflow_fixtures: WorkflowFixtures = {
"storage": self._storage,
"ensemble": prior,
"observation_settings": self._update_settings,
"es_settings": self._analysis_settings,
"random_seed": self.random_seed,
"reports_dir": self.reports_dir(ensemble_name=prior.name),
"run_paths": self.run_paths,
}

posterior = self._storage.create_ensemble(
Expand Down
9 changes: 4 additions & 5 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ert.storage import Ensemble, Experiment, Storage
from ert.trace import tracer

from ..config.workflow_fixtures import WorkflowFixtures
from ..run_arg import create_run_arguments
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents

Expand Down Expand Up @@ -96,7 +95,7 @@ def run_experiment(
if not restart:
self.run_workflows(
HookRuntime.PRE_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
fixtures={"random_seed": self.random_seed},
)
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
Expand Down Expand Up @@ -150,9 +149,9 @@ def run_experiment(
self.run_workflows(
HookRuntime.POST_EXPERIMENT,
fixtures={
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: self.ensemble,
"random_seed": self.random_seed,
"storage": self._storage,
"ensemble": self.ensemble,
},
)

Expand Down
9 changes: 4 additions & 5 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ert.storage import Storage
from ert.trace import tracer

from ..config.workflow_fixtures import WorkflowFixtures
from ..run_arg import create_run_arguments
from .base_run_model import StatusEvents, UpdateRunModel

Expand Down Expand Up @@ -77,7 +76,7 @@ def run_experiment(
self.restart = restart
self.run_workflows(
HookRuntime.PRE_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
fixtures={"random_seed": self.random_seed},
)
ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
Expand Down Expand Up @@ -127,9 +126,9 @@ def run_experiment(
self.run_workflows(
HookRuntime.POST_EXPERIMENT,
fixtures={
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: posterior,
"random_seed": self.random_seed,
"storage": self._storage,
"ensemble": posterior,
},
)

Expand Down
9 changes: 4 additions & 5 deletions src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ert.storage import Ensemble, Storage
from ert.trace import tracer

from ..config.workflow_fixtures import WorkflowFixtures
from ..run_arg import create_run_arguments
from .base_run_model import ErtRunError, StatusEvents, UpdateRunModel

Expand Down Expand Up @@ -119,7 +118,7 @@ def run_experiment(
else:
self.run_workflows(
HookRuntime.PRE_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
fixtures={"random_seed": self.random_seed},
)
sim_args = {"weights": self._relative_weights}
experiment = self._storage.create_experiment(
Expand Down Expand Up @@ -178,9 +177,9 @@ def run_experiment(
self.run_workflows(
HookRuntime.POST_EXPERIMENT,
fixtures={
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: prior,
"random_seed": self.random_seed,
"storage": self._storage,
"ensemble": prior,
},
)

Expand Down
Loading

0 comments on commit 0aeee68

Please sign in to comment.