Skip to content

Commit

Permalink
Refactor _forward_model_evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Feb 5, 2025
1 parent e514606 commit 05dc0fa
Showing 1 changed file with 49 additions and 44 deletions.
93 changes: 49 additions & 44 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,21 @@ def _on_before_forward_model_evaluation(
logging.getLogger(EVEREST).info("User abort requested.")
optimizer.abort_optimization()

def _forward_model_evaluator(
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> EvaluatorResult:
def _run_forward_model(
self,
control_values: NDArray[np.float64],
realizations: list[int],
active_control_vectors: list[bool],
) -> tuple[NDArray[np.float64], NDArray[np.float64] | None, NDArray[np.intc]]:
# Reset the current run status:
self._status = None

# Get cached_results:
cached_results = self._get_cached_results(control_values, evaluator_context)
cached_results = self._get_cached_results(control_values, realizations)

# Create the batch to run:
batch_data = self._init_batch_data(
control_values, evaluator_context, cached_results
control_values, active_control_vectors, cached_results
)

# Initialize a new ensemble in storage:
Expand All @@ -274,7 +277,7 @@ def _forward_model_evaluator(
self._setup_sim(sim_id, controls, ensemble)

# Evaluate the batch:
run_args = self._get_run_args(ensemble, evaluator_context, batch_data)
run_args = self._get_run_args(ensemble, realizations, batch_data)
self._context_env.update(
{
"_ERT_EXPERIMENT_ID": str(ensemble.experiment_id),
Expand All @@ -290,33 +293,53 @@ def _forward_model_evaluator(

# Gather the results and create the result for ropt:
results = self._gather_simulation_results(ensemble)
evaluator_result = self._make_evaluator_result(
objectives, constraints = self._get_objectives_and_constraints(
control_values, batch_data, results, cached_results
)

# Add the results from the evaluations to the cache:
self._add_results_to_cache(
control_values,
evaluator_context,
batch_data,
evaluator_result.objectives,
evaluator_result.constraints,
control_values, realizations, batch_data, objectives, constraints
)

# Increase the batch ID for the next evaluation:
self._batch_id += 1

return evaluator_result
sim_ids = np.full(control_values.shape[0], -1, dtype=np.intc)
sim_ids[list(batch_data.keys())] = np.arange(len(batch_data), dtype=np.intc)
return objectives, constraints, sim_ids

def _get_cached_results(
def _forward_model_evaluator(
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> EvaluatorResult:
realizations = [
self._everest_config.model.realizations[evaluator_context.realizations[idx]]
for idx in range(control_values.shape[0])
]
active_control_vectors = [
evaluator_context.active is None
or bool(evaluator_context.active[evaluator_context.realizations[idx]])
for idx in range(control_values.shape[0])
]
batch_id = self._batch_id # Save the batch ID, it will be modified.
objectives, constraints, sim_ids = self._run_forward_model(
control_values, realizations, active_control_vectors
)
return EvaluatorResult(
objectives=objectives,
constraints=constraints,
batch_id=batch_id,
evaluation_ids=sim_ids,
)

def _get_cached_results(
self, control_values: NDArray[np.float64], realizations: list[int]
) -> dict[int, Any]:
cached_results: dict[int, Any] = {}
if self._simulator_cache is not None:
for control_idx, real_idx in enumerate(evaluator_context.realizations):
for control_idx, realization in enumerate(realizations):
cached_data = self._simulator_cache.get(
self._everest_config.model.realizations[real_idx],
control_values[control_idx, :],
realization, control_values[control_idx, :]
)
if cached_data is not None:
cached_results[control_idx] = cached_data
Expand All @@ -325,7 +348,7 @@ def _get_cached_results(
def _init_batch_data(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
active_control_vectors: list[bool],
cached_results: dict[int, Any],
) -> dict[int, dict[str, Any]]:
def _add_controls(
Expand All @@ -348,15 +371,10 @@ def _add_controls(
batch_data_item[control.name] = control_dict
return batch_data_item

active = evaluator_context.active
realizations = evaluator_context.realizations
return {
idx: _add_controls(self._everest_config.controls, control_values[idx, :])
for idx in range(control_values.shape[0])
if (
idx not in cached_results
and (active is None or active[realizations[idx]])
)
if idx not in cached_results and active_control_vectors[idx]
}

def _setup_sim(
Expand Down Expand Up @@ -416,18 +434,14 @@ def _check_suffix(
def _get_run_args(
self,
ensemble: Ensemble,
evaluator_context: EvaluatorContext,
realizations: list[int],
batch_data: dict[int, Any],
) -> list[RunArg]:
substitutions = self._substitutions
substitutions["<BATCH_NAME>"] = ensemble.name
self.active_realizations = [True] * len(batch_data)
for sim_id, control_idx in enumerate(batch_data.keys()):
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(
self._everest_config.model.realizations[
evaluator_context.realizations[control_idx]
]
)
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(realizations[control_idx])
run_paths = Runpaths(
jobname_format=self._model_config.jobname_format_string,
runpath_format=self._model_config.runpath_format_string,
Expand Down Expand Up @@ -483,13 +497,13 @@ def _gather_simulation_results(
result[fnc_name] = result[alias]
return results

def _make_evaluator_result(
def _get_objectives_and_constraints(
self,
control_values: NDArray[np.float64],
batch_data: dict[int, Any],
results: list[dict[str, NDArray[np.float64]]],
cached_results: dict[int, Any],
) -> EvaluatorResult:
) -> tuple[NDArray[np.float64], NDArray[np.float64] | None]:
# We minimize the negative of the objectives:
objectives = -self._get_simulation_results(
results, self._everest_config.objective_names, control_values, batch_data
Expand All @@ -514,14 +528,7 @@ def _make_evaluator_result(
assert cached_constraints is not None
constraints[control_idx, ...] = cached_constraints

sim_ids = np.full(control_values.shape[0], -1, dtype=np.intc)
sim_ids[list(batch_data.keys())] = np.arange(len(batch_data), dtype=np.intc)
return EvaluatorResult(
objectives=objectives,
constraints=constraints,
batch_id=self._batch_id,
evaluation_ids=sim_ids,
)
return objectives, constraints

@staticmethod
def _get_simulation_results(
Expand All @@ -542,17 +549,15 @@ def _get_simulation_results(
def _add_results_to_cache(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
realizations: list[int],
batch_data: dict[int, Any],
objectives: NDArray[np.float64],
constraints: NDArray[np.float64] | None,
) -> None:
if self._simulator_cache is not None:
for control_idx in batch_data:
self._simulator_cache.add(
self._everest_config.model.realizations[
evaluator_context.realizations[control_idx]
],
realizations[control_idx],
control_values[control_idx, ...],
objectives[control_idx, ...],
None if constraints is None else constraints[control_idx, ...],
Expand Down

0 comments on commit 05dc0fa

Please sign in to comment.