Skip to content

Commit

Permalink
Create end point for events
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jan 29, 2025
1 parent dfea8b3 commit b446199
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 407 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ types = [
"types-setuptools",
]
everest = [
"websockets",
"progressbar2",
"ruamel.yaml",
"fastapi",
Expand Down
95 changes: 6 additions & 89 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from seba_sqlite import SqliteStorage
from typing_extensions import TypedDict

from _ert.events import EESnapshot, EESnapshotUpdate, Event
from ert.config import ErtConfig, ExtParamConfig
from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.runpaths import Runpaths
from ert.storage import open_storage
from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig
Expand Down Expand Up @@ -102,12 +101,12 @@ def __init__(
self,
config: ErtConfig,
everest_config: EverestConfig,
simulation_callback: SimulationCallback | None,
optimization_callback: OptimizerCallback | None,
status_queue: queue.SimpleQueue[StatusEvents] | None = None,
):
Path(everest_config.log_dir).mkdir(parents=True, exist_ok=True)
Path(everest_config.optimization_output_dir).mkdir(parents=True, exist_ok=True)

status_queue = queue.SimpleQueue() if status_queue is None else status_queue
assert everest_config.environment is not None
logging.getLogger(EVEREST).info(
"Using random seed: %d. To deterministically reproduce this experiment, "
Expand All @@ -118,7 +117,6 @@ def __init__(
self._everest_config = everest_config
self._ropt_config = everest2ropt(everest_config)

self._sim_callback = simulation_callback
self._opt_callback = optimization_callback
self._fm_errors: dict[int, dict[str, Any]] = {}
self._result: OptimalResult | None = None
Expand All @@ -134,11 +132,8 @@ def __init__(
self._experiment: Experiment | None = None
self._eval_server_cfg: EvaluatorServerConfig | None = None
self._batch_id: int = 0
self._status: SimulationStatus | None = None

storage = open_storage(config.ens_path, mode="w")
status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue()

super().__init__(
storage,
config.runpath_file,
Expand All @@ -165,12 +160,13 @@ def create(
ever_config: EverestConfig,
simulation_callback: SimulationCallback | None = None,
optimization_callback: OptimizerCallback | None = None,
status_queue: queue.SimpleQueue[StatusEvents] | None = None,
) -> EverestRunModel:
return cls(
config=everest_to_ert_config(ever_config),
everest_config=ever_config,
simulation_callback=simulation_callback,
optimization_callback=optimization_callback,
status_queue=status_queue,
)

@classmethod
Expand Down Expand Up @@ -359,9 +355,6 @@ def _on_before_forward_model_evaluation(
def _forward_model_evaluator(
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> EvaluatorResult:
# Reset the current run status:
self._status = None

# Get cached_results:
cached_results = self._get_cached_results(control_values, evaluator_context)

Expand Down Expand Up @@ -527,7 +520,7 @@ def _get_run_args(
batch_data: dict[int, Any],
) -> list[RunArg]:
substitutions = self._substitutions
substitutions["<BATCH_NAME>"] = ensemble.name
substitutions["<BATCH_NAME>"] = ensemble.name # Dark magic, should be fixed
self.active_realizations = [True] * len(batch_data)
for sim_id, control_idx in enumerate(batch_data.keys()):
substitutions[f"<GEO_ID_{sim_id}_{ensemble.iteration}>"] = str(
Expand Down Expand Up @@ -672,82 +665,6 @@ def check_if_runpath_exists(self) -> bool:
and any(os.listdir(self._everest_config.simulation_dir))
)

def send_snapshot_event(self, event: Event, iteration: int) -> None:
super().send_snapshot_event(event, iteration)
if type(event) in {EESnapshot, EESnapshotUpdate}:
newstatus = self._simulation_status(self.get_current_snapshot())
if self._status != newstatus: # No change in status
if self._sim_callback is not None:
self._sim_callback(newstatus)
self._status = newstatus

def _simulation_status(self, snapshot: EnsembleSnapshot) -> SimulationStatus:
jobs_progress: list[list[JobProgress]] = []
prev_realization = None
jobs: list[JobProgress] = []
for (realization, simulation), fm_step in snapshot.get_all_fm_steps().items():
if realization != prev_realization:
prev_realization = realization
if jobs:
jobs_progress.append(jobs)
jobs = []
jobs.append(
{
"name": fm_step.get("name") or "Unknown",
"status": fm_step.get("status") or "Unknown",
"error": fm_step.get("error", ""),
"start_time": fm_step.get("start_time", None),
"end_time": fm_step.get("end_time", None),
"realization": realization,
"simulation": simulation,
}
)
if fm_step.get("error", ""):
self._handle_errors(
batch=self._batch_id,
simulation=simulation,
realization=realization,
fm_name=fm_step.get("name", "Unknown"), # type: ignore
error_path=fm_step.get("stderr", ""), # type: ignore
fm_running_err=fm_step.get("error", ""), # type: ignore
)
jobs_progress.append(jobs)

return {
"status": self.get_current_status(),
"progress": jobs_progress,
"batch_number": self._batch_id,
}

def _handle_errors(
self,
batch: int,
simulation: Any,
realization: str,
fm_name: str,
error_path: str,
fm_running_err: str,
) -> None:
fm_id = f"b_{batch}_r_{realization}_s_{simulation}_{fm_name}"
fm_logger = logging.getLogger("forward_models")
if Path(error_path).is_file():
error_str = Path(error_path).read_text(encoding="utf-8") or fm_running_err
else:
error_str = fm_running_err
error_hash = hash(error_str)
err_msg = "Batch: {} Realization: {} Simulation: {} Job: {} Failed {}".format(
batch, realization, simulation, fm_name, "\n Error: {} ID:{}"
)

if error_hash not in self._fm_errors:
error_id = len(self._fm_errors)
fm_logger.error(err_msg.format(error_str, error_id))
self._fm_errors.update({error_hash: {"error_id": error_id, "ids": [fm_id]}})
elif fm_id not in self._fm_errors[error_hash]["ids"]:
self._fm_errors[error_hash]["ids"].append(fm_id)
error_id = self._fm_errors[error_hash]["error_id"]
fm_logger.error(err_msg.format("Already reported as", error_id))


class SimulatorCache:
EPS = float(np.finfo(np.float32).eps)
Expand Down
110 changes: 59 additions & 51 deletions src/everest/bin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@
from colorama import Fore
from pandas import DataFrame

from ert.ensemble_evaluator import (
EnsembleSnapshot,
FullSnapshotEvent,
SnapshotUpdateEvent,
)
from ert.resources import all_shell_script_fm_steps
from everest.config import EverestConfig, ExportConfig
from everest.detached import (
OPT_PROGRESS_ID,
SIM_PROGRESS_ID,
ServerStatus,
everserver_status,
get_opt_status,
start_monitor,
)
from everest.export import export_data
from everest.simulator import JOB_FAILURE, JOB_RUNNING, JOB_SUCCESS
from everest.strings import EVEREST
from everest.strings import EVEREST, SIM_PROGRESS_ID

try:
from progressbar import AdaptiveETA, Bar, Percentage, ProgressBar, Timer
Expand Down Expand Up @@ -147,6 +151,7 @@ def __init__(self, show_all_jobs):
self._batches_done = set()
self._last_reported_batch = -1
colorama.init(autoreset=True)
self._snapshots: dict[int, EnsembleSnapshot] = {}

def update(self, status):
try:
Expand All @@ -162,16 +167,37 @@ def update(self, status):
print(msg + "\n")
self._clear_lines = 0
if SIM_PROGRESS_ID in status:
sim_progress = status[SIM_PROGRESS_ID]
sim_progress["progress"] = self._filter_jobs(sim_progress["progress"])
msg, batch = self.get_fm_progress(sim_progress)
if msg.strip():
# Clear the previous report if it is still the same batch:
if batch == self._last_reported_batch:
self._clear()
print(msg)
self._clear_lines = len(msg.split("\n"))
self._last_reported_batch = batch
match status[SIM_PROGRESS_ID]:
case FullSnapshotEvent(snapshot=snapshot, iteration=batch):
if snapshot is not None:
self._snapshots[batch] = snapshot
case (
SnapshotUpdateEvent(snapshot=snapshot, iteration=batch) as event
):
if snapshot is not None:
self._snapshots[event.iteration].merge_snapshot(snapshot)
batch_number = event.iteration
header = self._make_header(
f"Running forward models (Batch #{batch_number})",
Fore.BLUE,
)
summary = self._get_progress_summary(event.status_count)
job_states = self._get_job_states(
self._snapshots[batch_number], self._show_all_jobs
)
msg = (
self._join_two_newlines_indent(
(header, summary, job_states)
)
+ "\n"
)
if batch == self._last_reported_batch:
self._clear()
print(msg)
self._clear_lines = len(msg.split("\n"))
self._last_reported_batch = max(
self._last_reported_batch, batch
)
except:
logging.getLogger(EVEREST).debug(traceback.format_exc())

Expand Down Expand Up @@ -209,36 +235,28 @@ def _get_opt_progress_batch(self, cli_monitor_data, batch, idx):
(header, controls, objectives, total_objective)
)

def get_fm_progress(self, context_status):
batch_number = int(context_status["batch_number"])
header = self._make_header(
f"Running forward models (Batch #{batch_number})", Fore.BLUE
)
summary = self._get_progress_summary(context_status["status"])
job_states = self._get_job_states(context_status["progress"])
msg = self._join_two_newlines_indent((header, summary, job_states)) + "\n"
return msg, batch_number

@staticmethod
def _get_progress_summary(status):
colors = [
Fore.BLACK,
Fore.BLACK,
Fore.BLUE if status["running"] > 0 else Fore.BLACK,
Fore.GREEN if status["complete"] > 0 else Fore.BLACK,
Fore.RED if status["failed"] > 0 else Fore.BLACK,
Fore.BLUE if status.get("Running", 0) > 0 else Fore.BLACK,
Fore.GREEN if status.get("Finished", 0) > 0 else Fore.BLACK,
Fore.RED if status.get("Failed", 0) > 0 else Fore.BLACK,
]
labels = ("Waiting", "Pending", "Running", "Complete", "FAILED")
values = [status.get(ls.lower(), 0) for ls in labels]
labels = ("Waiting", "Pending", "Running", "Finished", "Failed")
values = [status.get(ls, 0) for ls in labels]
return " | ".join(
f"{color}{key}: {value}{Fore.RESET}"
for color, key, value in zip(colors, labels, values, strict=False)
)

@classmethod
def _get_job_states(cls, progress):
def _get_job_states(cls, snapshot: EnsembleSnapshot, show_all_jobs: bool):
print_lines = ""
jobs_status = cls._get_jobs_status(progress)
jobs_status = cls._get_jobs_status(snapshot)
if not show_all_jobs:
jobs_status = cls._filter_jobs(jobs_status)
if jobs_status:
max_widths = {
state: _get_max_width(
Expand All @@ -254,29 +272,19 @@ def _get_job_states(cls, progress):
return print_lines

@staticmethod
def _get_jobs_status(progress):
def _get_jobs_status(snapshot: EnsembleSnapshot) -> list[JobProgress]:
job_progress = {}
for queue in progress:
for job_idx, job in enumerate(queue):
if job_idx not in job_progress:
job_progress[job_idx] = JobProgress(name=job["name"])
realization = int(job["realization"])
status = job["status"]
if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}:
job_progress[job_idx].status[status].append(realization)
return job_progress.values()

def _filter_jobs(self, progress):
if not self._show_all_jobs:
progress = [
[
job
for job in progress_list
if job["name"] not in all_shell_script_fm_steps
]
for progress_list in progress
]
return progress
for (realization, job_idx), job in snapshot.get_all_fm_steps().items():
if job_idx not in job_progress:
job_progress[job_idx] = JobProgress(name=job["name"])
status = job["status"]
if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}:
job_progress[job_idx].status[status].append(int(realization))
return list(job_progress.values())

@staticmethod
def _filter_jobs(jobs: list[JobProgress]):
return [job for job in jobs if job.name not in all_shell_script_fm_steps]

@classmethod
def _join_one_newline_indent(cls, sequence):
Expand Down
Loading

0 comments on commit b446199

Please sign in to comment.