From c845d7ab6e8963290d7c1bd9b370c53e332c6e3d Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Mon, 10 Feb 2025 13:58:06 +0100 Subject: [PATCH] Address review --- src/ert/cli/workflow.py | 16 +++- .../tools/workflows/run_workflow_widget.py | 15 +++- src/ert/libres_facade.py | 19 +++-- src/ert/run_models/base_run_model.py | 78 ++++++++++++++--- src/ert/run_models/ensemble_experiment.py | 12 ++- src/ert/run_models/ensemble_smoother.py | 12 ++- .../run_models/multiple_data_assimilation.py | 12 ++- src/ert/workflow_runner.py | 26 +++--- .../unit_tests/cli/test_model_hook_order.py | 84 +++++++++++++++++-- .../workflow_runner/test_workflow_runner.py | 14 ++-- 10 files changed, 221 insertions(+), 67 deletions(-) diff --git a/src/ert/cli/workflow.py b/src/ert/cli/workflow.py index 8321f8d95d1..f7393e06a6c 100644 --- a/src/ert/cli/workflow.py +++ b/src/ert/cli/workflow.py @@ -1,8 +1,10 @@ from __future__ import annotations import logging +from pathlib import Path from typing import TYPE_CHECKING +from ert.run_models.base_run_model import WorkflowFixtures from ert.workflow_runner import WorkflowRunner if TYPE_CHECKING: @@ -20,11 +22,19 @@ def execute_workflow( msg = "Workflow {} is not in the list of available workflows" logger.error(msg.format(workflow_name)) return + runner = WorkflowRunner( workflow=workflow, - storage=storage, - ert_config=ert_config, - config_file=ert_config.user_config_file, + fixtures={ + WorkflowFixtures.storage: storage, + WorkflowFixtures.config_file: ert_config.user_config_file, + WorkflowFixtures.random_seed: ert_config.random_seed, + WorkflowFixtures.reports_dir: storage.path.parent + / "reports" + / Path(ert_config.user_config_file).stem, + WorkflowFixtures.observation_settings: ert_config.analysis_config.observation_settings, + WorkflowFixtures.es_settings: ert_config.analysis_config.es_module, + }, ) runner.run_blocking() if not all(v["completed"] for v in runner.workflowReport().values()): diff --git a/src/ert/gui/tools/workflows/run_workflow_widget.py b/src/ert/gui/tools/workflows/run_workflow_widget.py index ce06a91c66f..656bc9e4c12 100644 --- a/src/ert/gui/tools/workflows/run_workflow_widget.py +++ b/src/ert/gui/tools/workflows/run_workflow_widget.py @@ -2,6 +2,7 @@ import time from collections.abc import Iterable +from pathlib import Path from typing import TYPE_CHECKING from PyQt6.QtCore import QSize, Qt @@ -20,6 +21,7 @@ from _ert.threading import ErtThread from ert.gui.ertwidgets import EnsembleSelector from ert.gui.tools.workflows.workflow_dialog import WorkflowDialog +from ert.run_models.base_run_model import WorkflowFixtures from ert.workflow_runner import WorkflowRunner if TYPE_CHECKING: @@ -126,9 +128,16 @@ def startWorkflow(self) -> None: workflow = self.config.workflows[self.getCurrentWorkflowName()] self._workflow_runner = WorkflowRunner( workflow=workflow, - storage=self.storage, - ensemble=self.source_ensemble_selector.currentData(), - config_file=self.config.user_config_file, + fixtures={ + WorkflowFixtures.ensemble: self.source_ensemble_selector.currentData(), + WorkflowFixtures.storage: self.storage, + WorkflowFixtures.random_seed: self.config.random_seed, + WorkflowFixtures.reports_dir: self.storage.path.parent + / "reports" + / Path(self.config.user_config_file).stem, + WorkflowFixtures.observation_settings: self.config.analysis_config.observation_settings, + WorkflowFixtures.es_settings: self.config.analysis_config.es_module, + }, ) self._workflow_runner.run() diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index a64e0e029e1..7287fd6ad68 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -26,6 +26,7 @@ from ert.load_status import LoadResult, LoadStatus from .plugins import ErtPluginContext +from .run_models.base_run_model import WorkflowFixtures _logger = logging.getLogger(__name__) @@ -98,9 +99,6 @@ def get_field_parameters(self) -> list[str]: if isinstance(val, Field) ] - def get_gen_kw(self) -> list[str]: - return self.config.ensemble_config.get_keylist_gen_kw() - def get_ensemble_size(self) -> int: return self.config.model_config.num_realizations @@ -221,10 +219,17 @@ def run_ertscript( # type: ignore [], argument_values=args, fixtures={ - "ert_config": self.config, - "ensemble": ensemble, - "storage": storage, - "config_file": self.config.user_config_file, + WorkflowFixtures.storage: storage, + WorkflowFixtures.ensemble: ensemble, + WorkflowFixtures.reports_dir: ( + storage.path.parent + / "reports" + / Path(self.user_config_file).stem + / ensemble.name + ), + WorkflowFixtures.observation_settings: self.config.analysis_config.observation_settings, + WorkflowFixtures.es_settings: self.config.analysis_config.es_module, + WorkflowFixtures.random_seed: self.config.random_seed, }, **kwargs, ) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 058745c217a..42136249172 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -12,6 +12,7 @@ from collections import defaultdict from collections.abc import Generator, MutableSequence from contextlib import contextmanager +from enum import StrEnum from pathlib import Path from queue import SimpleQueue from typing import TYPE_CHECKING, Any, cast @@ -101,6 +102,16 @@ def emit(self, record: logging.LogRecord) -> None: self.messages.append(record.getMessage()) +class WorkflowFixtures(StrEnum): + ensemble = "ensemble" + storage = "storage" + config_file = "config_file" + random_seed = "random_seed" + reports_dir = "reports_dir" + observation_settings = "observation_settings" + es_settings = "es_settings" + + @contextmanager def captured_logs( messages: MutableSequence[str], level: int = logging.ERROR @@ -181,6 +192,14 @@ def __init__( self.start_iteration = start_iteration self.restart = False + def reports_dir(self, ensemble_name: str) -> str: + return ( + self._storage.path.parent + / "reports" + / Path(self._user_config_file).stem + / ensemble_name + ) + def log_at_startup(self) -> None: keys_to_drop = [ "_end_queue", @@ -677,16 +696,10 @@ def validate_successful_realizations_count(self) -> None: def run_workflows( self, runtime: HookRuntime, - storage: Storage | None = None, - ensemble: Ensemble | None = None, + fixtures: dict[WorkflowFixtures, Any], ) -> None: for workflow in self._hooked_workflows[runtime]: - WorkflowRunner( - workflow=workflow, - storage=storage, - ensemble=ensemble, - config_file=str(self._user_config_file), - ).run_blocking() + WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking() def _evaluate_and_postprocess( self, @@ -708,7 +721,17 @@ def _evaluate_and_postprocess( context_env=self._context_env, ) - self.run_workflows(HookRuntime.PRE_SIMULATION, self._storage, ensemble) + self.run_workflows( + HookRuntime.PRE_SIMULATION, + fixtures={ + WorkflowFixtures.storage: self._storage, + WorkflowFixtures.ensemble: ensemble, + WorkflowFixtures.reports_dir: self.reports_dir( + ensemble_name=ensemble.name + ), + WorkflowFixtures.random_seed: self.random_seed, + }, + ) successful_realizations = self.run_ensemble_evaluator( run_args, ensemble, @@ -734,7 +757,17 @@ def _evaluate_and_postprocess( f"{self.ensemble_size - num_successful_realizations}" ) logger.info(f"Experiment run finished in: {self.get_runtime()}s") - self.run_workflows(HookRuntime.POST_SIMULATION, self._storage, ensemble) + self.run_workflows( + HookRuntime.POST_SIMULATION, + fixtures={ + WorkflowFixtures.storage: self._storage, + WorkflowFixtures.ensemble: ensemble, + WorkflowFixtures.reports_dir: self.reports_dir( + ensemble_name=ensemble.name + ), + WorkflowFixtures.random_seed: self.random_seed, + }, + ) return num_successful_realizations @@ -799,6 +832,16 @@ def update( msg="Creating posterior ensemble..", ) ) + + workflow_fixtures = { + WorkflowFixtures.storage: self._storage, + WorkflowFixtures.ensemble: prior, + WorkflowFixtures.observation_settings: self._update_settings, + WorkflowFixtures.es_settings: self._analysis_settings, + WorkflowFixtures.random_seed: self.random_seed, + WorkflowFixtures.reports_dir: self.reports_dir(ensemble_name=prior.name), + } + posterior = self._storage.create_ensemble( prior.experiment, ensemble_size=prior.ensemble_size, @@ -807,8 +850,14 @@ def update( prior_ensemble=prior, ) if prior.iteration == 0: - self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior) - self.run_workflows(HookRuntime.PRE_UPDATE, self._storage, prior) + self.run_workflows( + HookRuntime.PRE_FIRST_UPDATE, + fixtures=workflow_fixtures, + ) + self.run_workflows( + HookRuntime.PRE_UPDATE, + fixtures=workflow_fixtures, + ) try: smoother_update( prior, @@ -830,5 +879,8 @@ def update( "Update algorithm failed for iteration:" f"{posterior.iteration}. The following error occurred: {e}" ) from e - self.run_workflows(HookRuntime.POST_UPDATE, self._storage, prior) + self.run_workflows( + HookRuntime.POST_UPDATE, + fixtures=workflow_fixtures, + ) return posterior diff --git a/src/ert/run_models/ensemble_experiment.py b/src/ert/run_models/ensemble_experiment.py index a7aac4d0563..14f3f56a0a8 100644 --- a/src/ert/run_models/ensemble_experiment.py +++ b/src/ert/run_models/ensemble_experiment.py @@ -14,7 +14,7 @@ from ert.trace import tracer from ..run_arg import create_run_arguments -from .base_run_model import BaseRunModel, ErtRunError, StatusEvents +from .base_run_model import BaseRunModel, ErtRunError, StatusEvents, WorkflowFixtures if TYPE_CHECKING: from ert.config import ErtConfig, QueueConfig @@ -93,7 +93,10 @@ def run_experiment( raise ErtRunError(str(exc)) from exc if not restart: - self.run_workflows(HookRuntime.PRE_EXPERIMENT) + self.run_workflows( + HookRuntime.PRE_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: self.random_seed}, + ) self.experiment = self._storage.create_experiment( name=self.experiment_name, parameters=( @@ -143,7 +146,10 @@ def run_experiment( self.ensemble, evaluator_server_config, ) - self.run_workflows(HookRuntime.POST_EXPERIMENT) + self.run_workflows( + HookRuntime.POST_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: self.random_seed}, + ) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/ensemble_smoother.py b/src/ert/run_models/ensemble_smoother.py index 1a6f8cf3382..0b2c704e831 100644 --- a/src/ert/run_models/ensemble_smoother.py +++ b/src/ert/run_models/ensemble_smoother.py @@ -14,7 +14,7 @@ from ert.trace import tracer from ..run_arg import create_run_arguments -from .base_run_model import StatusEvents, UpdateRunModel +from .base_run_model import StatusEvents, UpdateRunModel, WorkflowFixtures if TYPE_CHECKING: from ert.config import QueueConfig @@ -74,7 +74,10 @@ def run_experiment( ) -> None: self.log_at_startup() self.restart = restart - self.run_workflows(HookRuntime.PRE_EXPERIMENT) + self.run_workflows( + HookRuntime.PRE_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: self.random_seed}, + ) ensemble_format = self.target_ensemble_format experiment = self._storage.create_experiment( parameters=self._parameter_configuration, @@ -120,7 +123,10 @@ def run_experiment( posterior, evaluator_server_config, ) - self.run_workflows(HookRuntime.POST_EXPERIMENT) + self.run_workflows( + HookRuntime.POST_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: self.random_seed}, + ) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index 8b799a44af8..f3b6fa52c9c 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -15,7 +15,7 @@ from ert.trace import tracer from ..run_arg import create_run_arguments -from .base_run_model import ErtRunError, StatusEvents, UpdateRunModel +from .base_run_model import ErtRunError, StatusEvents, UpdateRunModel, WorkflowFixtures if TYPE_CHECKING: from ert.config import QueueConfig @@ -116,7 +116,10 @@ def run_experiment( f"Prior ensemble with ID: {id} does not exists" ) from err else: - self.run_workflows(HookRuntime.PRE_EXPERIMENT) + self.run_workflows( + HookRuntime.PRE_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: self.random_seed}, + ) sim_args = {"weights": self._relative_weights} experiment = self._storage.create_experiment( parameters=self._parameter_configuration, @@ -171,7 +174,10 @@ def run_experiment( ) prior = posterior - self.run_workflows(HookRuntime.POST_EXPERIMENT) + self.run_workflows( + HookRuntime.POST_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: self.random_seed}, + ) @staticmethod def parse_weights(weights: str) -> list[float]: diff --git a/src/ert/workflow_runner.py b/src/ert/workflow_runner.py index ddc5b7f711a..9b0f104f067 100644 --- a/src/ert/workflow_runner.py +++ b/src/ert/workflow_runner.py @@ -5,10 +5,15 @@ from concurrent.futures import Future from typing import TYPE_CHECKING, Any, Self -from ert.config import ErtConfig, ErtScript, ExternalErtScript, Workflow, WorkflowJob +from ert.config import ( + ErtScript, + ExternalErtScript, + Workflow, + WorkflowJob, +) if TYPE_CHECKING: - from ert.storage import Ensemble, Storage + pass class WorkflowJobRunner: @@ -107,16 +112,10 @@ class WorkflowRunner: def __init__( self, workflow: Workflow, - config_file: str | None = None, - storage: Storage | None = None, - ensemble: Ensemble | None = None, - ert_config: ErtConfig | None = None, + fixtures: dict[str, Any], ) -> None: self.__workflow = workflow - self.storage = storage - self.ensemble = ensemble - self.ert_config = ert_config # Should eventually be removed - self.config_file = config_file + self.fixtures = fixtures self.__workflow_result: bool | None = None self._workflow_executor = futures.ThreadPoolExecutor(max_workers=1) @@ -152,18 +151,13 @@ def run_blocking(self) -> None: # Reset status self.__status = {} self.__running = True - fixtures = { - k: getattr(self, k) - for k in ["storage", "ensemble", "ert_config", "config_file"] - if getattr(self, k) - } for job, args in self.__workflow: jobrunner = WorkflowJobRunner(job) self.__current_job = jobrunner if not self.__cancelled: logger.info(f"Workflow job {jobrunner.name} starting") - jobrunner.run(args, fixtures=fixtures) + jobrunner.run(args, fixtures=self.fixtures) self.__status[jobrunner.name] = { "stdout": jobrunner.stdoutdata(), "stderr": jobrunner.stderrdata(), diff --git a/tests/ert/unit_tests/cli/test_model_hook_order.py b/tests/ert/unit_tests/cli/test_model_hook_order.py index ea2b3872b7a..b408b2ac1e4 100644 --- a/tests/ert/unit_tests/cli/test_model_hook_order.py +++ b/tests/ert/unit_tests/cli/test_model_hook_order.py @@ -12,17 +12,83 @@ ensemble_smoother, multiple_data_assimilation, ) +from ert.run_models.base_run_model import WorkflowFixtures EXPECTED_CALL_ORDER = [ - call(HookRuntime.PRE_EXPERIMENT), - call(HookRuntime.PRE_SIMULATION, ANY, ANY), - call(HookRuntime.POST_SIMULATION, ANY, ANY), - call(HookRuntime.PRE_FIRST_UPDATE, ANY, ANY), - call(HookRuntime.PRE_UPDATE, ANY, ANY), - call(HookRuntime.POST_UPDATE, ANY, ANY), - call(HookRuntime.PRE_SIMULATION, ANY, ANY), - call(HookRuntime.POST_SIMULATION, ANY, ANY), - call(HookRuntime.POST_EXPERIMENT), + call(HookRuntime.PRE_EXPERIMENT, fixtures={WorkflowFixtures.random_seed: ANY}), + call( + HookRuntime.PRE_SIMULATION, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + }, + ), + call( + HookRuntime.POST_SIMULATION, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + }, + ), + call( + HookRuntime.PRE_FIRST_UPDATE, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + WorkflowFixtures.es_settings: ANY, + WorkflowFixtures.observation_settings: ANY, + }, + ), + call( + HookRuntime.PRE_UPDATE, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + WorkflowFixtures.es_settings: ANY, + WorkflowFixtures.observation_settings: ANY, + }, + ), + call( + HookRuntime.POST_UPDATE, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + WorkflowFixtures.es_settings: ANY, + WorkflowFixtures.observation_settings: ANY, + }, + ), + call( + HookRuntime.PRE_SIMULATION, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + }, + ), + call( + HookRuntime.POST_SIMULATION, + fixtures={ + WorkflowFixtures.storage: ANY, + WorkflowFixtures.ensemble: ANY, + WorkflowFixtures.reports_dir: ANY, + WorkflowFixtures.random_seed: ANY, + }, + ), + call( + HookRuntime.POST_EXPERIMENT, + fixtures={WorkflowFixtures.random_seed: ANY}, + ), ] diff --git a/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py b/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py index 95e68c73936..8530d301f62 100644 --- a/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py +++ b/tests/ert/unit_tests/workflow_runner/test_workflow_runner.py @@ -149,7 +149,7 @@ def test_workflow_run(): job, args = workflow[1] assert job.name == "DUMP" - WorkflowRunner(workflow).run_blocking() + WorkflowRunner(workflow, fixtures={}).run_blocking() with open("dump1", encoding="utf-8") as f: assert f.read() == "dump_text_1" @@ -169,7 +169,7 @@ def test_workflow_thread_cancel_ert_script(): assert len(workflow) == 3 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() @@ -208,7 +208,7 @@ def test_workflow_thread_cancel_external(): assert len(workflow) == 3 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() @@ -237,7 +237,7 @@ def test_workflow_failed_job(): workflow = Workflow.from_file("dump_workflow", Substitutions(), {"DUMP": dump_job}) assert len(workflow) == 2 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() with ( @@ -272,7 +272,7 @@ def test_workflow_success(): assert len(workflow) == 2 - workflow_runner = WorkflowRunner(workflow) + workflow_runner = WorkflowRunner(workflow, fixtures={}) assert not workflow_runner.isRunning() with workflow_runner: @@ -306,7 +306,7 @@ def test_workflow_stops_with_stopping_job(): job_dict={"DUMP": job_failing_dump}, ) - runner = WorkflowRunner(workflow) + runner = WorkflowRunner(workflow, fixtures={}) with pytest.raises(RuntimeError, match="Workflow job dump_failing_job failed"): runner.run_blocking() @@ -322,4 +322,4 @@ def test_workflow_stops_with_stopping_job(): ) # Expect no error raised - WorkflowRunner(workflow).run_blocking() + WorkflowRunner(workflow, fixtures={}).run_blocking()