Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Feb 10, 2025
1 parent ab916da commit c845d7a
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 67 deletions.
16 changes: 13 additions & 3 deletions src/ert/cli/workflow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

from ert.run_models.base_run_model import WorkflowFixtures
from ert.workflow_runner import WorkflowRunner

if TYPE_CHECKING:
Expand All @@ -20,11 +22,19 @@ def execute_workflow(
msg = "Workflow {} is not in the list of available workflows"
logger.error(msg.format(workflow_name))
return

runner = WorkflowRunner(
workflow=workflow,
storage=storage,
ert_config=ert_config,
config_file=ert_config.user_config_file,
fixtures={
WorkflowFixtures.storage: storage,
WorkflowFixtures.config_file: ert_config.user_config_file,
WorkflowFixtures.random_seed: ert_config.random_seed,
WorkflowFixtures.reports_dir: storage.path.parent
/ "reports"
/ Path(ert_config.user_config_file).stem,
WorkflowFixtures.observation_settings: ert_config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: ert_config.analysis_config.es_module,
},
)
runner.run_blocking()
if not all(v["completed"] for v in runner.workflowReport().values()):
Expand Down
15 changes: 12 additions & 3 deletions src/ert/gui/tools/workflows/run_workflow_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import time
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING

from PyQt6.QtCore import QSize, Qt
Expand All @@ -20,6 +21,7 @@
from _ert.threading import ErtThread
from ert.gui.ertwidgets import EnsembleSelector
from ert.gui.tools.workflows.workflow_dialog import WorkflowDialog
from ert.run_models.base_run_model import WorkflowFixtures
from ert.workflow_runner import WorkflowRunner

if TYPE_CHECKING:
Expand Down Expand Up @@ -126,9 +128,16 @@ def startWorkflow(self) -> None:
workflow = self.config.workflows[self.getCurrentWorkflowName()]
self._workflow_runner = WorkflowRunner(
workflow=workflow,
storage=self.storage,
ensemble=self.source_ensemble_selector.currentData(),
config_file=self.config.user_config_file,
fixtures={
WorkflowFixtures.ensemble: self.source_ensemble_selector.currentData(),
WorkflowFixtures.storage: self.storage,
WorkflowFixtures.random_seed: self.config.random_seed,
WorkflowFixtures.reports_dir: self.storage.path.parent
/ "reports"
/ Path(self.config.user_config_file).stem,
WorkflowFixtures.observation_settings: self.config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: self.config.analysis_config.es_module,
},
)
self._workflow_runner.run()

Expand Down
19 changes: 12 additions & 7 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ert.load_status import LoadResult, LoadStatus

from .plugins import ErtPluginContext
from .run_models.base_run_model import WorkflowFixtures

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,9 +99,6 @@ def get_field_parameters(self) -> list[str]:
if isinstance(val, Field)
]

def get_gen_kw(self) -> list[str]:
return self.config.ensemble_config.get_keylist_gen_kw()

def get_ensemble_size(self) -> int:
return self.config.model_config.num_realizations

Expand Down Expand Up @@ -221,10 +219,17 @@ def run_ertscript( # type: ignore
[],
argument_values=args,
fixtures={
"ert_config": self.config,
"ensemble": ensemble,
"storage": storage,
"config_file": self.config.user_config_file,
WorkflowFixtures.storage: storage,
WorkflowFixtures.ensemble: ensemble,
WorkflowFixtures.reports_dir: (
storage.path.parent
/ "reports"
/ Path(self.user_config_file).stem

Check failure on line 227 in src/ert/libres_facade.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Argument 1 to "Path" has incompatible type "str | None"; expected "str | PathLike[str]"
/ ensemble.name
),
WorkflowFixtures.observation_settings: self.config.analysis_config.observation_settings,
WorkflowFixtures.es_settings: self.config.analysis_config.es_module,
WorkflowFixtures.random_seed: self.config.random_seed,
},
**kwargs,
)
Expand Down
78 changes: 65 additions & 13 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import defaultdict
from collections.abc import Generator, MutableSequence
from contextlib import contextmanager
from enum import StrEnum
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -101,6 +102,16 @@ def emit(self, record: logging.LogRecord) -> None:
self.messages.append(record.getMessage())


class WorkflowFixtures(StrEnum):
ensemble = "ensemble"
storage = "storage"
config_file = "config_file"
random_seed = "random_seed"
reports_dir = "reports_dir"
observation_settings = "observation_settings"
es_settings = "es_settings"


