Skip to content

Commit

Permalink
Commitin
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Aug 21, 2024
1 parent 4013764 commit b20ae03
Show file tree
Hide file tree
Showing 15 changed files with 301 additions and 419 deletions.
6 changes: 6 additions & 0 deletions src/ert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ def get_ert_parser(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
help="Name of the ensemble where the results for the "
"updated parameters will be stored.",
)
ensemble_smoother_parser.add_argument(
"--experiment-name",
type=str,
default="ensemble-experiment",
help="Name of the experiment",
)
ensemble_smoother_parser.add_argument(
"--realizations",
type=valid_realizations,
Expand Down
33 changes: 21 additions & 12 deletions src/ert/gui/ertnotifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional
from typing import List, Optional

from qtpy.QtCore import QObject, Signal, Slot

from ert.storage import Ensemble, Storage
from ert.storage import LocalEnsemble, LocalStorage


class ErtNotifier(QObject):
Expand All @@ -13,16 +13,16 @@ class ErtNotifier(QObject):
def __init__(self, config_file: str):
QObject.__init__(self)
self._config_file = config_file
self._storage: Optional[Storage] = None
self._current_ensemble: Optional[Ensemble] = None
self._storage: Optional[LocalStorage] = None
self._current_ensemble: Optional[LocalEnsemble] = None
self._is_simulation_running = False

@property
def is_storage_available(self) -> bool:
return self._storage is not None

@property
def storage(self) -> Storage:
def storage(self) -> LocalStorage:
assert self.is_storage_available
return self._storage # type: ignore

Expand All @@ -31,11 +31,11 @@ def config_file(self) -> str:
return self._config_file

@property
def current_ensemble(self) -> Optional[Ensemble]:
if self._current_ensemble is None and self._storage is not None:
ensembles = list(self._storage.ensembles)
if ensembles:
self._current_ensemble = ensembles[0]
def current_ensemble(self) -> Optional[LocalEnsemble]:
if self._current_ensemble is None and self.is_storage_available:
all_ensembles = self.get_all_ensembles()
if all_ensembles:
self._current_ensemble = all_ensembles[0]
return self._current_ensemble

@property
Expand All @@ -53,15 +53,24 @@ def emitErtChange(self) -> None:
self.ertChanged.emit()

@Slot(object)
def set_storage(self, storage: Storage) -> None:
def set_storage(self, storage: LocalStorage) -> None:
self._storage = storage
self._current_ensemble = None
self.storage_changed.emit(storage)

@Slot(object)
def set_current_ensemble(self, ensemble: Optional[Ensemble] = None) -> None:
def set_current_ensemble(self, ensemble: Optional[LocalEnsemble] = None) -> None:
self._current_ensemble = ensemble
self.current_ensemble_changed.emit(ensemble)

@Slot(bool)
def set_is_simulation_running(self, is_running: bool) -> None:
self._is_simulation_running = is_running

def get_all_ensembles(self) -> List[LocalEnsemble]:
if not self.is_storage_available:
return []
all_ensembles = []
for experiment in self.storage.experiments:
all_ensembles.extend(list(experiment.ensembles))
return all_ensembles
52 changes: 22 additions & 30 deletions src/ert/gui/ertwidgets/ensembleselector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ert.storage.realization_storage_state import RealizationStorageState

if TYPE_CHECKING:
from ert.storage import Ensemble
from ert.storage import LocalEnsemble


class EnsembleSelector(QComboBox):
Expand All @@ -24,26 +24,19 @@ def __init__(
):
super().__init__()
self.notifier = notifier

# If true current ensemble of ert will be change
self._update_ert = update_ert
# only show initialized ensembles
self._show_only_undefined = show_only_undefined
# If True, we filter out any ensembles which have children
# One use case is if a user wants to rerun because of failures
# not related to parameterization. We can allow that, but only
# if the ensemble has not been used in an update, as that would
# invalidate the result
# invalidate the result.
self._show_only_no_children = show_only_no_children
self.setSizeAdjustPolicy(QComboBox.AdjustToContents)

self.setEnabled(False)

if update_ert:
# Update ERT when this combo box is changed
self.currentIndexChanged[int].connect(self._on_current_index_changed)

# Update this combo box when ERT is changed
notifier.current_ensemble_changed.connect(
self._on_global_current_ensemble_changed
)
Expand All @@ -55,52 +48,51 @@ def __init__(
self.populate()

@property
def selected_ensemble(self) -> Ensemble:
def selected_ensemble(self) -> LocalEnsemble:
return self.itemData(self.currentIndex())

def populate(self) -> None:
block = self.blockSignals(True)

self.clear()

if self._ensemble_list():
self.setEnabled(True)

for ensemble in self._ensemble_list():
self.addItem(ensemble.name, userData=ensemble)

current_index = self.findData(
self.notifier.current_ensemble, Qt.ItemDataRole.UserRole
)

self.setCurrentIndex(max(current_index, 0))

self.blockSignals(block)

self.ensemble_populated.emit()

def _ensemble_list(self) -> Iterable[Ensemble]:
def _ensemble_list(self) -> Iterable[LocalEnsemble]:
if not self.notifier.is_storage_available:
return []

all_ensembles = []
for experiment in self.notifier.storage.experiments:
all_ensembles.extend(experiment.ensembles)

if self._show_only_undefined:
ensembles = (
all_ensembles = [
ensemble
for ensemble in self.notifier.storage.ensembles
for ensemble in all_ensembles
if all(
e == RealizationStorageState.UNDEFINED
for e in ensemble.get_ensemble_state()
)
)
else:
ensembles = self.notifier.storage.ensembles
ensemble_list = list(ensembles)
if self._show_only_no_children:
parents = [
ens.parent for ens in self.notifier.storage.ensembles if ens.parent
]
ensemble_list = [val for val in ensemble_list if val.id not in parents]
return sorted(ensemble_list, key=lambda x: x.started_at, reverse=True)

if self._show_only_no_children:
parents = [ens.parent for ens in all_ensembles if ens.parent]
all_ensembles = [val for val in all_ensembles if val.id not in parents]

return sorted(all_ensembles, key=lambda x: x.started_at, reverse=True)

def _on_current_index_changed(self, index: int) -> None:
self.notifier.set_current_ensemble(self.itemData(index))

def _on_global_current_ensemble_changed(self, data: Optional[Ensemble]) -> None:
def _on_global_current_ensemble_changed(
self, data: Optional[LocalEnsemble]
) -> None:
self.setCurrentIndex(max(self.findData(data, Qt.ItemDataRole.UserRole), 0))
11 changes: 5 additions & 6 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
from ert.mode_definitions import MODULE_MODE
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage
from ert.storage import Ensemble, LocalEnsemble, Storage
from ert.workflow_runner import WorkflowRunner

from ..config.analysis_config import UpdateSettings
Expand Down Expand Up @@ -671,8 +671,8 @@ def __init__(
)

def update(
self, prior: Ensemble, posterior_name: str, weight: float = 1.0
) -> Ensemble:
self, prior: LocalEnsemble, posterior_name: str, weight: float = 1.0
) -> LocalEnsemble:
self.validate()
self.send_event(
RunModelUpdateBeginEvent(iteration=prior.iteration, run_id=prior.id)
Expand All @@ -684,11 +684,10 @@ def update(
msg="Creating posterior ensemble..",
)
)
posterior = self._storage.create_ensemble(
prior.experiment,
posterior = prior.experiment.create_ensemble(
name=posterior_name,
ensemble_size=prior.ensemble_size,
iteration=prior.iteration + 1,
name=posterior_name,
prior_ensemble=prior,
)
if prior.iteration == 0:
Expand Down
14 changes: 7 additions & 7 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage
from ert.storage import LocalEnsemble, LocalExperiment, LocalStorage

from ..run_arg import create_run_arguments
from .base_run_model import BaseRunModel, StatusEvents
Expand All @@ -31,15 +31,14 @@ def __init__(
minimum_required_realizations: int,
random_seed: Optional[int],
config: ErtConfig,
storage: Storage,
storage: LocalStorage,
queue_config: QueueConfig,
status_queue: SimpleQueue[StatusEvents],
):
self.ensemble_name = ensemble_name
self.experiment_name = experiment_name
self.experiment: Experiment | None = None
self.ensemble: Ensemble | None = None

self.experiment: LocalExperiment | None = None
self.ensemble: LocalEnsemble | None = None
super().__init__(
config,
storage,
Expand All @@ -63,8 +62,8 @@ def run_experiment(
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
)
self.ensemble = self._storage.create_ensemble(
self.experiment,
assert self.experiment
self.ensemble = self.experiment.create_ensemble(
name=self.ensemble_name,
ensemble_size=self.ensemble_size,
)
Expand All @@ -82,6 +81,7 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)

sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
Expand Down
19 changes: 7 additions & 12 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ert.config import ErtConfig
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Storage
from ert.storage import LocalStorage

from ..config.analysis_config import UpdateSettings
from ..config.analysis_module import ESSettings
Expand All @@ -19,7 +19,6 @@
if TYPE_CHECKING:
from ert.config import QueueConfig


logger = logging.getLogger(__name__)


Expand All @@ -32,7 +31,7 @@ def __init__(
minimum_required_realizations: int,
random_seed: Optional[int],
config: ErtConfig,
storage: Storage,
storage: LocalStorage,
queue_config: QueueConfig,
es_settings: ESSettings,
update_settings: UpdateSettings,
Expand All @@ -53,7 +52,6 @@ def __init__(
)
self.target_ensemble_format = target_ensemble
self.experiment_name = experiment_name

self.support_restart = False

def run_experiment(
Expand All @@ -67,39 +65,36 @@ def run_experiment(
responses=self.ert_config.ensemble_config.response_configuration,
name=self.experiment_name,
)

self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id))
prior = self._storage.create_ensemble(
experiment,
ensemble_size=self.ensemble_size,

prior = experiment.create_ensemble(
name=ensemble_format % 0,
ensemble_size=self.ensemble_size,
)
self.set_env_key("_ERT_ENSEMBLE_ID", str(prior.id))

prior_args = create_run_arguments(
self.run_paths,
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)

sample_prior(
prior,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

self._evaluate_and_postprocess(
prior_args,
prior,
evaluator_server_config,
)
posterior = self.update(prior, ensemble_format % 1)

posterior = self.update(prior, ensemble_format % 1)
posterior_args = create_run_arguments(
self.run_paths,
np.array(self.active_realizations, dtype=bool),
ensemble=posterior,
)

self._evaluate_and_postprocess(
posterior_args,
posterior,
Expand Down
12 changes: 4 additions & 8 deletions src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,9 @@ def run_experiment(
responses=self.ert_config.ensemble_config.response_configuration,
name=self.experiment_name,
)
prior = self._storage.create_ensemble(
experiment=experiment,
ensemble_size=self.ensemble_size,
prior = experiment.create_ensemble(
name=target_ensemble_format % 0,
ensemble_size=self.ensemble_size,
)
self.set_env_key("_ERT_ENSEMBLE_ID", str(prior.id))
self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id))
Expand Down Expand Up @@ -172,9 +171,8 @@ def run_experiment(
)
)

posterior = self._storage.create_ensemble(
experiment,
name=target_ensemble_format % (prior_iter + 1), # noqa
posterior = experiment.create_ensemble(
name=target_ensemble_format % (prior_iter + 1),
ensemble_size=prior.ensemble_size,
iteration=prior_iter + 1,
prior_ensemble=prior,
Expand All @@ -194,8 +192,6 @@ def run_experiment(
initial_mask=initial_mask,
)

# sies iteration starts at 1, we keep iters at 0,
# so we subtract sies to be 0-indexed
analysis_success = prior_iter < (self.sies_iteration - 1)
if analysis_success:
update_success = True
Expand Down
Loading

0 comments on commit b20ae03

Please sign in to comment.