Skip to content

Commit

Permalink
First
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Oct 28, 2024
1 parent 9817d75 commit 2478adb
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 59 deletions.
1 change: 1 addition & 0 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ def _create_list_of_forward_model_steps_to_run(
def forward_model_step_name_list(self) -> List[str]:
return [j.name for j in self.forward_model_steps]


def forward_model_data_to_json(
self,
run_id: Optional[str] = None,
Expand Down
26 changes: 20 additions & 6 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,24 @@
import time
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)

import orjson
from numpy.random import SeedSequence

from ert.config.model_config import ModelConfig
from ert.substitutions import Substitutions

from .config import (
ExtParamConfig,
Field,
Expand All @@ -22,7 +35,6 @@
from .runpaths import Runpaths

if TYPE_CHECKING:
from .config import ErtConfig
from .storage import Ensemble

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -183,20 +195,22 @@ def sample_prior(
def create_run_path(
run_args: List[RunArg],
ensemble: Ensemble,
ert_config: ErtConfig,
substitutions: Substitutions,
templates: List[Tuple[str, str]],
model_config: ModelConfig,
runpaths: Runpaths,
context_env: Optional[Dict[str, str]] = None,
) -> None:
if context_env is None:
context_env = {}
t = time.perf_counter()
substitutions = ert_config.substitutions
substitutions = substitutions
runpaths.set_ert_ensemble(ensemble.name)
for run_arg in run_args:
run_path = Path(run_arg.runpath)
if run_arg.active:
run_path.mkdir(parents=True, exist_ok=True)
for source_file, target_file in ert_config.ert_templates:
for source_file, target_file in templates:
target_file = substitutions.substitute_real_iter(
target_file, run_arg.iens, ensemble.iteration
)
Expand All @@ -220,7 +234,7 @@ def create_run_path(
)
target.write_text(result)

model_config = ert_config.model_config
model_config = model_config
_generate_parameter_files(
ensemble.experiment.parameter_configuration.values(),
model_config.gen_kw_export_name,
Expand Down
12 changes: 7 additions & 5 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,11 +669,13 @@ def _evaluate_and_postprocess(
evaluator_server_config: EvaluatorServerConfig,
) -> int:
create_run_path(
run_args,
ensemble,
self.ert_config,
self.run_paths,
self._context_env,
run_args=run_args,
ensemble=ensemble,
substitutions=self.ert_config.substitutions,
templates=self.ert_config.ert_templates,
model_config=self.ert_config.model_config,
runpaths=self.run_paths,
context_env=self._context_env,
)

self.run_workflows(HookRuntime.PRE_SIMULATION, self._storage, ensemble)
Expand Down
54 changes: 47 additions & 7 deletions src/ert/simulator/batch_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
import numpy as np

from ert.config import ErtConfig, ExtParamConfig
from ert.config.analysis_config import AnalysisConfig
from ert.config.model_config import ModelConfig
from ert.config.parameter_config import ParameterConfig
from ert.config.parsing.hook_runtime import HookRuntime
from ert.config.queue_config import QueueConfig
from ert.config.workflow import Workflow
from ert.substitutions import Substitutions

from .batch_simulator_context import BatchContext

Expand All @@ -25,7 +32,15 @@
class BatchSimulator:
def __init__(
self,
ert_config: ErtConfig,
perferred_num_cpu : int,
runpath_file: str,
parameter_configurations : List[ParameterConfig],
queue_config: QueueConfig,
model_config: ModelConfig,
analysis_config: AnalysisConfig,
hooked_workflows: Dict[HookRuntime, List[Workflow]],
substitutions: Substitutions,
templates: List[Tuple[str, str]],
experiment: Experiment,
controls: Iterable[str],
results: Iterable[str],
Expand Down Expand Up @@ -95,10 +110,15 @@ def callback(*args, **kwargs):
....
"""
if not isinstance(ert_config, ErtConfig):
raise ValueError("The first argument must be valid ErtConfig instance")

self.ert_config = ert_config
self.preferred_num_cpu= perferred_num_cpu
self.runpath_file = runpath_file
self.queue_config = queue_config
self.model_config = model_config
self.analysis_config = analysis_config
self.hooked_workflows = hooked_workflows
self.substitutions = substitutions
self.templates = templates
self.parameter_configurations = parameter_configurations
self.experiment = experiment
self.control_keys = set(controls)
self.result_keys = set(results)
Expand Down Expand Up @@ -143,7 +163,15 @@ def _check_suffix(
raise KeyError(err_msg)

for control_name, control in controls.items():
ext_config = self.ert_config.ensemble_config[control_name]
ext_config= self.parameter_configurations[control_name]




# fix this



if isinstance(ext_config, ExtParamConfig):
if len(ext_config) != len(control.keys()):
raise KeyError(
Expand Down Expand Up @@ -233,7 +261,19 @@ def start(
itr = 0
mask = np.full(len(case_data), True, dtype=bool)
sim_context = BatchContext(
self.result_keys, self.ert_config, ensemble, mask, itr, case_data
result_keys=self.result_keys,
ensemble=ensemble,
preferred_num_cpu=self.preferred_num_cpu,
runpath_file=self.runpath_file,
queue_config=self.queue_config,
model_config=self.model_config,
analysis_config=self.analysis_config,
hooked_workflows=self.hooked_workflows,
substitutions=self.substitutions,
templates=self.templates,
mask=mask,
itr=itr,
case_data=case_data,
)

if self.callback:
Expand Down
89 changes: 58 additions & 31 deletions src/ert/simulator/batch_simulator_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@

from _ert.threading import ErtThread
from ert.config import HookRuntime
from ert.config.analysis_config import AnalysisConfig
from ert.config.model_config import ModelConfig
from ert.config.queue_config import QueueConfig
from ert.config.workflow import Workflow
from ert.enkf_main import create_run_path
from ert.ensemble_evaluator import Realization
from ert.runpaths import Runpaths
from ert.scheduler import JobState, Scheduler, create_driver
from ert.substitutions import Substitutions
from ert.workflow_runner import WorkflowRunner

from ..run_arg import RunArg, create_run_arguments
Expand All @@ -28,7 +33,6 @@

import numpy.typing as npt

from ert.config import ErtConfig
from ert.storage import Ensemble

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,20 +75,30 @@ def _slug(entity: str) -> str:


def _run_forward_model(
ert_config: "ErtConfig",
# ert_config: "ErtConfig",
prefered_num_cpu: int,
queue_config: QueueConfig,
analysis_config: AnalysisConfig,
scheduler: Scheduler,
run_args: List[RunArg],
) -> None:
# run simplestep
asyncio.run(_submit_and_run_jobqueue(ert_config, scheduler, run_args))
asyncio.run(
_submit_and_run_jobqueue(
prefered_num_cpu, queue_config, analysis_config, scheduler, run_args
)
)


async def _submit_and_run_jobqueue(
ert_config: "ErtConfig",
# ert_config: "ErtConfig",
preferred_num_cpu: int,
queue_config: QueueConfig,
analysis_config: AnalysisConfig,
scheduler: Scheduler,
run_args: List[RunArg],
) -> None:
max_runtime: Optional[int] = ert_config.analysis_config.max_runtime
max_runtime: Optional[int] = analysis_config.max_runtime
if max_runtime == 0:
max_runtime = None
for run_arg in run_args:
Expand All @@ -96,23 +110,31 @@ async def _submit_and_run_jobqueue(
active=True,
max_runtime=max_runtime,
run_arg=run_arg,
num_cpu=ert_config.preferred_num_cpu,
job_script=ert_config.queue_config.job_script,
realization_memory=ert_config.queue_config.realization_memory,
num_cpu=preferred_num_cpu,
job_script=queue_config.job_script,
realization_memory=queue_config.realization_memory,
)
scheduler.set_realization(realization)

required_realizations = 0
if ert_config.queue_config.stop_long_running:
required_realizations = ert_config.analysis_config.minimum_required_realizations
if queue_config.stop_long_running:
required_realizations = analysis_config.minimum_required_realizations
with contextlib.suppress(asyncio.CancelledError):
await scheduler.execute(required_realizations)


@dataclass
class BatchContext:
result_keys: "Iterable[str]"
ert_config: "ErtConfig"
# ert_config: "ErtConfig"
preferred_num_cpu: int
queue_config: QueueConfig
model_config: ModelConfig
analysis_config: AnalysisConfig
hooked_workflows: Dict[HookRuntime, List[Workflow]]
substitutions: Substitutions
templates: List[Tuple[str, str]]
runpath_file: str
ensemble: Ensemble
mask: npt.NDArray[np.bool_]
itr: int
Expand All @@ -122,24 +144,23 @@ def __post_init__(self) -> None:
"""
Handle which can be used to query status and results for batch simulation.
"""
ert_config = self.ert_config
driver = create_driver(ert_config.queue_config)
self._scheduler = Scheduler(
driver, max_running=self.ert_config.queue_config.max_running
)
# ert_config = self.ert_config

driver = create_driver(self.queue_config)
self._scheduler = Scheduler(driver, max_running=self.queue_config.max_running)

# fill in the missing geo_id data
global_substitutions = self.ert_config.substitutions
global_substitutions["<CASE_NAME>"] = _slug(self.ensemble.name)
self.substitutions["<CASE_NAME>"] = _slug(self.ensemble.name)
for sim_id, (geo_id, _) in enumerate(self.case_data):
if self.mask[sim_id]:
global_substitutions[f"<GEO_ID_{sim_id}_{self.itr}>"] = str(geo_id)
self.substitutions[f"<GEO_ID_{sim_id}_{self.itr}>"] = str(geo_id)

run_paths = Runpaths(
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitutions=global_substitutions,
eclbase=ert_config.model_config.eclbase_format_string,
jobname_format=self.model_config.jobname_format_string,
runpath_format=self.model_config.runpath_format_string,
filename=str(self.runpath_file),
substitutions=self.substitutions,
eclbase=self.model_config.eclbase_format_string,
)
self.run_args = create_run_arguments(
run_paths,
Expand All @@ -152,13 +173,15 @@ def __post_init__(self) -> None:
"_ERT_SIMULATION_MODE": "batch_simulation",
}
create_run_path(
self.run_args,
self.ensemble,
ert_config,
run_paths,
context_env,
run_args=self.run_args,
ensemble=self.ensemble,
substitutions=self.substitutions,
templates=self.templates,
model_config=self.model_config,
runpaths=run_paths,
context_env=context_env,
)
for workflow in ert_config.hooked_workflows[HookRuntime.PRE_SIMULATION]:
for workflow in self.hooked_workflows[HookRuntime.PRE_SIMULATION]:
WorkflowRunner(workflow, None, self.ensemble).run_blocking()
self._sim_thread = self._run_simulations_simple_step()

Expand All @@ -176,7 +199,11 @@ def get_ensemble(self) -> Ensemble:
def _run_simulations_simple_step(self) -> Thread:
sim_thread = ErtThread(
target=lambda: _run_forward_model(
self.ert_config, self._scheduler, self.run_args
prefered_num_cpu=self.preferred_num_cpu,
queue_config=self.queue_config,
analysis_config=self.analysis_config,
scheduler=self._scheduler,
run_args=self.run_args,
)
)
sim_thread.start()
Expand Down
4 changes: 1 addition & 3 deletions src/ert/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from typing_extensions import Self

from ert.config import ErtConfig, ErtScript, ExternalErtScript, Workflow, WorkflowJob
from ert.config import ErtScript, ExternalErtScript, Workflow, WorkflowJob

if TYPE_CHECKING:
from ert.storage import Ensemble, Storage
Expand Down Expand Up @@ -111,12 +111,10 @@ def __init__(
workflow: Workflow,
storage: Optional[Storage] = None,
ensemble: Optional[Ensemble] = None,
ert_config: Optional[ErtConfig] = None,
) -> None:
self.__workflow = workflow
self.storage = storage
self.ensemble = ensemble
self.ert_config = ert_config

self.__workflow_result: Optional[bool] = None
self._workflow_executor = futures.ThreadPoolExecutor(max_workers=1)
Expand Down
16 changes: 15 additions & 1 deletion src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,21 @@ def start_server(config: EverestConfig, ert_config: ErtConfig, storage):
responses=[],
)

_server = BatchSimulator(ert_config, experiment, {}, [])
_server = BatchSimulator(
experiment=experiment,
perferred_num_cpu=ert_config.preferred_num_cpu,
runpath_file=str(ert_config.runpath_file),
parameter_configurations=ert_config.ensemble_config.parameter_configuration,
queue_config=ert_config.queue_config,
model_config=ert_config.model_config,
analysis_config=ert_config.analysis_config,
hooked_workflows=ert_config.hooked_workflows,
substitutions=ert_config.substitutions,
templates=ert_config.ert_templates,
controls={},
results=[],
)

_context = _server.start("dispatch_server", [(0, {})])

return _context
Expand Down
Loading

0 comments on commit 2478adb

Please sign in to comment.