From e0290aa432491072bbc4a892aa99ffb4e9eee4ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 24 Jan 2025 14:53:48 +0100 Subject: [PATCH 1/9] Fix bug where iteration was not set --- src/ert/run_models/everest_run_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 810912b407b..3c222e2f994 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -269,6 +269,7 @@ def _forward_model_evaluator( ensemble = self._experiment.create_ensemble( name=f"batch_{self._batch_id}", ensemble_size=len(batch_data), + iteration=self._batch_id, ) for sim_id, controls in enumerate(batch_data.values()): self._setup_sim(sim_id, controls, ensemble) @@ -423,7 +424,7 @@ def _get_run_args( substitutions[""] = ensemble.name self.active_realizations = [True] * len(batch_data) for sim_id, control_idx in enumerate(batch_data.keys()): - substitutions[f""] = str( + substitutions[f""] = str( self._everest_config.model.realizations[ evaluator_context.realizations[control_idx] ] From 3bbcbbb49e6ee105cc712c9ee75311e62da9b8dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 30 Jan 2025 10:43:31 +0100 Subject: [PATCH 2/9] Move function --- src/ert/run_models/everest_run_model.py | 35 +++++++++++++++++++- src/everest/detached/jobs/everserver.py | 43 ++++--------------------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 3c222e2f994..188a22669e5 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -30,10 +30,12 @@ from ert.storage import open_storage from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig from everest.everest_storage import EverestStorage, OptimalResult +from everest.detached import ServerStatus +from everest.detached.jobs.everserver import ExperimentStatus from everest.optimizer.everest2ropt import everest2ropt from everest.optimizer.opt_model_transforms import get_opt_model_transforms from everest.simulator.everest_to_ert import everest_to_ert_config -from everest.strings import EVEREST +from everest.strings import EVEREST, OPT_FAILURE_REALIZATIONS, STOP_ENDPOINT from ..run_arg import RunArg, create_run_arguments from .base_run_model import BaseRunModel, StatusEvents @@ -692,3 +694,34 @@ def get( if np.allclose(controls, control_values, rtol=0.0, atol=self.EPS): return objectives, constraints return None + + +def _get_optimization_status( + experiment_status: ExperimentStatus, shared_data: dict +) -> tuple[ServerStatus, str]: + match experiment_status.exit_code: + case EverestExitCode.MAX_BATCH_NUM_REACHED: + return ServerStatus.completed, "Maximum number of batches reached." + + case EverestExitCode.MAX_FUNCTIONS_REACHED: + return ( + ServerStatus.completed, + "Maximum number of function evaluations reached.", + ) + + case EverestExitCode.USER_ABORT: + return ServerStatus.stopped, "Optimization aborted." + + case EverestExitCode.EXCEPTION: + assert experiment_status.message is not None + return ServerStatus.failed, experiment_status.message + + case EverestExitCode.TOO_FEW_REALIZATIONS: + status = ( + ServerStatus.stopped + if shared_data[STOP_ENDPOINT] + else ServerStatus.failed + ) + return status, OPT_FAILURE_REALIZATIONS + case _: + return ServerStatus.completed, "Optimization completed." diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 64bef7b14ca..2bd3a66fd20 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -36,7 +36,11 @@ from ert.config.parsing.queue_system import QueueSystem from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel +from ert.run_models.everest_run_model import ( + EverestExitCode, + EverestRunModel, + _get_optimization_status, +) from everest.config import EverestConfig, ServerConfig from everest.detached import ( PROXY, @@ -51,12 +55,11 @@ DEFAULT_LOGGING_FORMAT, EVEREST, EXPERIMENT_STATUS_ENDPOINT, - OPT_FAILURE_REALIZATIONS, OPT_PROGRESS_ENDPOINT, SHARED_DATA_ENDPOINT, SIM_PROGRESS_ENDPOINT, START_EXPERIMENT_ENDPOINT, - STOP_ENDPOINT, + STOP_ENDPOINT, OPT_FAILURE_REALIZATIONS, ) from everest.util import makedirs_if_needed, version_info @@ -466,40 +469,6 @@ def main(): update_everserver_status(status_path, ServerStatus.completed, message=message) -def _get_optimization_status( - experiment_status: ExperimentStatus, shared_data: dict -) -> tuple[ServerStatus, str]: - match experiment_status.exit_code: - case EverestExitCode.MAX_BATCH_NUM_REACHED: - return ServerStatus.completed, "Maximum number of batches reached." - - case EverestExitCode.MAX_FUNCTIONS_REACHED: - return ( - ServerStatus.completed, - "Maximum number of function evaluations reached.", - ) - - case EverestExitCode.USER_ABORT: - return ServerStatus.stopped, "Optimization aborted." - - case EverestExitCode.EXCEPTION: - assert experiment_status.message is not None - return ServerStatus.failed, experiment_status.message - - case EverestExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped - if shared_data[STOP_ENDPOINT] - else ServerStatus.failed - ) - messages = _failed_realizations_messages(shared_data) - for msg in messages: - logging.getLogger(EVEREST).error(msg) - return status, "\n".join(messages) - case _: - return ServerStatus.completed, "Optimization completed." - - def _failed_realizations_messages(shared_data): messages = [OPT_FAILURE_REALIZATIONS] failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] From 380310a80e8373b5c5f798f0ff5214bbf4bd1f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 30 Jan 2025 11:46:59 +0100 Subject: [PATCH 3/9] Move variable --- src/ert/run_models/everest_run_model.py | 17 ++++++----------- src/everest/detached/jobs/everserver.py | 7 ++++++- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 188a22669e5..6443ee346c2 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -31,11 +31,10 @@ from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig from everest.everest_storage import EverestStorage, OptimalResult from everest.detached import ServerStatus -from everest.detached.jobs.everserver import ExperimentStatus from everest.optimizer.everest2ropt import everest2ropt from everest.optimizer.opt_model_transforms import get_opt_model_transforms from everest.simulator.everest_to_ert import everest_to_ert_config -from everest.strings import EVEREST, OPT_FAILURE_REALIZATIONS, STOP_ENDPOINT +from everest.strings import EVEREST, OPT_FAILURE_REALIZATIONS from ..run_arg import RunArg, create_run_arguments from .base_run_model import BaseRunModel, StatusEvents @@ -697,9 +696,9 @@ def get( def _get_optimization_status( - experiment_status: ExperimentStatus, shared_data: dict + exit_code: int, exception: str, stopped: bool ) -> tuple[ServerStatus, str]: - match experiment_status.exit_code: + match exit_code: case EverestExitCode.MAX_BATCH_NUM_REACHED: return ServerStatus.completed, "Maximum number of batches reached." @@ -713,15 +712,11 @@ def _get_optimization_status( return ServerStatus.stopped, "Optimization aborted." case EverestExitCode.EXCEPTION: - assert experiment_status.message is not None - return ServerStatus.failed, experiment_status.message + assert exception is not None + return ServerStatus.failed, exception case EverestExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped - if shared_data[STOP_ENDPOINT] - else ServerStatus.failed - ) + status = ServerStatus.stopped if stopped else ServerStatus.failed return status, OPT_FAILURE_REALIZATIONS case _: return ServerStatus.completed, "Optimization completed." diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 2bd3a66fd20..1ad1d0ef20e 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -447,7 +447,12 @@ def main(): shared_data = json_body assert experiment_status is not None - status, message = _get_optimization_status(experiment_status, shared_data) + status, message = _get_optimization_status( + experiment_status.exit_code, + experiment_status.message, + shared_data[STOP_ENDPOINT], + ) + if status != ServerStatus.completed: update_everserver_status(status_path, status, message) return From 5fdd7730e793ac725949c150f7cf218b444f3443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 31 Jan 2025 16:57:49 +0100 Subject: [PATCH 4/9] Fix typo --- src/everest/detached/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index b9e6104a8df..cab025e6f34 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -197,7 +197,7 @@ def get_opt_status(output_folder): def wait_for_server_to_stop(server_context: tuple[str, str, tuple[str, str]], timeout): """ - Checks everest server has stoped _HTTP_REQUEST_RETRY times. Waits + Checks everest server has stopped _HTTP_REQUEST_RETRY times. Waits progressively longer between each check. Raise an exception when the timeout is reached. From 20263a6db36961182f6c4ab536db505273cec6dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 9 Jan 2025 10:25:26 +0100 Subject: [PATCH 5/9] Create end point for events --- pyproject.toml | 1 + src/ert/run_models/everest_run_model.py | 103 +------ src/everest/bin/utils.py | 110 ++++---- src/everest/detached/__init__.py | 66 +++-- src/everest/detached/jobs/everserver.py | 267 +++++++++--------- src/everest/strings.py | 1 - .../entry_points/test_everest_entry.py | 200 ++----------- tests/everest/test_everserver.py | 162 +++-------- tests/everest/test_logging.py | 31 +- tests/everest/test_monitor.py | 137 +++++++++ 10 files changed, 454 insertions(+), 624 deletions(-) create mode 100644 tests/everest/test_monitor.py diff --git a/pyproject.toml b/pyproject.toml index cc5b4be1e99..59e2d8a5f00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,7 @@ types = [ "types-setuptools", ] everest = [ + "websockets", "progressbar2", "ruamel.yaml", "fastapi", diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 6443ee346c2..9359195be5a 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -23,14 +23,13 @@ from ropt.plan import Event as OptimizerEvent from typing_extensions import TypedDict -from _ert.events import EESnapshot, EESnapshotUpdate, Event from ert.config import ErtConfig, ExtParamConfig -from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig +from ert.ensemble_evaluator import EndEvent, EvaluatorServerConfig from ert.runpaths import Runpaths from ert.storage import open_storage from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig -from everest.everest_storage import EverestStorage, OptimalResult from everest.detached import ServerStatus +from everest.everest_storage import EverestStorage, OptimalResult from everest.optimizer.everest2ropt import everest2ropt from everest.optimizer.opt_model_transforms import get_opt_model_transforms from everest.simulator.everest_to_ert import everest_to_ert_config @@ -84,12 +83,12 @@ def __init__( self, config: ErtConfig, everest_config: EverestConfig, - simulation_callback: SimulationCallback | None, optimization_callback: OptimizerCallback | None, + status_queue: queue.SimpleQueue[StatusEvents] | None = None, ): Path(everest_config.log_dir).mkdir(parents=True, exist_ok=True) Path(everest_config.optimization_output_dir).mkdir(parents=True, exist_ok=True) - + status_queue = queue.SimpleQueue() if status_queue is None else status_queue assert everest_config.environment is not None logging.getLogger(EVEREST).info( "Using random seed: %d. To deterministically reproduce this experiment, " @@ -103,7 +102,6 @@ def __init__( everest_config, transforms=self._opt_model_transforms ) - self._sim_callback = simulation_callback self._opt_callback = optimization_callback self._fm_errors: dict[int, dict[str, Any]] = {} self._result: OptimalResult | None = None @@ -119,11 +117,8 @@ def __init__( self._experiment: Experiment | None = None self._eval_server_cfg: EvaluatorServerConfig | None = None self._batch_id: int = 0 - self._status: SimulationStatus | None = None storage = open_storage(config.ens_path, mode="w") - status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue() - super().__init__( storage, config.runpath_file, @@ -150,12 +145,13 @@ def create( ever_config: EverestConfig, simulation_callback: SimulationCallback | None = None, optimization_callback: OptimizerCallback | None = None, + status_queue: queue.SimpleQueue[StatusEvents] | None = None, ) -> EverestRunModel: return cls( config=everest_to_ert_config(ever_config), everest_config=ever_config, - simulation_callback=simulation_callback, optimization_callback=optimization_callback, + status_queue=status_queue, ) @classmethod @@ -214,6 +210,12 @@ def run_experiment( self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS case _: self._exit_code = EverestExitCode.COMPLETED + self.send_event( + EndEvent( + failed=self._exit_code != EverestExitCode.COMPLETED, + msg=_get_optimization_status(self._exit_code, "", False)[1], + ) + ) def _create_optimizer(self) -> BasicOptimizer: optimizer = BasicOptimizer( @@ -254,9 +256,6 @@ def _on_before_forward_model_evaluation( def _forward_model_evaluator( self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext ) -> EvaluatorResult: - # Reset the current run status: - self._status = None - # Get cached_results: cached_results = self._get_cached_results(control_values, evaluator_context) @@ -422,7 +421,7 @@ def _get_run_args( batch_data: dict[int, Any], ) -> list[RunArg]: substitutions = self._substitutions - substitutions[""] = ensemble.name + substitutions[""] = ensemble.name # Dark magic, should be fixed self.active_realizations = [True] * len(batch_data) for sim_id, control_idx in enumerate(batch_data.keys()): substitutions[f""] = str( @@ -567,82 +566,6 @@ def check_if_runpath_exists(self) -> bool: and any(os.listdir(self._everest_config.simulation_dir)) ) - def send_snapshot_event(self, event: Event, iteration: int) -> None: - super().send_snapshot_event(event, iteration) - if type(event) in {EESnapshot, EESnapshotUpdate}: - newstatus = self._simulation_status(self.get_current_snapshot()) - if self._status != newstatus: # No change in status - if self._sim_callback is not None: - self._sim_callback(newstatus) - self._status = newstatus - - def _simulation_status(self, snapshot: EnsembleSnapshot) -> SimulationStatus: - jobs_progress: list[list[JobProgress]] = [] - prev_realization = None - jobs: list[JobProgress] = [] - for (realization, simulation), fm_step in snapshot.get_all_fm_steps().items(): - if realization != prev_realization: - prev_realization = realization - if jobs: - jobs_progress.append(jobs) - jobs = [] - jobs.append( - { - "name": fm_step.get("name") or "Unknown", - "status": fm_step.get("status") or "Unknown", - "error": fm_step.get("error", ""), - "start_time": fm_step.get("start_time", None), - "end_time": fm_step.get("end_time", None), - "realization": realization, - "simulation": simulation, - } - ) - if fm_step.get("error", ""): - self._handle_errors( - batch=self._batch_id, - simulation=simulation, - realization=realization, - fm_name=fm_step.get("name", "Unknown"), # type: ignore - error_path=fm_step.get("stderr", ""), # type: ignore - fm_running_err=fm_step.get("error", ""), # type: ignore - ) - jobs_progress.append(jobs) - - return { - "status": self.get_current_status(), - "progress": jobs_progress, - "batch_number": self._batch_id, - } - - def _handle_errors( - self, - batch: int, - simulation: Any, - realization: str, - fm_name: str, - error_path: str, - fm_running_err: str, - ) -> None: - fm_id = f"b_{batch}_r_{realization}_s_{simulation}_{fm_name}" - fm_logger = logging.getLogger("forward_models") - if Path(error_path).is_file(): - error_str = Path(error_path).read_text(encoding="utf-8") or fm_running_err - else: - error_str = fm_running_err - error_hash = hash(error_str) - err_msg = "Batch: {} Realization: {} Simulation: {} Job: {} Failed {}".format( - batch, realization, simulation, fm_name, "\n Error: {} ID:{}" - ) - - if error_hash not in self._fm_errors: - error_id = len(self._fm_errors) - fm_logger.error(err_msg.format(error_str, error_id)) - self._fm_errors.update({error_hash: {"error_id": error_id, "ids": [fm_id]}}) - elif fm_id not in self._fm_errors[error_hash]["ids"]: - self._fm_errors[error_hash]["ids"].append(fm_id) - error_id = self._fm_errors[error_hash]["error_id"] - fm_logger.error(err_msg.format("Already reported as", error_id)) - class SimulatorCache: EPS = float(np.finfo(np.float32).eps) diff --git a/src/everest/bin/utils.py b/src/everest/bin/utils.py index ef73ce988ec..c1ecab8d442 100644 --- a/src/everest/bin/utils.py +++ b/src/everest/bin/utils.py @@ -9,17 +9,21 @@ import colorama from colorama import Fore +from ert.ensemble_evaluator import ( + EnsembleSnapshot, + FullSnapshotEvent, + SnapshotUpdateEvent, +) from ert.resources import all_shell_script_fm_steps from everest.detached import ( OPT_PROGRESS_ID, - SIM_PROGRESS_ID, ServerStatus, everserver_status, get_opt_status, start_monitor, ) from everest.simulator import JOB_FAILURE, JOB_RUNNING, JOB_SUCCESS -from everest.strings import EVEREST +from everest.strings import EVEREST, SIM_PROGRESS_ID def handle_keyboard_interrupt(signal, frame, options): @@ -111,6 +115,7 @@ def __init__(self, show_all_jobs): self._batches_done = set() self._last_reported_batch = -1 colorama.init(autoreset=True) + self._snapshots: dict[int, EnsembleSnapshot] = {} def update(self, status): try: @@ -126,16 +131,37 @@ def update(self, status): print(msg + "\n") self._clear_lines = 0 if SIM_PROGRESS_ID in status: - sim_progress = status[SIM_PROGRESS_ID] - sim_progress["progress"] = self._filter_jobs(sim_progress["progress"]) - msg, batch = self.get_fm_progress(sim_progress) - if msg.strip(): - # Clear the previous report if it is still the same batch: - if batch == self._last_reported_batch: - self._clear() - print(msg) - self._clear_lines = len(msg.split("\n")) - self._last_reported_batch = batch + match status[SIM_PROGRESS_ID]: + case FullSnapshotEvent(snapshot=snapshot, iteration=batch): + if snapshot is not None: + self._snapshots[batch] = snapshot + case ( + SnapshotUpdateEvent(snapshot=snapshot, iteration=batch) as event + ): + if snapshot is not None: + batch_number = event.iteration + self._snapshots[batch_number].merge_snapshot(snapshot) + header = self._make_header( + f"Running forward models (Batch #{batch_number})", + Fore.BLUE, + ) + summary = self._get_progress_summary(event.status_count) + job_states = self._get_job_states( + self._snapshots[batch_number], self._show_all_jobs + ) + msg = ( + self._join_two_newlines_indent( + (header, summary, job_states) + ) + + "\n" + ) + if batch == self._last_reported_batch: + self._clear() + print(msg) + self._clear_lines = len(msg.split("\n")) + self._last_reported_batch = max( + self._last_reported_batch, batch + ) except: logging.getLogger(EVEREST).debug(traceback.format_exc()) @@ -173,36 +199,28 @@ def _get_opt_progress_batch(self, cli_monitor_data, batch, idx): (header, controls, objectives, total_objective) ) - def get_fm_progress(self, context_status): - batch_number = int(context_status["batch_number"]) - header = self._make_header( - f"Running forward models (Batch #{batch_number})", Fore.BLUE - ) - summary = self._get_progress_summary(context_status["status"]) - job_states = self._get_job_states(context_status["progress"]) - msg = self._join_two_newlines_indent((header, summary, job_states)) + "\n" - return msg, batch_number - @staticmethod def _get_progress_summary(status): colors = [ Fore.BLACK, Fore.BLACK, - Fore.BLUE if status["running"] > 0 else Fore.BLACK, - Fore.GREEN if status["complete"] > 0 else Fore.BLACK, - Fore.RED if status["failed"] > 0 else Fore.BLACK, + Fore.BLUE if status.get("Running", 0) > 0 else Fore.BLACK, + Fore.GREEN if status.get("Finished", 0) > 0 else Fore.BLACK, + Fore.RED if status.get("Failed", 0) > 0 else Fore.BLACK, ] - labels = ("Waiting", "Pending", "Running", "Complete", "FAILED") - values = [status.get(ls.lower(), 0) for ls in labels] + labels = ("Waiting", "Pending", "Running", "Finished", "Failed") + values = [status.get(ls, 0) for ls in labels] return " | ".join( f"{color}{key}: {value}{Fore.RESET}" for color, key, value in zip(colors, labels, values, strict=False) ) @classmethod - def _get_job_states(cls, progress): + def _get_job_states(cls, snapshot: EnsembleSnapshot, show_all_jobs: bool): print_lines = "" - jobs_status = cls._get_jobs_status(progress) + jobs_status = cls._get_jobs_status(snapshot) + if not show_all_jobs: + jobs_status = cls._filter_jobs(jobs_status) if jobs_status: max_widths = { state: _get_max_width( @@ -218,29 +236,19 @@ def _get_job_states(cls, progress): return print_lines @staticmethod - def _get_jobs_status(progress): + def _get_jobs_status(snapshot: EnsembleSnapshot) -> list[JobProgress]: job_progress = {} - for queue in progress: - for job_idx, job in enumerate(queue): - if job_idx not in job_progress: - job_progress[job_idx] = JobProgress(name=job["name"]) - realization = int(job["realization"]) - status = job["status"] - if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}: - job_progress[job_idx].status[status].append(realization) - return job_progress.values() - - def _filter_jobs(self, progress): - if not self._show_all_jobs: - progress = [ - [ - job - for job in progress_list - if job["name"] not in all_shell_script_fm_steps - ] - for progress_list in progress - ] - return progress + for (realization, job_idx), job in snapshot.get_all_fm_steps().items(): + if job_idx not in job_progress: + job_progress[job_idx] = JobProgress(name=job["name"]) + status = job["status"] + if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}: + job_progress[job_idx].status[status].append(int(realization)) + return list(job_progress.values()) + + @staticmethod + def _filter_jobs(jobs: list[JobProgress]): + return [job for job in jobs if job.name not in all_shell_script_fm_steps] @classmethod def _join_one_newline_indent(cls, sequence): diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index cab025e6f34..d8ace4c12f8 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -4,16 +4,21 @@ import logging import os import re +import ssl import time import traceback from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Literal +from typing import Annotated, Literal import polars as pl import requests +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from websockets.sync.client import connect +from ert.ensemble_evaluator import EndEvent, EnsembleSnapshot +from ert.run_models import StatusEvents from ert.scheduler import create_driver from ert.scheduler.driver import Driver, FailedSubmit from ert.scheduler.event import StartedEvent @@ -23,7 +28,6 @@ EVEREST_SERVER_CONFIG, OPT_PROGRESS_ENDPOINT, OPT_PROGRESS_ID, - SIM_PROGRESS_ENDPOINT, SIM_PROGRESS_ID, START_EXPERIMENT_ENDPOINT, STOP_ENDPOINT, @@ -42,6 +46,14 @@ # everest.log file instead +logger = logging.getLogger(__name__) + + +class EventWrapper(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + event: Annotated[StatusEvents, Field(discriminator="event_type")] + + async def start_server(config: EverestConfig, debug: bool = False) -> Driver: """ Start an Everest server running the optimization defined in the config @@ -242,29 +254,45 @@ def start_monitor( interrupted by returning True from the callback """ url, cert, auth = server_context - sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT]) opt_endpoint = "/".join([url, OPT_PROGRESS_ENDPOINT]) - sim_status: dict = {} opt_status: dict = {} stop = False + ssl_context = ssl.create_default_context() + ssl_context.load_verify_locations(cafile=cert) + try: - while not stop: - new_sim_status = _query_server(cert, auth, sim_endpoint) - if new_sim_status != sim_status: - sim_status = new_sim_status - ret = bool(callback({SIM_PROGRESS_ID: sim_status})) - stop |= ret - # When the API will support it query only from a certain batch on - - # Check the optimization status - new_opt_status = _query_server(cert, auth, opt_endpoint) - if new_opt_status != opt_status: - opt_status = new_opt_status - ret = bool(callback({OPT_PROGRESS_ID: opt_status})) - stop |= ret - time.sleep(polling_interval) + with connect( + "wss://{username}:{password}@" + url.replace("https://", "") + "/events", + ssl=ssl_context, + open_timeout=30, + ) as websocket: + while not stop: + try: + message = websocket.recv(timeout=1.0) + except TimeoutError: + message = None + if message: + event_dict = json.loads(message) + if "snapshot" in event_dict: + event_dict["snapshot"] = EnsembleSnapshot.from_nested_dict( + event_dict["snapshot"] + ) + try: + event = EventWrapper(event=event_dict).event + except ValidationError as e: + logger.error("Error when processing event %s", exc_info=e) + if isinstance(event, EndEvent): + print(event.msg) + callback({SIM_PROGRESS_ID: event}) + # Check the optimization status + new_opt_status = _query_server(cert, auth, opt_endpoint) + if new_opt_status != opt_status: + opt_status = new_opt_status + ret = bool(callback({OPT_PROGRESS_ID: opt_status})) + stop |= ret + time.sleep(polling_interval) except: logging.debug(traceback.format_exc()) diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 1ad1d0ef20e..07beec73263 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -1,19 +1,22 @@ import argparse +import asyncio import datetime import json import logging +import multiprocessing as mp import os +import queue import socket import ssl import threading import time import traceback +import uuid from base64 import b64encode from functools import partial from pathlib import Path from typing import Any -import requests import uvicorn from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -21,7 +24,15 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi import ( + BackgroundTasks, + Depends, + FastAPI, + HTTPException, + Request, + WebSocket, + status, +) from fastapi.encoders import jsonable_encoder from fastapi.responses import ( JSONResponse, @@ -35,7 +46,13 @@ from pydantic import BaseModel from ert.config.parsing.queue_system import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig +from ert.ensemble_evaluator import ( + EndEvent, + EvaluatorServerConfig, + FullSnapshotEvent, + SnapshotUpdateEvent, +) +from ert.run_models import StatusEvents from ert.run_models.everest_run_model import ( EverestExitCode, EverestRunModel, @@ -43,23 +60,18 @@ ) from everest.config import EverestConfig, ServerConfig from everest.detached import ( - PROXY, ServerStatus, get_opt_status, update_everserver_status, wait_for_server, ) from everest.plugins.everest_plugin_manager import EverestPluginManager -from everest.simulator import JOB_FAILURE from everest.strings import ( DEFAULT_LOGGING_FORMAT, EVEREST, - EXPERIMENT_STATUS_ENDPOINT, OPT_PROGRESS_ENDPOINT, - SHARED_DATA_ENDPOINT, - SIM_PROGRESS_ENDPOINT, START_EXPERIMENT_ENDPOINT, - STOP_ENDPOINT, OPT_FAILURE_REALIZATIONS, + STOP_ENDPOINT, ) from everest.util import makedirs_if_needed, version_info @@ -69,19 +81,19 @@ class ExperimentStatus(BaseModel): message: str | None = None -class ExperimentRunner(threading.Thread): +class ExperimentRunner: def __init__(self, everest_config, shared_data: dict): super().__init__() self._everest_config = everest_config - self._shared_data = shared_data - self._status: ExperimentStatus | None = None + self.shared_data = shared_data - def run(self): + async def run(self): + status_queue: mp.Queue[StatusEvents] = mp.Queue() run_model = EverestRunModel.create( self._everest_config, - simulation_callback=partial(_sim_monitor, shared_data=self._shared_data), - optimization_callback=partial(_opt_monitor, shared_data=self._shared_data), + optimization_callback=partial(_opt_monitor, shared_data=self.shared_data), + status_queue=status_queue, ) if run_model._queue_config.queue_system == QueueSystem.LOCAL: @@ -92,22 +104,50 @@ def run(self): ) try: - run_model.run_experiment(evaluator_server_config) - + loop = asyncio.get_running_loop() + simulation_future = loop.run_in_executor( + None, + lambda: run_model.run_experiment(evaluator_server_config), + ) + while True: + if self.shared_data[STOP_ENDPOINT]: + run_model.cancel() + raise ValueError("Optimization aborted") + try: + item: StatusEvents = status_queue.get(block=False) + except queue.Empty: + await asyncio.sleep(0.01) + continue + + self.shared_data["events"].append(item) + for sub in self.shared_data["subscribers"].values(): + sub.notify() + + if isinstance(item, EndEvent): + break + await asyncio.sleep(0.1) + await simulation_future assert run_model.exit_code is not None - self._status = ExperimentStatus(exit_code=run_model.exit_code) + self.shared_data["experiment_status"] = ExperimentStatus( + exit_code=run_model.exit_code + ) except Exception as e: - self._status = ExperimentStatus( + self.shared_data["experiment_status"] = ExperimentStatus( exit_code=EverestExitCode.EXCEPTION, message=str(e) ) - @property - def status(self) -> ExperimentStatus | None: - return self._status - @property - def shared_data(self) -> dict: - return self._shared_data +class Subscriber: + def __init__(self) -> None: + self.index = 0 + self._event = asyncio.Event() + + def notify(self): + self._event.set() + + async def wait_for_event(self): + await self._event.wait() + self._event.clear() def _get_machine_name() -> str: @@ -137,26 +177,6 @@ def _get_machine_name() -> str: return "localhost" -def _sim_monitor(context_status, shared_data=None): - assert shared_data is not None - - status = context_status["status"] - shared_data[SIM_PROGRESS_ENDPOINT] = { - "batch_number": context_status["batch_number"], - "status": { - "running": status.get("Running", 0), - "waiting": status.get("Waiting", 0), - "pending": status.get("Pending", 0), - "complete": status.get("Finished", 0), - "failed": status.get("Failed", 0), - }, - "progress": context_status["progress"], - } - - if shared_data[STOP_ENDPOINT]: - return "stop_queue" - - def _opt_monitor(shared_data=None): assert shared_data is not None if shared_data[STOP_ENDPOINT]: @@ -167,8 +187,6 @@ def _everserver_thread(shared_data, server_config) -> None: app = FastAPI() security = HTTPBasic() - runner: ExperimentRunner | None = None - def _check_user(credentials: HTTPBasicCredentials) -> None: if credentials.password != server_config["authentication"]: raise HTTPException( @@ -199,15 +217,6 @@ def stop( shared_data[STOP_ENDPOINT] = True return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) - @app.get("/" + SIM_PROGRESS_ENDPOINT) - def get_sim_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = shared_data[SIM_PROGRESS_ENDPOINT] - return JSONResponse(jsonable_encoder(progress)) - @app.get("/" + OPT_PROGRESS_ENDPOINT) def get_opt_progress( request: Request, credentials: HTTPBasicCredentials = Depends(security) @@ -218,50 +227,49 @@ def get_opt_progress( return JSONResponse(jsonable_encoder(progress)) @app.post("/" + START_EXPERIMENT_ENDPOINT) - def start_experiment( + async def start_experiment( config: EverestConfig, request: Request, + background_tasks: BackgroundTasks, credentials: HTTPBasicCredentials = Depends(security), ) -> Response: _log(request) _check_user(credentials) - nonlocal runner - if runner is None: + if not shared_data["started"]: runner = ExperimentRunner(config, shared_data) try: - runner.start() + background_tasks.add_task(runner.run) + shared_data["started"] = True return Response("Everest experiment started") except Exception as e: return Response(f"Could not start experiment: {e!s}", status_code=501) return Response("Everest experiment is running") - @app.get("/" + EXPERIMENT_STATUS_ENDPOINT) - def get_experiment_status( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> Response: - _log(request) - _check_user(credentials) - if shared_data[STOP_ENDPOINT]: - return JSONResponse( - ExperimentStatus(exit_code=EverestExitCode.USER_ABORT).model_dump_json() - ) - if runner is None: - return Response(None, 204) - status = runner.status - if status is None: - return Response(None, 204) - return JSONResponse(status.model_dump_json()) - - @app.get("/" + SHARED_DATA_ENDPOINT) - def get_shared_data( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - if runner is None: - return JSONResponse(jsonable_encoder(shared_data)) - return JSONResponse(jsonable_encoder(runner.shared_data)) + @app.websocket("/events") + async def websocket_endpoint(websocket: WebSocket): + subscriber_id = str(uuid.uuid4()) + await websocket.accept() + while True: + event = await get_event(subscriber_id=subscriber_id) + await websocket.send_json(jsonable_encoder(event)) + await asyncio.sleep(0.1) + if isinstance(event, EndEvent): + # Give some time for subscribers to get events + await asyncio.sleep(5) + break + + async def get_event(subscriber_id: str) -> StatusEvents: + if subscriber_id not in shared_data["subscribers"]: + shared_data["subscribers"][subscriber_id] = Subscriber() + subscriber = shared_data["subscribers"][subscriber_id] + + while subscriber.index >= len(shared_data["events"]): + await subscriber.wait_for_event() + + event = shared_data["events"][subscriber.index] + shared_data["subscribers"][subscriber_id].index += 1 + return event uvicorn.run( app, @@ -344,6 +352,17 @@ def make_handler_config( EverestPluginManager().add_log_handle_to_root() +def send_end_event(shared_data, failed, msg): + shared_data["events"].append( + EndEvent( + failed=failed, + msg=msg, + ) + ) + for sub in shared_data["subscribers"].values(): + sub.notify() + + def main(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--config-file", type=str) @@ -378,8 +397,11 @@ def main(): _write_hostfile(host_file, host, port, cert_path, authentication) shared_data = { - SIM_PROGRESS_ENDPOINT: {}, STOP_ENDPOINT: False, + "started": False, + "experiment_status": None, + "events": [], + "subscribers": {}, } server_config = { @@ -405,56 +427,38 @@ def main(): message=traceback.format_exc(), ) return - try: wait_for_server(config.output_dir, 60) update_everserver_status(status_path, ServerStatus.running) - server_context = (ServerConfig.get_server_context(config.output_dir),) - url, cert, auth = server_context[0] - - done = False - experiment_status: ExperimentStatus | None = None # loop until the optimization is done - while not done: - response = requests.get( - "/".join([url, EXPERIMENT_STATUS_ENDPOINT]), - verify=cert, - auth=auth, - timeout=1, - proxies=PROXY, # type: ignore - ) - if response.status_code == requests.codes.OK: - json_body = json.loads( - response.text if hasattr(response, "text") else response.body - ) - experiment_status = ExperimentStatus.model_validate_json(json_body) - done = True - else: - time.sleep(1) - - response = requests.get( - "/".join([url, SHARED_DATA_ENDPOINT]), - verify=cert, - auth=auth, - timeout=1, - proxies=PROXY, # type: ignore - ) - if json_body := json.loads( - response.text if hasattr(response, "text") else response.body - ): - shared_data = json_body - + while not shared_data["experiment_status"]: + if shared_data[STOP_ENDPOINT] and not shared_data["started"]: + raise ValueError("Stopped by user") + time.sleep(1) + experiment_status = shared_data["experiment_status"] assert experiment_status is not None + snapshots = {} + for event in shared_data["events"]: + if isinstance(event, FullSnapshotEvent): + snapshots[event.iteration] = event + elif isinstance(event, SnapshotUpdateEvent): + snapshots[event.iteration].snapshot.merge_snapshot(event.snapshot) + logging.getLogger("forward_models").info("Status event") + fm_status = "\n Forward model status:\b" + for snapshot in snapshots.values(): + logging.getLogger("forward_models").info( + f"Status event: {jsonable_encoder(snapshot)}" + ) + fm_status += str(snapshot.snapshot.to_dict()) status, message = _get_optimization_status( experiment_status.exit_code, experiment_status.message, shared_data[STOP_ENDPOINT], ) - if status != ServerStatus.completed: - update_everserver_status(status_path, status, message) + update_everserver_status(status_path, status, message + fm_status) return except: if shared_data[STOP_ENDPOINT]: @@ -470,24 +474,9 @@ def main(): message=traceback.format_exc(), ) return - update_everserver_status(status_path, ServerStatus.completed, message=message) -def _failed_realizations_messages(shared_data): - messages = [OPT_FAILURE_REALIZATIONS] - failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] - if failed > 0: - # Report each unique pair of failed job name and error - for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"]: - for job in queue: - if job["status"] == JOB_FAILURE: - err_msg = f"{job['name']} Failed with: {job.get('error', '')}" - if err_msg not in messages: - messages.append(err_msg) - return messages - - def _generate_certificate(cert_folder: str): """Generate a private key and a certificate signed with it diff --git a/src/everest/strings.py b/src/everest/strings.py index 593f592fa35..d82f66579e2 100644 --- a/src/everest/strings.py +++ b/src/everest/strings.py @@ -27,7 +27,6 @@ SIMULATOR_START = "start" SIMULATOR_UPDATE = "update" SIMULATOR_END = "end" -SIM_PROGRESS_ENDPOINT = "sim_progress" SIM_PROGRESS_ID = "simulation_progress" STOP_ENDPOINT = "stop" STORAGE_DIR = "simulation_results" diff --git a/tests/everest/entry_points/test_everest_entry.py b/tests/everest/entry_points/test_everest_entry.py index 3b38b0496ab..efb79732853 100644 --- a/tests/everest/entry_points/test_everest_entry.py +++ b/tests/everest/entry_points/test_everest_entry.py @@ -1,69 +1,26 @@ import logging import os from functools import partial -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +import yaml -from ert.resources import all_shell_script_fm_steps +import everest from everest.bin.everest_script import everest_entry from everest.bin.kill_script import kill_entry from everest.bin.monitor_script import monitor_entry from everest.config import EverestConfig, ServerConfig from everest.detached import ( - SIM_PROGRESS_ENDPOINT, ServerStatus, everserver_status, update_everserver_status, ) -from everest.simulator import JOB_SUCCESS from tests.everest.utils import capture_streams CONFIG_FILE_MINIMAL = "config_minimal.yml" -def query_server_mock(cert, auth, endpoint): - url = "localhost" - sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT]) - - def build_job( - status=JOB_SUCCESS, - start_time="begining", - end_time="end", - name="default_job", - error=None, - ): - return { - "status": status, - "start_time": start_time, - "end_time": end_time, - "name": name, - "error": error, - "realization": 0, - } - - shell_cmd_jobs = [build_job(name=command) for command in all_shell_script_fm_steps] - all_jobs = [ - *shell_cmd_jobs, - build_job(name="make_pancakes"), - build_job(name="make_scrambled_eggs"), - ] - if endpoint == sim_endpoint: - return { - "status": { - "failed": 0, - "running": 0, - "complete": 1, - "pending": 0, - "waiting": 0, - }, - "progress": [all_jobs], - "batch_number": "0", - } - else: - raise Exception("Stop! Hands in the air!") - - def run_detached_monitor_mock(status=ServerStatus.completed, error=None, **kwargs): optimization_output = kwargs.get("optimization_output_dir") path = os.path.join(optimization_output, "../detached_node_output/.session/status") @@ -290,145 +247,30 @@ def test_everest_entry_monitor_no_run( everserver_status_mock.assert_called() -@patch("everest.bin.everest_script.server_is_running", return_value=False) -@patch("everest.bin.everest_script.wait_for_server") -@patch("everest.bin.everest_script.start_server") -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.everest_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) -@patch("everest.bin.everest_script.start_experiment") -def test_everest_entry_show_all_jobs( - start_experiment_mock, - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - start_server_mock, - wait_for_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, -): - """Test running everest with --show-all-jobs""" - - # Test when --show-all-jobs flag is given shell command are in the list - # of forward model jobs - with capture_streams() as (out, _): - everest_entry([CONFIG_FILE_MINIMAL, "--show-all-jobs"]) - for cmd in all_shell_script_fm_steps: - assert cmd in out.getvalue() - - -@patch("everest.bin.everest_script.server_is_running", return_value=False) -@patch("everest.bin.everest_script.wait_for_server") -@patch("everest.bin.everest_script.start_server") -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.everest_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) -@patch("everest.bin.everest_script.start_experiment") -def test_everest_entry_no_show_all_jobs( - start_experiment_mock, - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - start_server_mock, - wait_for_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, -): - """Test running everest without --show-all-jobs""" - - # Test when --show-all-jobs flag is not given the shell command are not - # in the list of forward model jobs - with capture_streams() as (out, _): - everest_entry([CONFIG_FILE_MINIMAL]) - for cmd in all_shell_script_fm_steps: - assert cmd not in out.getvalue() - - # Check the other jobs are still there - assert "make_pancakes" in out.getvalue() - assert "make_scrambled_eggs" in out.getvalue() +@pytest.fixture(autouse=True) +def mock_ssl(monkeypatch): + monkeypatch.setattr(everest.detached, "ssl", MagicMock()) +@pytest.mark.parametrize("show_all_jobs", [True, False]) @patch("everest.bin.monitor_script.server_is_running", return_value=True) -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.monitor_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) def test_monitor_entry_show_all_jobs( - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, + _, + monkeypatch, + tmp_path, + min_config, + show_all_jobs, ): """Test running everest with and without --show-all-jobs""" - - # Test when --show-all-jobs flag is given shell command are in the list - # of forward model jobs - - with capture_streams() as (out, _): - monitor_entry([CONFIG_FILE_MINIMAL, "--show-all-jobs"]) - for cmd in all_shell_script_fm_steps: - assert cmd in out.getvalue() - - -@patch("everest.bin.monitor_script.server_is_running", return_value=True) -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.monitor_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) -def test_monitor_entry_no_show_all_jobs( - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, -): - """Test running everest without --show-all-jobs""" - - # Test when --show-all-jobs flag is not given the shell command are not - # in the list of forward model jobs - with capture_streams() as (out, _): - monitor_entry([CONFIG_FILE_MINIMAL]) - for cmd in all_shell_script_fm_steps: - assert cmd not in out.getvalue() - - # Check the other jobs are still there - assert "make_pancakes" in out.getvalue() - assert "make_scrambled_eggs" in out.getvalue() + monkeypatch.chdir(tmp_path) + with open("config.yml", "w", encoding="utf-8") as fout: + yaml.dump(min_config, fout) + detatched_mock = MagicMock() + monkeypatch.setattr(everest.bin.utils, "start_monitor", MagicMock()) + monkeypatch.setattr(everest.bin.utils, "_DetachedMonitor", detatched_mock) + args = ["config.yml"] if not show_all_jobs else ["config.yml", "--show-all-jobs"] + monitor_entry(args) + detatched_mock.assert_called_once_with(show_all_jobs) @patch( diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 565bfb7d0bf..e7b923ed25d 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -2,13 +2,12 @@ import os import ssl from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -import requests -from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse +import everest +from ert.ensemble_evaluator import EndEvent from ert.run_models.everest_run_model import EverestExitCode from ert.scheduler.event import FinishedEvent from everest.config import EverestConfig, ServerConfig @@ -20,11 +19,11 @@ wait_for_server, ) from everest.detached.jobs import everserver +from everest.detached.jobs.everserver import ExperimentStatus, _everserver_thread from everest.everest_storage import EverestStorage -from everest.simulator import JOB_FAILURE, JOB_SUCCESS +from everest.simulator import JOB_FAILURE from everest.strings import ( OPT_FAILURE_REALIZATIONS, - SIM_PROGRESS_ENDPOINT, STOP_ENDPOINT, ) @@ -66,7 +65,6 @@ def fail_optimization(self, from_ropt=False): # call the provided simulation callback, which has access to the shared_data # variable in the eversever main function. Patch that callback to modify # shared_data (see set_shared_status() below). - self._sim_callback(None) if from_ropt: self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS return EverestExitCode.TOO_FEW_REALIZATIONS @@ -81,7 +79,7 @@ def set_shared_status(*args, progress, shared_data): [job for queue in progress for job in queue if job["status"] == JOB_FAILURE] ) - shared_data[SIM_PROGRESS_ENDPOINT] = { + shared_data["events"] = { "status": {"failed": failed}, "progress": progress, } @@ -137,37 +135,31 @@ def test_configure_logger_failure(mocked_logger, copy_math_func_test_data_to_tmp assert "Exception: Configuring logger failed" in status["message"] +@pytest.fixture +def mock_end_event(monkeypatch): + queue_mock = MagicMock() + queue_mock.get.return_value = EndEvent(failed=False, msg="Experiment complete") + monkeypatch.setattr( + everest.detached.jobs.everserver.mp, "Queue", MagicMock(return_value=queue_mock) + ) + yield queue_mock + + @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("requests.get") def test_status_running_complete( - mocked_get, mocked_logger, copy_math_func_test_data_to_tmp + mocked_logger, copy_math_func_test_data_to_tmp, monkeypatch ): + def server_mock(shared_data, server_config): + shared_data["experiment_status"] = ExperimentStatus(exit_code=1) + _everserver_thread(shared_data, server_config) + + monkeypatch.setattr( + everest.detached.jobs.everserver, "_everserver_thread", server_mock + ) config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - def mocked_server(url, verify, auth, timeout, proxies): - if "/experiment_status" in url: - return JSONResponse( - everserver.ExperimentStatus( - exit_code=EverestExitCode.COMPLETED - ).model_dump_json() - ) - if "/shared_data" in url: - return JSONResponse( - jsonable_encoder( - { - SIM_PROGRESS_ENDPOINT: {}, - STOP_ENDPOINT: False, - } - ) - ) - resp = requests.Response() - resp.status_code = 200 - return resp - - mocked_get.side_effect = mocked_server - everserver.main() status = everserver_status( @@ -180,60 +172,17 @@ def mocked_server(url, verify, auth, timeout, proxies): @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("requests.get") -def test_status_failed_job(mocked_get, mocked_logger, copy_math_func_test_data_to_tmp): +def test_status_failed_job(mocked_logger, copy_math_func_test_data_to_tmp, monkeypatch): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - def mocked_server(url, verify, auth, timeout, proxies): - if "/experiment_status" in url: - return JSONResponse( - everserver.ExperimentStatus( - exit_code=EverestExitCode.TOO_FEW_REALIZATIONS - ).model_dump_json() - ) - if "/shared_data" in url: - return JSONResponse( - jsonable_encoder( - { - SIM_PROGRESS_ENDPOINT: { - "status": {"failed": 3}, - "progress": [ - [ - { - "name": "job1", - "status": JOB_FAILURE, - "error": "job 1 error 1", - }, - { - "name": "job1", - "status": JOB_FAILURE, - "error": "job 1 error 2", - }, - ], - [ - { - "name": "job2", - "status": JOB_SUCCESS, - "error": "", - }, - { - "name": "job2", - "status": JOB_FAILURE, - "error": "job 2 error 1", - }, - ], - ], - }, - STOP_ENDPOINT: False, - } - ) - ) - resp = requests.Response() - resp.status_code = 200 - return resp - - mocked_get.side_effect = mocked_server + def server_mock(shared_data, server_config): + shared_data["experiment_status"] = ExperimentStatus(exit_code=2) + _everserver_thread(shared_data, server_config) + + monkeypatch.setattr( + everest.detached.jobs.everserver, "_everserver_thread", server_mock + ) everserver.main() @@ -244,42 +193,22 @@ def mocked_server(url, verify, auth, timeout, proxies): # The server should fail and store a user-friendly message. assert status["status"] == ServerStatus.failed assert OPT_FAILURE_REALIZATIONS in status["message"] - assert "job1 Failed with: job 1 error 1" in status["message"] - assert "job1 Failed with: job 1 error 2" in status["message"] - assert "job2 Failed with: job 2 error 1" in status["message"] @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) -@patch("everest.detached.jobs.everserver._configure_loggers") -@patch("requests.get") -def test_status_exception(mocked_get, mocked_logger, copy_math_func_test_data_to_tmp): +def test_everserver_status_exception(copy_math_func_test_data_to_tmp, monkeypatch): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - def mocked_server(url, verify, auth, timeout, proxies): - if "/experiment_status" in url: - return JSONResponse( - everserver.ExperimentStatus( - exit_code=EverestExitCode.EXCEPTION, message="Some message" - ).model_dump_json() - ) - if "/shared_data" in url: - return JSONResponse( - jsonable_encoder( - { - SIM_PROGRESS_ENDPOINT: { - "status": {}, - "progress": [], - }, - STOP_ENDPOINT: False, - } - ) - ) - resp = requests.Response() - resp.status_code = 200 - return resp - - mocked_get.side_effect = mocked_server + def server_mock(shared_data, server_config): + shared_data["experiment_status"] = ExperimentStatus( + exit_code=6, message="Some message" + ) + _everserver_thread(shared_data, server_config) + + monkeypatch.setattr( + everest.detached.jobs.everserver, "_everserver_thread", server_mock + ) everserver.main() status = everserver_status( @@ -328,7 +257,7 @@ async def test_status_contains_max_runtime_failure( config_file = "config_minimal.yml" Path("SLEEP_job").write_text("EXECUTABLE sleep", encoding="utf-8") - min_config["simulator"] = {"max_runtime": 2} + min_config["simulator"] = {"max_runtime": 1} min_config["forward_model"] = ["sleep 5"] min_config["install_jobs"] = [{"name": "sleep", "source": "SLEEP_job"}] @@ -345,7 +274,4 @@ async def test_status_contains_max_runtime_failure( assert status["status"] == ServerStatus.failed print(status["message"]) - assert ( - "sleep Failed with: The run is cancelled due to reaching MAX_RUNTIME" - in status["message"] - ) + assert "The run is cancelled due to reaching MAX_RUNTIME" in status["message"] diff --git a/tests/everest/test_logging.py b/tests/everest/test_logging.py index 0ca1afe5aea..17976b74686 100644 --- a/tests/everest/test_logging.py +++ b/tests/everest/test_logging.py @@ -3,14 +3,12 @@ import pytest -from ert.scheduler.event import FinishedEvent +from everest.bin.main import start_everest from everest.config import ( EverestConfig, ServerConfig, ) from everest.config.install_job_config import InstallJobConfig -from everest.detached import start_experiment, start_server, wait_for_server -from everest.util import makedirs_if_needed def _string_exists_in_file(file_path, string): @@ -20,13 +18,7 @@ def _string_exists_in_file(file_path, string): @pytest.mark.timeout(240) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest") -async def test_logging_setup(copy_math_func_test_data_to_tmp): - async def server_running(): - while True: - event = await driver.event_queue.get() - if isinstance(event, FinishedEvent) and event.iens == 0: - return - +def test_logging_setup(copy_math_func_test_data_to_tmp): everest_config = EverestConfig.load_file("config_minimal.yml") everest_config.forward_model.append("toggle_failure --fail simulation_2") everest_config.install_jobs.append( @@ -39,19 +31,7 @@ async def server_running(): # start_server() loads config based on config_path, so we need to actually overwrite it everest_config.dump("config_minimal.yml") - - makedirs_if_needed(everest_config.output_dir, roll_if_exists=True) - driver = await start_server(everest_config, debug=True) - try: - wait_for_server(everest_config.output_dir, 120) - - start_experiment( - server_context=ServerConfig.get_server_context(everest_config.output_dir), - config=everest_config, - ) - except (SystemExit, RuntimeError) as e: - raise e - await server_running() + start_everest(["everest", "run", "config_minimal.yml"]) everest_output_path = os.path.join(os.getcwd(), "everest_output") everest_logs_dir_path = everest_config.log_dir @@ -68,10 +48,7 @@ async def server_running(): assert _string_exists_in_file(everest_log_path, "everest DEBUG:") assert _string_exists_in_file( - forward_model_log_path, "Exception: Failing simulation_2 by request!" - ) - assert _string_exists_in_file( - forward_model_log_path, "Exception: Failing simulation_2 by request!" + forward_model_log_path, "Process exited with status code 1" ) endpoint_logs = Path(endpoint_log_path).read_text(encoding="utf-8") diff --git a/tests/everest/test_monitor.py b/tests/everest/test_monitor.py new file mode 100644 index 00000000000..53579b7b876 --- /dev/null +++ b/tests/everest/test_monitor.py @@ -0,0 +1,137 @@ +import json +import string +from collections import defaultdict +from datetime import datetime +from functools import partial +from unittest.mock import MagicMock, patch + +import pytest +from websockets.sync.client import ClientConnection + +import everest +from ert.ensemble_evaluator import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, + state, +) +from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata +from ert.resources import all_shell_script_fm_steps +from everest.bin.utils import run_detached_monitor +from tests.ert import SnapshotBuilder + +METADATA = EnsembleSnapshotMetadata( + aggr_fm_step_status_colors=defaultdict(dict), + real_status_colors={}, + sorted_real_ids=[], + sorted_fm_step_ids=defaultdict(list), +) + + +from fastapi.encoders import jsonable_encoder + + +@pytest.fixture +def full_snapshot_event(): + snapshot = SnapshotBuilder(metadata=METADATA) + snapshot.add_fm_step( + fm_step_id="0", + index="0", + name="fm_step_0", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_0.stdout", + stderr="job_fm_step_0.stderr", + start_time=datetime(1999, 1, 1), + ) + for i, command in enumerate(all_shell_script_fm_steps): + snapshot.add_fm_step( + fm_step_id=str(i + 1), + index=str(i + 1), + name=command, + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout=None, + stderr=None, + start_time=datetime(1999, 1, 1), + ) + event = FullSnapshotEvent( + snapshot=snapshot.build( + real_ids=["0", "1"], + status=state.REALIZATION_STATE_PENDING, + start_time=datetime(1999, 1, 1), + exec_hosts="12121.121", + message="Some message", + ), + iteration_label="Foo", + total_iterations=1, + progress=0.25, + realization_count=4, + status_count={ + "Finished": 0, + "Pending": len(all_shell_script_fm_steps), + "Unknown": 0, + }, + iteration=0, + ) + yield json.dumps(jsonable_encoder(event)) + + +@pytest.fixture +def snapshot_update_event(): + event = SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + name=None, + index="0", + status=state.FORWARD_MODEL_STATE_FINISHED, + end_time=datetime(2019, 1, 1), + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_FINISHED, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 1, "Running": 0, "Unknown": 0}, + iteration=0, + ) + yield json.dumps(jsonable_encoder(event)) + + +@pytest.mark.parametrize("show_all_jobs", [True, False]) +def test_monitor( + monkeypatch, full_snapshot_event, snapshot_update_event, capsys, show_all_jobs +): + server_mock = MagicMock() + connection_mock = MagicMock(spec=ClientConnection) + connection_mock.recv.side_effect = [ + full_snapshot_event, + snapshot_update_event, + json.dumps(jsonable_encoder(EndEvent(failed=False, msg="Experiment complete"))), + ] + server_mock.return_value.__enter__.return_value = connection_mock + monkeypatch.setattr(everest.detached, "_query_server", MagicMock(return_value={})) + monkeypatch.setattr(everest.detached, "connect", server_mock) + monkeypatch.setattr(everest.detached, "ssl", MagicMock()) + patched = partial(everest.detached.start_monitor, polling_interval=0.1) + with patch("everest.bin.utils.start_monitor", patched): + run_detached_monitor(("some/url", None, None), "output", show_all_jobs) + captured = capsys.readouterr() + expected = [ + "===================== Running forward models (Batch #0) ======================\n", + " Waiting: 0 | Pending: 0 | Running: 0 | Complete: 0 | Failed: 0\n", + " fm_step_0: 0/1/0 | Finished: 1\n", + "Experiment complete\n", + ] + if show_all_jobs: + expected[-1:-1] = [f"{name}: 0/0/0" for name in all_shell_script_fm_steps] + # Ignore whitespace + assert captured.out.translate({ord(c): None for c in string.whitespace}) == "".join( + expected + ).translate({ord(c): None for c in string.whitespace}) From c58f689fe9b3e8b8a9fa31980ee6b97d8e02f876 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Tue, 28 Jan 2025 14:28:03 +0100 Subject: [PATCH 6/9] Add monitoring of failed jobs --- src/everest/bin/utils.py | 23 +++++++++--- tests/ert/__init__.py | 2 + tests/everest/test_everserver.py | 1 - tests/everest/test_monitor.py | 63 ++++++++++++++++++++++++++++++-- 4 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/everest/bin/utils.py b/src/everest/bin/utils.py index c1ecab8d442..d28aa96fbde 100644 --- a/src/everest/bin/utils.py +++ b/src/everest/bin/utils.py @@ -2,6 +2,7 @@ import os import sys import traceback +from collections import defaultdict from dataclasses import dataclass, field from itertools import groupby from typing import ClassVar @@ -77,6 +78,7 @@ class JobProgress: JOB_FAILURE: [], # contains failed simulation numbers i.e [5,6] } ) + errors: defaultdict[list] = field(default_factory=lambda: defaultdict(list)) STATUS_COLOR: ClassVar = { JOB_RUNNING: Fore.BLUE, JOB_SUCCESS: Fore.GREEN, @@ -217,7 +219,7 @@ def _get_progress_summary(status): @classmethod def _get_job_states(cls, snapshot: EnsembleSnapshot, show_all_jobs: bool): - print_lines = "" + print_lines = [] jobs_status = cls._get_jobs_status(snapshot) if not show_all_jobs: jobs_status = cls._filter_jobs(jobs_status) @@ -229,11 +231,18 @@ def _get_job_states(cls, snapshot: EnsembleSnapshot, show_all_jobs: bool): for state in [JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE] } width = _get_max_width([item.name for item in jobs_status]) - print_lines = cls._join_one_newline_indent( - f"{item.name:>{width}}: {item.progress_str(max_widths)}{Fore.RESET}" - for item in jobs_status - ) - return print_lines + for job in jobs_status: + print_lines.append( + f"{job.name:>{width}}: {job.progress_str(max_widths)}{Fore.RESET}" + ) + if job.errors: + print_lines.extend( + [ + f"{Fore.RED}{job.name:>{width}}: Failed: {err}, realizations: {_format_list(job.errors[err])}{Fore.RESET}" + for err in job.errors + ] + ) + return cls._join_one_newline_indent(print_lines) @staticmethod def _get_jobs_status(snapshot: EnsembleSnapshot) -> list[JobProgress]: @@ -244,6 +253,8 @@ def _get_jobs_status(snapshot: EnsembleSnapshot) -> list[JobProgress]: status = job["status"] if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}: job_progress[job_idx].status[status].append(int(realization)) + if error := job.get("error"): + job_progress[job_idx].errors[error].append(int(realization)) return list(job_progress.values()) @staticmethod diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index e3db923e86d..5c1a39b9736 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -71,6 +71,7 @@ def add_fm_step( end_time: datetime | None = None, stdout: str | None = None, stderr: str | None = None, + error: str | None = None, ) -> "SnapshotBuilder": self.fm_steps[fm_step_id] = _filter_nones( FMStepSnapshot( @@ -83,6 +84,7 @@ def add_fm_step( stderr=stderr, current_memory_usage=current_memory_usage, max_memory_usage=max_memory_usage, + error=error, ) ) return self diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index e7b923ed25d..0485461548d 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -24,7 +24,6 @@ from everest.simulator import JOB_FAILURE from everest.strings import ( OPT_FAILURE_REALIZATIONS, - STOP_ENDPOINT, ) diff --git a/tests/everest/test_monitor.py b/tests/everest/test_monitor.py index 53579b7b876..e6c078136ce 100644 --- a/tests/everest/test_monitor.py +++ b/tests/everest/test_monitor.py @@ -104,6 +104,63 @@ def snapshot_update_event(): yield json.dumps(jsonable_encoder(event)) +@pytest.fixture +def snapshot_update_failure_event(): + event = SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + name=None, + index="0", + status=state.FORWARD_MODEL_STATE_FAILURE, + end_time=datetime(2019, 1, 1), + error="The run is cancelled due to reaching MAX_RUNTIME", + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_FAILED, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 0, "Running": 0, "Unknown": 0, "Failed": 1}, + iteration=0, + ) + yield json.dumps(jsonable_encoder(event)) + + +def test_failed_jobs_monitor( + monkeypatch, full_snapshot_event, snapshot_update_failure_event, capsys +): + server_mock = MagicMock() + connection_mock = MagicMock(spec=ClientConnection) + connection_mock.recv.side_effect = [ + full_snapshot_event, + snapshot_update_failure_event, + json.dumps(jsonable_encoder(EndEvent(failed=True, msg="Failed"))), + ] + server_mock.return_value.__enter__.return_value = connection_mock + monkeypatch.setattr(everest.detached, "_query_server", MagicMock(return_value={})) + monkeypatch.setattr(everest.detached, "connect", server_mock) + monkeypatch.setattr(everest.detached, "ssl", MagicMock()) + patched = partial(everest.detached.start_monitor, polling_interval=0.1) + with patch("everest.bin.utils.start_monitor", patched): + run_detached_monitor(("some/url", None, None), "output", False) + captured = capsys.readouterr() + expected = [ + "===================== Running forward models (Batch #0) ======================\n", + " Waiting: 0 | Pending: 0 | Running: 0 | Finished: 0 | Failed: 1\n", + " fm_step_0: 1/0/1 | Failed: 1" + " fm_step_0: Failed: The run is cancelled due to reaching MAX_RUNTIME, realizations: 1\n", + "Failed\n", + ] + # Ignore whitespace + assert captured.out.translate({ord(c): None for c in string.whitespace}) == "".join( + expected + ).translate({ord(c): None for c in string.whitespace}) + + @pytest.mark.parametrize("show_all_jobs", [True, False]) def test_monitor( monkeypatch, full_snapshot_event, snapshot_update_event, capsys, show_all_jobs @@ -125,12 +182,12 @@ def test_monitor( captured = capsys.readouterr() expected = [ "===================== Running forward models (Batch #0) ======================\n", - " Waiting: 0 | Pending: 0 | Running: 0 | Complete: 0 | Failed: 0\n", - " fm_step_0: 0/1/0 | Finished: 1\n", + " Waiting: 0 | Pending: 0 | Running: 0 | Finished: 1 | Failed: 0\n", + " fm_step_0: 1/1/0 | Finished: 1\n", "Experiment complete\n", ] if show_all_jobs: - expected[-1:-1] = [f"{name}: 0/0/0" for name in all_shell_script_fm_steps] + expected[-1:-1] = [f"{name}: 2/0/0" for name in all_shell_script_fm_steps] # Ignore whitespace assert captured.out.translate({ord(c): None for c in string.whitespace}) == "".join( expected From c92e178a3265bcc295f71c60bf316a840e153c08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Fri, 31 Jan 2025 08:38:57 +0100 Subject: [PATCH 7/9] Decrease polling interval --- src/everest/detached/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index d8ace4c12f8..d6890b2c903 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -245,7 +245,7 @@ def server_is_running(url: str, cert: str, auth: tuple[str, str]): def start_monitor( - server_context: tuple[str, str, tuple[str, str]], callback, polling_interval=5 + server_context: tuple[str, str, tuple[str, str]], callback, polling_interval=0.1 ): """ Checks status on Everest server and calls callback when status changes From 1eb5ca02da8e230d82a3ccae755a21891d2746e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Wed, 5 Feb 2025 14:18:26 +0100 Subject: [PATCH 8/9] Simplify everserver test Running these tests will cause a lot of dangling threads because the server is started, but not stopped, leading to potential flakyness. This mocks the server completely, which has the drawback that there is less testing of the actual server, but that is also tested in other tests. --- tests/everest/test_everserver.py | 63 +++++++++++--------------------- 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 0485461548d..129da84c66f 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -19,9 +19,8 @@ wait_for_server, ) from everest.detached.jobs import everserver -from everest.detached.jobs.everserver import ExperimentStatus, _everserver_thread +from everest.detached.jobs.everserver import ExperimentStatus from everest.everest_storage import EverestStorage -from everest.simulator import JOB_FAILURE from everest.strings import ( OPT_FAILURE_REALIZATIONS, ) @@ -71,17 +70,22 @@ def fail_optimization(self, from_ropt=False): raise Exception("Failed optimization") -def set_shared_status(*args, progress, shared_data): - # Patch _sim_monitor with this to access the shared_data variable in the - # everserver main function. - failed = len( - [job for queue in progress for job in queue if job["status"] == JOB_FAILURE] - ) +@pytest.fixture +def mock_server(monkeypatch): + def func(exit_code: int, message: str = ""): + def server_mock(shared_data, server_config): + shared_data["experiment_status"] = ExperimentStatus( + exit_code=exit_code, message=message + ) + + monkeypatch.setattr( + everest.detached.jobs.everserver, "_everserver_thread", server_mock + ) + monkeypatch.setattr( + everest.detached.jobs.everserver, "wait_for_server", MagicMock() + ) - shared_data["events"] = { - "status": {"failed": failed}, - "progress": progress, - } + yield func @pytest.mark.integration_test @@ -147,15 +151,9 @@ def mock_end_event(monkeypatch): @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") def test_status_running_complete( - mocked_logger, copy_math_func_test_data_to_tmp, monkeypatch + mocked_logger, copy_math_func_test_data_to_tmp, mock_server ): - def server_mock(shared_data, server_config): - shared_data["experiment_status"] = ExperimentStatus(exit_code=1) - _everserver_thread(shared_data, server_config) - - monkeypatch.setattr( - everest.detached.jobs.everserver, "_everserver_thread", server_mock - ) + mock_server(1) config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) @@ -171,18 +169,11 @@ def server_mock(shared_data, server_config): @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -def test_status_failed_job(mocked_logger, copy_math_func_test_data_to_tmp, monkeypatch): +def test_status_failed_job(mocked_logger, copy_math_func_test_data_to_tmp, mock_server): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - def server_mock(shared_data, server_config): - shared_data["experiment_status"] = ExperimentStatus(exit_code=2) - _everserver_thread(shared_data, server_config) - - monkeypatch.setattr( - everest.detached.jobs.everserver, "_everserver_thread", server_mock - ) - + mock_server(2) everserver.main() status = everserver_status( @@ -195,20 +186,10 @@ def server_mock(shared_data, server_config): @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) -def test_everserver_status_exception(copy_math_func_test_data_to_tmp, monkeypatch): +def test_everserver_status_exception(copy_math_func_test_data_to_tmp, mock_server): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - - def server_mock(shared_data, server_config): - shared_data["experiment_status"] = ExperimentStatus( - exit_code=6, message="Some message" - ) - _everserver_thread(shared_data, server_config) - - monkeypatch.setattr( - everest.detached.jobs.everserver, "_everserver_thread", server_mock - ) - + mock_server(6, "Some message") everserver.main() status = everserver_status( ServerConfig.get_everserver_status_path(config.output_dir) From a0827e3a5b66233c372621932a9a6cafb8b03c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Wed, 5 Feb 2025 09:25:52 +0100 Subject: [PATCH 9/9] Remove print statement --- tests/everest/test_everserver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 129da84c66f..ccddc21d713 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -253,5 +253,4 @@ async def test_status_contains_max_runtime_failure( ) assert status["status"] == ServerStatus.failed - print(status["message"]) assert "The run is cancelled due to reaching MAX_RUNTIME" in status["message"]