@contextmanager
def captured_logs(
messages: MutableSequence[str], level: int = logging.ERROR
Expand Down Expand Up @@ -181,6 +192,14 @@ def __init__(
self.start_iteration = start_iteration
self.restart = False

def reports_dir(self, ensemble_name: str) -> str:
return (
self._storage.path.parent

Check failure on line 197 in src/ert/run_models/base_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Incompatible return value type (got "Path", expected "str")
/ "reports"
/ Path(self._user_config_file).stem
/ ensemble_name
)

def log_at_startup(self) -> None:
keys_to_drop = [
"_end_queue",
Expand Down Expand Up @@ -677,16 +696,10 @@ def validate_successful_realizations_count(self) -> None:
def run_workflows(
self,
runtime: HookRuntime,
storage: Storage | None = None,
ensemble: Ensemble | None = None,
fixtures: dict[WorkflowFixtures, Any],
) -> None:
for workflow in self._hooked_workflows[runtime]:
WorkflowRunner(
workflow=workflow,
storage=storage,
ensemble=ensemble,
config_file=str(self._user_config_file),
).run_blocking()
WorkflowRunner(workflow=workflow, fixtures=fixtures).run_blocking()

Check failure on line 702 in src/ert/run_models/base_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Argument "fixtures" to "WorkflowRunner" has incompatible type "dict[WorkflowFixtures, Any]"; expected "dict[str, Any]"

def _evaluate_and_postprocess(
self,
Expand All @@ -708,7 +721,17 @@ def _evaluate_and_postprocess(
context_env=self._context_env,
)

self.run_workflows(HookRuntime.PRE_SIMULATION, self._storage, ensemble)
self.run_workflows(
HookRuntime.PRE_SIMULATION,
fixtures={
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: ensemble,
WorkflowFixtures.reports_dir: self.reports_dir(
ensemble_name=ensemble.name
),
WorkflowFixtures.random_seed: self.random_seed,
},
)
successful_realizations = self.run_ensemble_evaluator(
run_args,
ensemble,
Expand All @@ -734,7 +757,17 @@ def _evaluate_and_postprocess(
f"{self.ensemble_size - num_successful_realizations}"
)
logger.info(f"Experiment run finished in: {self.get_runtime()}s")
self.run_workflows(HookRuntime.POST_SIMULATION, self._storage, ensemble)
self.run_workflows(
HookRuntime.POST_SIMULATION,
fixtures={
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: ensemble,
WorkflowFixtures.reports_dir: self.reports_dir(
ensemble_name=ensemble.name
),
WorkflowFixtures.random_seed: self.random_seed,
},
)

return num_successful_realizations

Expand Down Expand Up @@ -799,6 +832,16 @@ def update(
msg="Creating posterior ensemble..",
)
)

workflow_fixtures = {
WorkflowFixtures.storage: self._storage,
WorkflowFixtures.ensemble: prior,
WorkflowFixtures.observation_settings: self._update_settings,
WorkflowFixtures.es_settings: self._analysis_settings,
WorkflowFixtures.random_seed: self.random_seed,
WorkflowFixtures.reports_dir: self.reports_dir(ensemble_name=prior.name),
}

posterior = self._storage.create_ensemble(
prior.experiment,
ensemble_size=prior.ensemble_size,
Expand All @@ -807,8 +850,14 @@ def update(
prior_ensemble=prior,
)
if prior.iteration == 0:
self.run_workflows(HookRuntime.PRE_FIRST_UPDATE, self._storage, prior)
self.run_workflows(HookRuntime.PRE_UPDATE, self._storage, prior)
self.run_workflows(
HookRuntime.PRE_FIRST_UPDATE,
fixtures=workflow_fixtures,
)
self.run_workflows(
HookRuntime.PRE_UPDATE,
fixtures=workflow_fixtures,
)
try:
smoother_update(
prior,
Expand All @@ -830,5 +879,8 @@ def update(
"Update algorithm failed for iteration:"
f"{posterior.iteration}. The following error occurred: {e}"
) from e
self.run_workflows(HookRuntime.POST_UPDATE, self._storage, prior)
self.run_workflows(
HookRuntime.POST_UPDATE,
fixtures=workflow_fixtures,
)
return posterior
12 changes: 9 additions & 3 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.trace import tracer

from ..run_arg import create_run_arguments
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents, WorkflowFixtures

if TYPE_CHECKING:
from ert.config import ErtConfig, QueueConfig
Expand Down Expand Up @@ -93,7 +93,10 @@ def run_experiment(
raise ErtRunError(str(exc)) from exc

if not restart:
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
self.run_workflows(
HookRuntime.PRE_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
)
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
parameters=(
Expand Down Expand Up @@ -143,7 +146,10 @@ def run_experiment(
self.ensemble,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT)
self.run_workflows(
HookRuntime.POST_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
)

@classmethod
def name(cls) -> str:
Expand Down
12 changes: 9 additions & 3 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ert.trace import tracer

from ..run_arg import create_run_arguments
from .base_run_model import StatusEvents, UpdateRunModel
from .base_run_model import StatusEvents, UpdateRunModel, WorkflowFixtures

if TYPE_CHECKING:
from ert.config import QueueConfig
Expand Down Expand Up @@ -74,7 +74,10 @@ def run_experiment(
) -> None:
self.log_at_startup()
self.restart = restart
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
self.run_workflows(
HookRuntime.PRE_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
)
ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self._parameter_configuration,
Expand Down Expand Up @@ -120,7 +123,10 @@ def run_experiment(
posterior,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT)
self.run_workflows(
HookRuntime.POST_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
)

@classmethod
def name(cls) -> str:
Expand Down
12 changes: 9 additions & 3 deletions src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ert.trace import tracer

from ..run_arg import create_run_arguments
from .base_run_model import ErtRunError, StatusEvents, UpdateRunModel
from .base_run_model import ErtRunError, StatusEvents, UpdateRunModel, WorkflowFixtures

if TYPE_CHECKING:
from ert.config import QueueConfig
Expand Down Expand Up @@ -116,7 +116,10 @@ def run_experiment(
f"Prior ensemble with ID: {id} does not exists"
) from err
else:
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
self.run_workflows(
HookRuntime.PRE_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
)
sim_args = {"weights": self._relative_weights}
experiment = self._storage.create_experiment(
parameters=self._parameter_configuration,
Expand Down Expand Up @@ -171,7 +174,10 @@ def run_experiment(
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT)
self.run_workflows(
HookRuntime.POST_EXPERIMENT,
fixtures={WorkflowFixtures.random_seed: self.random_seed},
)

@staticmethod
def parse_weights(weights: str) -> list[float]:
Expand Down
Loading

0 comments on commit c845d7a

Please sign in to comment.