Skip to content

Commit

Permalink
Stop using ert_config for workflows
Browse files Browse the repository at this point in the history
Co-authored-by: Jonathan Karlsen <[email protected]>
  • Loading branch information
yngve-sk and jonathan-eq committed Mar 4, 2025
1 parent 7b807f7 commit 688a1f8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 46 deletions.
21 changes: 9 additions & 12 deletions src/ert/gui/tools/plugins/plugin_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def run(self) -> None:
ert_config = self.ert_config
try:
plugin = self.__plugin
run_paths = Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitutions=ert_config.substitutions,
eclbase=ert_config.model_config.eclbase_format_string,
)
arguments = plugin.getArguments(
fixtures={
"storage": self.storage,
Expand All @@ -46,24 +53,14 @@ def run(self) -> None:
),
"observation_settings": ert_config.analysis_config.observation_settings,
"es_settings": ert_config.analysis_config.es_module,
"run_paths": Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitutions=ert_config.substitutions,
eclbase=ert_config.model_config.eclbase_format_string,
),
"run_paths": run_paths,
}
)
dialog = ProcessJobDialog(plugin.getName(), plugin.getParentWindow())
dialog.setObjectName("process_job_dialog")

dialog.cancelConfirmed.connect(self.cancel)
fixtures = {
k: getattr(self, k)
for k in ["storage", "run_paths"]
if getattr(self, k)
}
fixtures = {"storage": self.storage, "run_paths": run_paths}
workflow_job_thread = ErtThread(
name="ert_gui_workflow_job_thread",
target=self.__runWorkflowJob,
Expand Down
43 changes: 26 additions & 17 deletions src/ert/plugins/hook_implementations/workflows/csv_export.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import json
import os
from collections.abc import Sequence
from typing import TYPE_CHECKING

import pandas
import pandas as pd
import polars as pl

from ert import ErtScript, LibresFacade
from ert.config import ErtConfig
from ert.storage import Storage

if TYPE_CHECKING:
from ert.storage import Ensemble

def loadDesignMatrix(filename: str) -> pandas.DataFrame:
dm = pandas.read_csv(filename, delim_whitespace=True)

def loadDesignMatrix(filename: str) -> pd.DataFrame:
dm = pd.read_csv(filename, delim_whitespace=True)
dm = dm.rename(columns={dm.columns[0]: "Realization"})
dm = dm.set_index(["Realization"])
return dm
Expand Down Expand Up @@ -52,7 +56,6 @@ def getDescription() -> str:

def run(
self,
ert_config: ErtConfig,
storage: Storage,
workflow_args: Sequence[str],
) -> str:
Expand All @@ -61,17 +64,16 @@ def run(
design_matrix_path = None if len(workflow_args) < 3 else workflow_args[2]
_ = True if len(workflow_args) < 4 else workflow_args[3]
drop_const_cols = False if len(workflow_args) < 5 else workflow_args[4]
facade = LibresFacade(ert_config)

ensemble_data_as_dict = (
json.loads(ensemble_data_as_json) if ensemble_data_as_json else {}
)

# Use the keys (UUIDs as strings) to get ensembles
ensembles = []
ensembles: list[Ensemble] = []
for ensemble_id in ensemble_data_as_dict:
assert self.storage is not None
ensemble = self.storage.get_ensemble(ensemble_id)
assert storage is not None
ensemble = storage.get_ensemble(ensemble_id)
ensembles.append(ensemble)

if design_matrix_path is not None:
Expand All @@ -81,7 +83,7 @@ def run(
if not os.path.isfile(design_matrix_path):
raise UserWarning("The design matrix is not a file!")

data = pandas.DataFrame()
data = pd.DataFrame()

for ensemble in ensembles:
if not ensemble.has_data():
Expand All @@ -96,13 +98,20 @@ def run(
if not design_matrix_data.empty:
ensemble_data = ensemble_data.join(design_matrix_data, how="outer")

misfit_data = facade.load_all_misfit_data(ensemble)
misfit_data = LibresFacade.load_all_misfit_data(ensemble)
if not misfit_data.empty:
ensemble_data = ensemble_data.join(misfit_data, how="outer")
realizations = ensemble.get_realization_list_with_responses()

try:
summary_data = ensemble.load_responses("summary", tuple(realizations))
except (KeyError, ValueError):
summary_data = pl.DataFrame({})

summary_data = ensemble.load_all_summary_data()
if not summary_data.empty:
ensemble_data = ensemble_data.join(summary_data, how="outer")
if not summary_data.is_empty():
ensemble_data = ensemble_data.join(
summary_data.to_pandas(), how="outer"
)
else:
ensemble_data["Date"] = None
ensemble_data.set_index(["Date"], append=True, inplace=True)
Expand All @@ -113,9 +122,9 @@ def run(
["Ensemble", "Iteration"], append=True, inplace=True
)

data = pandas.concat([data, ensemble_data])

data = data.reorder_levels(["Realization", "Iteration", "Date", "Ensemble"])
data = pd.concat([data, ensemble_data])
if not data.empty:
data = data.reorder_levels(["Realization", "Iteration", "Date", "Ensemble"])
if drop_const_cols:
data = data.loc[:, (data != data.iloc[0]).any()]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ert.exceptions import StorageError

if TYPE_CHECKING:
from ert.config import ErtConfig
from ert.storage import Ensemble


Expand All @@ -23,17 +22,14 @@ class ExportMisfitDataJob(ErtScript):
((response_value - observation_data) / observation_std)**2
"""

def run(
self, ert_config: ErtConfig, ensemble: Ensemble, workflow_args: list[Any]
) -> None:
def run(self, ensemble: Ensemble, workflow_args: list[Any]) -> None:
target_file = "misfit.hdf" if not workflow_args else workflow_args[0]

realizations = ensemble.get_realization_list_with_responses()

from ert import LibresFacade # noqa: PLC0415 (circular import)

facade = LibresFacade(ert_config)
misfit = facade.load_all_misfit_data(ensemble)
misfit = LibresFacade.load_all_misfit_data(ensemble)
if len(realizations) == 0 or misfit.empty:
raise StorageError("No responses loaded")
misfit.columns = pd.Index([val.split(":")[1] for val in misfit.columns])
Expand Down
20 changes: 9 additions & 11 deletions src/ert/plugins/hook_implementations/workflows/export_runpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ert.validation import rangestring_to_list

if TYPE_CHECKING:
from ert.config import ErtConfig
from ert.storage import Ensemble


class ExportRunpathJob(ErtScript):
Expand All @@ -33,20 +33,18 @@ class ExportRunpathJob(ErtScript):
file.
"""

def run(self, ert_config: ErtConfig, workflow_args: list[Any]) -> None:
def run(
self, run_paths: Runpaths, ensemble: Ensemble, workflow_args: list[Any]
) -> None:
args = " ".join(workflow_args).split() # Make sure args is a list of words
run_paths = Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitutions=ert_config.substitutions,
eclbase=ert_config.model_config.eclbase_format_string,
)
assert ensemble
iter = ensemble.iteration
reals = ensemble.ensemble_size
run_paths.write_runpath_list(
*self.get_ranges(
args,
ert_config.analysis_config.num_iterations,
ert_config.model_config.num_realizations,
iter,
reals,
)
)

Expand Down

0 comments on commit 688a1f8

Please sign in to comment.