Skip to content

Commit

Permalink
Split ERT 3 parameters into separate records
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkwah authored and Markus Fanebust Dregi committed Jun 23, 2021
1 parent 52efbba commit d26c243
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 121 deletions.
11 changes: 9 additions & 2 deletions ert3/config/_parameters_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,15 @@ class ParametersConfig(_ParametersConfig):
def __iter__(self) -> Iterator[_ParameterConfig]: # type: ignore
return iter(self.__root__)

def __getitem__(self, item: int) -> _ParameterConfig:
return self.__root__[item]
def __getitem__(self, item: Union[int, str]) -> _ParameterConfig:
if isinstance(item, int):
return self.__root__[item]
elif isinstance(item, str):
for group in self:
if group.name == item:
return group
raise ValueError(f"No parameter group found named: {item}")
raise TypeError(f"Item should be int or str, not {type(item)}")

def __len__(self) -> int:
return len(self.__root__)
Expand Down
12 changes: 1 addition & 11 deletions ert3/engine/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@ def load_record(workspace: Path, record_name: str, record_file: Path) -> None:
)


def _get_distribution(
parameter_group_name: str, parameters_config: ert3.config.ParametersConfig
) -> ert3.stats.Distribution:
for parameter_group in parameters_config:
if parameter_group.name == parameter_group_name:
return parameter_group.as_distribution()

raise ValueError(f"No parameter group found named: {parameter_group_name}")


# pylint: disable=too-many-arguments
def sample_record(
workspace: Path,
Expand All @@ -38,7 +28,7 @@ def sample_record(
ensemble_size: int,
experiment_name: Optional[str] = None,
) -> None:
distribution = _get_distribution(parameter_group_name, parameters_config)
distribution = parameters_config[parameter_group_name].as_distribution()
ensrecord = ert3.data.EnsembleRecord(
records=[distribution.sample() for _ in range(ensemble_size)]
)
Expand Down
56 changes: 49 additions & 7 deletions ert3/engine/_run.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,67 @@
import pathlib
from typing import List, Dict
from typing import List, Dict, Set, Union

import ert3


# Character used to separate record source "paths".
_SOURCE_SEPARATOR = "."


def _prepare_experiment(
workspace_root: pathlib.Path,
experiment_name: str,
ensemble: ert3.config.EnsembleConfig,
ensemble_size: int,
parameters_config: ert3.config.ParametersConfig,
) -> None:
if ert3.workspace.experiment_has_run(workspace_root, experiment_name):
raise ValueError(f"Experiment {experiment_name} have been carried out.")

parameter_names = [elem.record for elem in ensemble.input]
parameters: Dict[str, List[str]] = {}
for input_record in ensemble.input:
record_name = input_record.record
record_source = input_record.source.split(_SOURCE_SEPARATOR)
parameters[record_name] = _get_experiment_record_indices(
workspace_root, record_name, record_source, parameters_config
)

responses = [elem.record for elem in ensemble.output]

ert3.storage.init_experiment(
workspace=workspace_root,
experiment_name=experiment_name,
parameters=parameter_names,
parameters=parameters,
ensemble_size=ensemble_size,
responses=responses,
)


def _get_experiment_record_indices(
workspace_root: pathlib.Path,
record_name: str,
record_source: List[str],
parameters_config: ert3.config.ParametersConfig,
) -> List[str]:
assert len(record_source) == 2
source, source_record_name = record_source

if source == "storage":
ensemble_record = ert3.storage.get_ensemble_record(
workspace=workspace_root, record_name=source_record_name
)
indices: Set[Union[str, int]] = set()
for record in ensemble_record.records:
assert record.index is not None
indices.update(record.index)
return [str(x) for x in indices]

elif source == "stochastic":
return list(parameters_config[source_record_name].variables)

raise ValueError("Unknown record source location {}".format(source))


# pylint: disable=too-many-arguments
def _prepare_experiment_record(
record_name: str,
Expand Down Expand Up @@ -66,11 +104,13 @@ def _prepare_evaluation(
# This reassures mypy that the ensemble size is defined
assert ensemble.size is not None

_prepare_experiment(workspace_root, experiment_name, ensemble, ensemble.size)
_prepare_experiment(
workspace_root, experiment_name, ensemble, ensemble.size, parameters_config
)

for input_record in ensemble.input:
record_name = input_record.record
record_source = input_record.source.split(".")
record_source = input_record.source.split(_SOURCE_SEPARATOR)

_prepare_experiment_record(
record_name,
Expand All @@ -94,7 +134,7 @@ def _load_ensemble_parameters(
ensemble_parameters = {}
for input_record in ensemble.input:
record_name = input_record.record
record_source = input_record.source.split(".")
record_source = input_record.source.split(_SOURCE_SEPARATOR)
assert len(record_source) == 2
assert record_source[0] == "stochastic"
parameter_group_name = record_source[1]
Expand All @@ -113,7 +153,9 @@ def _prepare_sensitivity(
)
input_records = ert3.algorithms.one_at_the_time(parameter_distributions)

_prepare_experiment(workspace_root, experiment_name, ensemble, len(input_records))
_prepare_experiment(
workspace_root, experiment_name, ensemble, len(input_records), parameters_config
)

parameters: Dict[str, List[ert3.data.Record]] = {
param.record: [] for param in ensemble.input
Expand Down
Loading

0 comments on commit d26c243

Please sign in to comment.