diff --git a/pyproject.toml b/pyproject.toml index cc5b4be1e99..59e2d8a5f00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,7 @@ types = [ "types-setuptools", ] everest = [ + "websockets", "progressbar2", "ruamel.yaml", "fastapi", diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 810912b407b..9359195be5a 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -23,17 +23,17 @@ from ropt.plan import Event as OptimizerEvent from typing_extensions import TypedDict -from _ert.events import EESnapshot, EESnapshotUpdate, Event from ert.config import ErtConfig, ExtParamConfig -from ert.ensemble_evaluator import EnsembleSnapshot, EvaluatorServerConfig +from ert.ensemble_evaluator import EndEvent, EvaluatorServerConfig from ert.runpaths import Runpaths from ert.storage import open_storage from everest.config import ControlConfig, ControlVariableGuessListConfig, EverestConfig +from everest.detached import ServerStatus from everest.everest_storage import EverestStorage, OptimalResult from everest.optimizer.everest2ropt import everest2ropt from everest.optimizer.opt_model_transforms import get_opt_model_transforms from everest.simulator.everest_to_ert import everest_to_ert_config -from everest.strings import EVEREST +from everest.strings import EVEREST, OPT_FAILURE_REALIZATIONS from ..run_arg import RunArg, create_run_arguments from .base_run_model import BaseRunModel, StatusEvents @@ -83,12 +83,12 @@ def __init__( self, config: ErtConfig, everest_config: EverestConfig, - simulation_callback: SimulationCallback | None, optimization_callback: OptimizerCallback | None, + status_queue: queue.SimpleQueue[StatusEvents] | None = None, ): Path(everest_config.log_dir).mkdir(parents=True, exist_ok=True) Path(everest_config.optimization_output_dir).mkdir(parents=True, exist_ok=True) - + status_queue = queue.SimpleQueue() if status_queue is None else status_queue assert everest_config.environment is not None logging.getLogger(EVEREST).info( "Using random seed: %d. To deterministically reproduce this experiment, " @@ -102,7 +102,6 @@ def __init__( everest_config, transforms=self._opt_model_transforms ) - self._sim_callback = simulation_callback self._opt_callback = optimization_callback self._fm_errors: dict[int, dict[str, Any]] = {} self._result: OptimalResult | None = None @@ -118,11 +117,8 @@ def __init__( self._experiment: Experiment | None = None self._eval_server_cfg: EvaluatorServerConfig | None = None self._batch_id: int = 0 - self._status: SimulationStatus | None = None storage = open_storage(config.ens_path, mode="w") - status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue() - super().__init__( storage, config.runpath_file, @@ -149,12 +145,13 @@ def create( ever_config: EverestConfig, simulation_callback: SimulationCallback | None = None, optimization_callback: OptimizerCallback | None = None, + status_queue: queue.SimpleQueue[StatusEvents] | None = None, ) -> EverestRunModel: return cls( config=everest_to_ert_config(ever_config), everest_config=ever_config, - simulation_callback=simulation_callback, optimization_callback=optimization_callback, + status_queue=status_queue, ) @classmethod @@ -213,6 +210,12 @@ def run_experiment( self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS case _: self._exit_code = EverestExitCode.COMPLETED + self.send_event( + EndEvent( + failed=self._exit_code != EverestExitCode.COMPLETED, + msg=_get_optimization_status(self._exit_code, "", False)[1], + ) + ) def _create_optimizer(self) -> BasicOptimizer: optimizer = BasicOptimizer( @@ -253,9 +256,6 @@ def _on_before_forward_model_evaluation( def _forward_model_evaluator( self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext ) -> EvaluatorResult: - # Reset the current run status: - self._status = None - # Get cached_results: cached_results = self._get_cached_results(control_values, evaluator_context) @@ -269,6 +269,7 @@ def _forward_model_evaluator( ensemble = self._experiment.create_ensemble( name=f"batch_{self._batch_id}", ensemble_size=len(batch_data), + iteration=self._batch_id, ) for sim_id, controls in enumerate(batch_data.values()): self._setup_sim(sim_id, controls, ensemble) @@ -420,10 +421,10 @@ def _get_run_args( batch_data: dict[int, Any], ) -> list[RunArg]: substitutions = self._substitutions - substitutions[""] = ensemble.name + substitutions[""] = ensemble.name # Dark magic, should be fixed self.active_realizations = [True] * len(batch_data) for sim_id, control_idx in enumerate(batch_data.keys()): - substitutions[f""] = str( + substitutions[f""] = str( self._everest_config.model.realizations[ evaluator_context.realizations[control_idx] ] @@ -565,82 +566,6 @@ def check_if_runpath_exists(self) -> bool: and any(os.listdir(self._everest_config.simulation_dir)) ) - def send_snapshot_event(self, event: Event, iteration: int) -> None: - super().send_snapshot_event(event, iteration) - if type(event) in {EESnapshot, EESnapshotUpdate}: - newstatus = self._simulation_status(self.get_current_snapshot()) - if self._status != newstatus: # No change in status - if self._sim_callback is not None: - self._sim_callback(newstatus) - self._status = newstatus - - def _simulation_status(self, snapshot: EnsembleSnapshot) -> SimulationStatus: - jobs_progress: list[list[JobProgress]] = [] - prev_realization = None - jobs: list[JobProgress] = [] - for (realization, simulation), fm_step in snapshot.get_all_fm_steps().items(): - if realization != prev_realization: - prev_realization = realization - if jobs: - jobs_progress.append(jobs) - jobs = [] - jobs.append( - { - "name": fm_step.get("name") or "Unknown", - "status": fm_step.get("status") or "Unknown", - "error": fm_step.get("error", ""), - "start_time": fm_step.get("start_time", None), - "end_time": fm_step.get("end_time", None), - "realization": realization, - "simulation": simulation, - } - ) - if fm_step.get("error", ""): - self._handle_errors( - batch=self._batch_id, - simulation=simulation, - realization=realization, - fm_name=fm_step.get("name", "Unknown"), # type: ignore - error_path=fm_step.get("stderr", ""), # type: ignore - fm_running_err=fm_step.get("error", ""), # type: ignore - ) - jobs_progress.append(jobs) - - return { - "status": self.get_current_status(), - "progress": jobs_progress, - "batch_number": self._batch_id, - } - - def _handle_errors( - self, - batch: int, - simulation: Any, - realization: str, - fm_name: str, - error_path: str, - fm_running_err: str, - ) -> None: - fm_id = f"b_{batch}_r_{realization}_s_{simulation}_{fm_name}" - fm_logger = logging.getLogger("forward_models") - if Path(error_path).is_file(): - error_str = Path(error_path).read_text(encoding="utf-8") or fm_running_err - else: - error_str = fm_running_err - error_hash = hash(error_str) - err_msg = "Batch: {} Realization: {} Simulation: {} Job: {} Failed {}".format( - batch, realization, simulation, fm_name, "\n Error: {} ID:{}" - ) - - if error_hash not in self._fm_errors: - error_id = len(self._fm_errors) - fm_logger.error(err_msg.format(error_str, error_id)) - self._fm_errors.update({error_hash: {"error_id": error_id, "ids": [fm_id]}}) - elif fm_id not in self._fm_errors[error_hash]["ids"]: - self._fm_errors[error_hash]["ids"].append(fm_id) - error_id = self._fm_errors[error_hash]["error_id"] - fm_logger.error(err_msg.format("Already reported as", error_id)) - class SimulatorCache: EPS = float(np.finfo(np.float32).eps) @@ -691,3 +616,30 @@ def get( if np.allclose(controls, control_values, rtol=0.0, atol=self.EPS): return objectives, constraints return None + + +def _get_optimization_status( + exit_code: int, exception: str, stopped: bool +) -> tuple[ServerStatus, str]: + match exit_code: + case EverestExitCode.MAX_BATCH_NUM_REACHED: + return ServerStatus.completed, "Maximum number of batches reached." + + case EverestExitCode.MAX_FUNCTIONS_REACHED: + return ( + ServerStatus.completed, + "Maximum number of function evaluations reached.", + ) + + case EverestExitCode.USER_ABORT: + return ServerStatus.stopped, "Optimization aborted." + + case EverestExitCode.EXCEPTION: + assert exception is not None + return ServerStatus.failed, exception + + case EverestExitCode.TOO_FEW_REALIZATIONS: + status = ServerStatus.stopped if stopped else ServerStatus.failed + return status, OPT_FAILURE_REALIZATIONS + case _: + return ServerStatus.completed, "Optimization completed." diff --git a/src/everest/bin/utils.py b/src/everest/bin/utils.py index ef73ce988ec..d28aa96fbde 100644 --- a/src/everest/bin/utils.py +++ b/src/everest/bin/utils.py @@ -2,6 +2,7 @@ import os import sys import traceback +from collections import defaultdict from dataclasses import dataclass, field from itertools import groupby from typing import ClassVar @@ -9,17 +10,21 @@ import colorama from colorama import Fore +from ert.ensemble_evaluator import ( + EnsembleSnapshot, + FullSnapshotEvent, + SnapshotUpdateEvent, +) from ert.resources import all_shell_script_fm_steps from everest.detached import ( OPT_PROGRESS_ID, - SIM_PROGRESS_ID, ServerStatus, everserver_status, get_opt_status, start_monitor, ) from everest.simulator import JOB_FAILURE, JOB_RUNNING, JOB_SUCCESS -from everest.strings import EVEREST +from everest.strings import EVEREST, SIM_PROGRESS_ID def handle_keyboard_interrupt(signal, frame, options): @@ -73,6 +78,7 @@ class JobProgress: JOB_FAILURE: [], # contains failed simulation numbers i.e [5,6] } ) + errors: defaultdict[list] = field(default_factory=lambda: defaultdict(list)) STATUS_COLOR: ClassVar = { JOB_RUNNING: Fore.BLUE, JOB_SUCCESS: Fore.GREEN, @@ -111,6 +117,7 @@ def __init__(self, show_all_jobs): self._batches_done = set() self._last_reported_batch = -1 colorama.init(autoreset=True) + self._snapshots: dict[int, EnsembleSnapshot] = {} def update(self, status): try: @@ -126,16 +133,37 @@ def update(self, status): print(msg + "\n") self._clear_lines = 0 if SIM_PROGRESS_ID in status: - sim_progress = status[SIM_PROGRESS_ID] - sim_progress["progress"] = self._filter_jobs(sim_progress["progress"]) - msg, batch = self.get_fm_progress(sim_progress) - if msg.strip(): - # Clear the previous report if it is still the same batch: - if batch == self._last_reported_batch: - self._clear() - print(msg) - self._clear_lines = len(msg.split("\n")) - self._last_reported_batch = batch + match status[SIM_PROGRESS_ID]: + case FullSnapshotEvent(snapshot=snapshot, iteration=batch): + if snapshot is not None: + self._snapshots[batch] = snapshot + case ( + SnapshotUpdateEvent(snapshot=snapshot, iteration=batch) as event + ): + if snapshot is not None: + batch_number = event.iteration + self._snapshots[batch_number].merge_snapshot(snapshot) + header = self._make_header( + f"Running forward models (Batch #{batch_number})", + Fore.BLUE, + ) + summary = self._get_progress_summary(event.status_count) + job_states = self._get_job_states( + self._snapshots[batch_number], self._show_all_jobs + ) + msg = ( + self._join_two_newlines_indent( + (header, summary, job_states) + ) + + "\n" + ) + if batch == self._last_reported_batch: + self._clear() + print(msg) + self._clear_lines = len(msg.split("\n")) + self._last_reported_batch = max( + self._last_reported_batch, batch + ) except: logging.getLogger(EVEREST).debug(traceback.format_exc()) @@ -173,36 +201,28 @@ def _get_opt_progress_batch(self, cli_monitor_data, batch, idx): (header, controls, objectives, total_objective) ) - def get_fm_progress(self, context_status): - batch_number = int(context_status["batch_number"]) - header = self._make_header( - f"Running forward models (Batch #{batch_number})", Fore.BLUE - ) - summary = self._get_progress_summary(context_status["status"]) - job_states = self._get_job_states(context_status["progress"]) - msg = self._join_two_newlines_indent((header, summary, job_states)) + "\n" - return msg, batch_number - @staticmethod def _get_progress_summary(status): colors = [ Fore.BLACK, Fore.BLACK, - Fore.BLUE if status["running"] > 0 else Fore.BLACK, - Fore.GREEN if status["complete"] > 0 else Fore.BLACK, - Fore.RED if status["failed"] > 0 else Fore.BLACK, + Fore.BLUE if status.get("Running", 0) > 0 else Fore.BLACK, + Fore.GREEN if status.get("Finished", 0) > 0 else Fore.BLACK, + Fore.RED if status.get("Failed", 0) > 0 else Fore.BLACK, ] - labels = ("Waiting", "Pending", "Running", "Complete", "FAILED") - values = [status.get(ls.lower(), 0) for ls in labels] + labels = ("Waiting", "Pending", "Running", "Finished", "Failed") + values = [status.get(ls, 0) for ls in labels] return " | ".join( f"{color}{key}: {value}{Fore.RESET}" for color, key, value in zip(colors, labels, values, strict=False) ) @classmethod - def _get_job_states(cls, progress): - print_lines = "" - jobs_status = cls._get_jobs_status(progress) + def _get_job_states(cls, snapshot: EnsembleSnapshot, show_all_jobs: bool): + print_lines = [] + jobs_status = cls._get_jobs_status(snapshot) + if not show_all_jobs: + jobs_status = cls._filter_jobs(jobs_status) if jobs_status: max_widths = { state: _get_max_width( @@ -211,36 +231,35 @@ def _get_job_states(cls, progress): for state in [JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE] } width = _get_max_width([item.name for item in jobs_status]) - print_lines = cls._join_one_newline_indent( - f"{item.name:>{width}}: {item.progress_str(max_widths)}{Fore.RESET}" - for item in jobs_status - ) - return print_lines + for job in jobs_status: + print_lines.append( + f"{job.name:>{width}}: {job.progress_str(max_widths)}{Fore.RESET}" + ) + if job.errors: + print_lines.extend( + [ + f"{Fore.RED}{job.name:>{width}}: Failed: {err}, realizations: {_format_list(job.errors[err])}{Fore.RESET}" + for err in job.errors + ] + ) + return cls._join_one_newline_indent(print_lines) @staticmethod - def _get_jobs_status(progress): + def _get_jobs_status(snapshot: EnsembleSnapshot) -> list[JobProgress]: job_progress = {} - for queue in progress: - for job_idx, job in enumerate(queue): - if job_idx not in job_progress: - job_progress[job_idx] = JobProgress(name=job["name"]) - realization = int(job["realization"]) - status = job["status"] - if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}: - job_progress[job_idx].status[status].append(realization) - return job_progress.values() - - def _filter_jobs(self, progress): - if not self._show_all_jobs: - progress = [ - [ - job - for job in progress_list - if job["name"] not in all_shell_script_fm_steps - ] - for progress_list in progress - ] - return progress + for (realization, job_idx), job in snapshot.get_all_fm_steps().items(): + if job_idx not in job_progress: + job_progress[job_idx] = JobProgress(name=job["name"]) + status = job["status"] + if status in {JOB_RUNNING, JOB_SUCCESS, JOB_FAILURE}: + job_progress[job_idx].status[status].append(int(realization)) + if error := job.get("error"): + job_progress[job_idx].errors[error].append(int(realization)) + return list(job_progress.values()) + + @staticmethod + def _filter_jobs(jobs: list[JobProgress]): + return [job for job in jobs if job.name not in all_shell_script_fm_steps] @classmethod def _join_one_newline_indent(cls, sequence): diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index b9e6104a8df..d6890b2c903 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -4,16 +4,21 @@ import logging import os import re +import ssl import time import traceback from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Literal +from typing import Annotated, Literal import polars as pl import requests +from pydantic import BaseModel, ConfigDict, Field, ValidationError +from websockets.sync.client import connect +from ert.ensemble_evaluator import EndEvent, EnsembleSnapshot +from ert.run_models import StatusEvents from ert.scheduler import create_driver from ert.scheduler.driver import Driver, FailedSubmit from ert.scheduler.event import StartedEvent @@ -23,7 +28,6 @@ EVEREST_SERVER_CONFIG, OPT_PROGRESS_ENDPOINT, OPT_PROGRESS_ID, - SIM_PROGRESS_ENDPOINT, SIM_PROGRESS_ID, START_EXPERIMENT_ENDPOINT, STOP_ENDPOINT, @@ -42,6 +46,14 @@ # everest.log file instead +logger = logging.getLogger(__name__) + + +class EventWrapper(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + event: Annotated[StatusEvents, Field(discriminator="event_type")] + + async def start_server(config: EverestConfig, debug: bool = False) -> Driver: """ Start an Everest server running the optimization defined in the config @@ -197,7 +209,7 @@ def get_opt_status(output_folder): def wait_for_server_to_stop(server_context: tuple[str, str, tuple[str, str]], timeout): """ - Checks everest server has stoped _HTTP_REQUEST_RETRY times. Waits + Checks everest server has stopped _HTTP_REQUEST_RETRY times. Waits progressively longer between each check. Raise an exception when the timeout is reached. @@ -233,7 +245,7 @@ def server_is_running(url: str, cert: str, auth: tuple[str, str]): def start_monitor( - server_context: tuple[str, str, tuple[str, str]], callback, polling_interval=5 + server_context: tuple[str, str, tuple[str, str]], callback, polling_interval=0.1 ): """ Checks status on Everest server and calls callback when status changes @@ -242,29 +254,45 @@ def start_monitor( interrupted by returning True from the callback """ url, cert, auth = server_context - sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT]) opt_endpoint = "/".join([url, OPT_PROGRESS_ENDPOINT]) - sim_status: dict = {} opt_status: dict = {} stop = False + ssl_context = ssl.create_default_context() + ssl_context.load_verify_locations(cafile=cert) + try: - while not stop: - new_sim_status = _query_server(cert, auth, sim_endpoint) - if new_sim_status != sim_status: - sim_status = new_sim_status - ret = bool(callback({SIM_PROGRESS_ID: sim_status})) - stop |= ret - # When the API will support it query only from a certain batch on - - # Check the optimization status - new_opt_status = _query_server(cert, auth, opt_endpoint) - if new_opt_status != opt_status: - opt_status = new_opt_status - ret = bool(callback({OPT_PROGRESS_ID: opt_status})) - stop |= ret - time.sleep(polling_interval) + with connect( + "wss://{username}:{password}@" + url.replace("https://", "") + "/events", + ssl=ssl_context, + open_timeout=30, + ) as websocket: + while not stop: + try: + message = websocket.recv(timeout=1.0) + except TimeoutError: + message = None + if message: + event_dict = json.loads(message) + if "snapshot" in event_dict: + event_dict["snapshot"] = EnsembleSnapshot.from_nested_dict( + event_dict["snapshot"] + ) + try: + event = EventWrapper(event=event_dict).event + except ValidationError as e: + logger.error("Error when processing event %s", exc_info=e) + if isinstance(event, EndEvent): + print(event.msg) + callback({SIM_PROGRESS_ID: event}) + # Check the optimization status + new_opt_status = _query_server(cert, auth, opt_endpoint) + if new_opt_status != opt_status: + opt_status = new_opt_status + ret = bool(callback({OPT_PROGRESS_ID: opt_status})) + stop |= ret + time.sleep(polling_interval) except: logging.debug(traceback.format_exc()) diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 64bef7b14ca..07beec73263 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -1,19 +1,22 @@ import argparse +import asyncio import datetime import json import logging +import multiprocessing as mp import os +import queue import socket import ssl import threading import time import traceback +import uuid from base64 import b64encode from functools import partial from pathlib import Path from typing import Any -import requests import uvicorn from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -21,7 +24,15 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID from dns import resolver, reversename -from fastapi import Depends, FastAPI, HTTPException, Request, status +from fastapi import ( + BackgroundTasks, + Depends, + FastAPI, + HTTPException, + Request, + WebSocket, + status, +) from fastapi.encoders import jsonable_encoder from fastapi.responses import ( JSONResponse, @@ -35,26 +46,30 @@ from pydantic import BaseModel from ert.config.parsing.queue_system import QueueSystem -from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel +from ert.ensemble_evaluator import ( + EndEvent, + EvaluatorServerConfig, + FullSnapshotEvent, + SnapshotUpdateEvent, +) +from ert.run_models import StatusEvents +from ert.run_models.everest_run_model import ( + EverestExitCode, + EverestRunModel, + _get_optimization_status, +) from everest.config import EverestConfig, ServerConfig from everest.detached import ( - PROXY, ServerStatus, get_opt_status, update_everserver_status, wait_for_server, ) from everest.plugins.everest_plugin_manager import EverestPluginManager -from everest.simulator import JOB_FAILURE from everest.strings import ( DEFAULT_LOGGING_FORMAT, EVEREST, - EXPERIMENT_STATUS_ENDPOINT, - OPT_FAILURE_REALIZATIONS, OPT_PROGRESS_ENDPOINT, - SHARED_DATA_ENDPOINT, - SIM_PROGRESS_ENDPOINT, START_EXPERIMENT_ENDPOINT, STOP_ENDPOINT, ) @@ -66,19 +81,19 @@ class ExperimentStatus(BaseModel): message: str | None = None -class ExperimentRunner(threading.Thread): +class ExperimentRunner: def __init__(self, everest_config, shared_data: dict): super().__init__() self._everest_config = everest_config - self._shared_data = shared_data - self._status: ExperimentStatus | None = None + self.shared_data = shared_data - def run(self): + async def run(self): + status_queue: mp.Queue[StatusEvents] = mp.Queue() run_model = EverestRunModel.create( self._everest_config, - simulation_callback=partial(_sim_monitor, shared_data=self._shared_data), - optimization_callback=partial(_opt_monitor, shared_data=self._shared_data), + optimization_callback=partial(_opt_monitor, shared_data=self.shared_data), + status_queue=status_queue, ) if run_model._queue_config.queue_system == QueueSystem.LOCAL: @@ -89,22 +104,50 @@ def run(self): ) try: - run_model.run_experiment(evaluator_server_config) - + loop = asyncio.get_running_loop() + simulation_future = loop.run_in_executor( + None, + lambda: run_model.run_experiment(evaluator_server_config), + ) + while True: + if self.shared_data[STOP_ENDPOINT]: + run_model.cancel() + raise ValueError("Optimization aborted") + try: + item: StatusEvents = status_queue.get(block=False) + except queue.Empty: + await asyncio.sleep(0.01) + continue + + self.shared_data["events"].append(item) + for sub in self.shared_data["subscribers"].values(): + sub.notify() + + if isinstance(item, EndEvent): + break + await asyncio.sleep(0.1) + await simulation_future assert run_model.exit_code is not None - self._status = ExperimentStatus(exit_code=run_model.exit_code) + self.shared_data["experiment_status"] = ExperimentStatus( + exit_code=run_model.exit_code + ) except Exception as e: - self._status = ExperimentStatus( + self.shared_data["experiment_status"] = ExperimentStatus( exit_code=EverestExitCode.EXCEPTION, message=str(e) ) - @property - def status(self) -> ExperimentStatus | None: - return self._status - @property - def shared_data(self) -> dict: - return self._shared_data +class Subscriber: + def __init__(self) -> None: + self.index = 0 + self._event = asyncio.Event() + + def notify(self): + self._event.set() + + async def wait_for_event(self): + await self._event.wait() + self._event.clear() def _get_machine_name() -> str: @@ -134,26 +177,6 @@ def _get_machine_name() -> str: return "localhost" -def _sim_monitor(context_status, shared_data=None): - assert shared_data is not None - - status = context_status["status"] - shared_data[SIM_PROGRESS_ENDPOINT] = { - "batch_number": context_status["batch_number"], - "status": { - "running": status.get("Running", 0), - "waiting": status.get("Waiting", 0), - "pending": status.get("Pending", 0), - "complete": status.get("Finished", 0), - "failed": status.get("Failed", 0), - }, - "progress": context_status["progress"], - } - - if shared_data[STOP_ENDPOINT]: - return "stop_queue" - - def _opt_monitor(shared_data=None): assert shared_data is not None if shared_data[STOP_ENDPOINT]: @@ -164,8 +187,6 @@ def _everserver_thread(shared_data, server_config) -> None: app = FastAPI() security = HTTPBasic() - runner: ExperimentRunner | None = None - def _check_user(credentials: HTTPBasicCredentials) -> None: if credentials.password != server_config["authentication"]: raise HTTPException( @@ -196,15 +217,6 @@ def stop( shared_data[STOP_ENDPOINT] = True return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200) - @app.get("/" + SIM_PROGRESS_ENDPOINT) - def get_sim_progress( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - progress = shared_data[SIM_PROGRESS_ENDPOINT] - return JSONResponse(jsonable_encoder(progress)) - @app.get("/" + OPT_PROGRESS_ENDPOINT) def get_opt_progress( request: Request, credentials: HTTPBasicCredentials = Depends(security) @@ -215,50 +227,49 @@ def get_opt_progress( return JSONResponse(jsonable_encoder(progress)) @app.post("/" + START_EXPERIMENT_ENDPOINT) - def start_experiment( + async def start_experiment( config: EverestConfig, request: Request, + background_tasks: BackgroundTasks, credentials: HTTPBasicCredentials = Depends(security), ) -> Response: _log(request) _check_user(credentials) - nonlocal runner - if runner is None: + if not shared_data["started"]: runner = ExperimentRunner(config, shared_data) try: - runner.start() + background_tasks.add_task(runner.run) + shared_data["started"] = True return Response("Everest experiment started") except Exception as e: return Response(f"Could not start experiment: {e!s}", status_code=501) return Response("Everest experiment is running") - @app.get("/" + EXPERIMENT_STATUS_ENDPOINT) - def get_experiment_status( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> Response: - _log(request) - _check_user(credentials) - if shared_data[STOP_ENDPOINT]: - return JSONResponse( - ExperimentStatus(exit_code=EverestExitCode.USER_ABORT).model_dump_json() - ) - if runner is None: - return Response(None, 204) - status = runner.status - if status is None: - return Response(None, 204) - return JSONResponse(status.model_dump_json()) - - @app.get("/" + SHARED_DATA_ENDPOINT) - def get_shared_data( - request: Request, credentials: HTTPBasicCredentials = Depends(security) - ) -> JSONResponse: - _log(request) - _check_user(credentials) - if runner is None: - return JSONResponse(jsonable_encoder(shared_data)) - return JSONResponse(jsonable_encoder(runner.shared_data)) + @app.websocket("/events") + async def websocket_endpoint(websocket: WebSocket): + subscriber_id = str(uuid.uuid4()) + await websocket.accept() + while True: + event = await get_event(subscriber_id=subscriber_id) + await websocket.send_json(jsonable_encoder(event)) + await asyncio.sleep(0.1) + if isinstance(event, EndEvent): + # Give some time for subscribers to get events + await asyncio.sleep(5) + break + + async def get_event(subscriber_id: str) -> StatusEvents: + if subscriber_id not in shared_data["subscribers"]: + shared_data["subscribers"][subscriber_id] = Subscriber() + subscriber = shared_data["subscribers"][subscriber_id] + + while subscriber.index >= len(shared_data["events"]): + await subscriber.wait_for_event() + + event = shared_data["events"][subscriber.index] + shared_data["subscribers"][subscriber_id].index += 1 + return event uvicorn.run( app, @@ -341,6 +352,17 @@ def make_handler_config( EverestPluginManager().add_log_handle_to_root() +def send_end_event(shared_data, failed, msg): + shared_data["events"].append( + EndEvent( + failed=failed, + msg=msg, + ) + ) + for sub in shared_data["subscribers"].values(): + sub.notify() + + def main(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument("--config-file", type=str) @@ -375,8 +397,11 @@ def main(): _write_hostfile(host_file, host, port, cert_path, authentication) shared_data = { - SIM_PROGRESS_ENDPOINT: {}, STOP_ENDPOINT: False, + "started": False, + "experiment_status": None, + "events": [], + "subscribers": {}, } server_config = { @@ -402,51 +427,38 @@ def main(): message=traceback.format_exc(), ) return - try: wait_for_server(config.output_dir, 60) update_everserver_status(status_path, ServerStatus.running) - server_context = (ServerConfig.get_server_context(config.output_dir),) - url, cert, auth = server_context[0] - - done = False - experiment_status: ExperimentStatus | None = None # loop until the optimization is done - while not done: - response = requests.get( - "/".join([url, EXPERIMENT_STATUS_ENDPOINT]), - verify=cert, - auth=auth, - timeout=1, - proxies=PROXY, # type: ignore + while not shared_data["experiment_status"]: + if shared_data[STOP_ENDPOINT] and not shared_data["started"]: + raise ValueError("Stopped by user") + time.sleep(1) + experiment_status = shared_data["experiment_status"] + assert experiment_status is not None + snapshots = {} + for event in shared_data["events"]: + if isinstance(event, FullSnapshotEvent): + snapshots[event.iteration] = event + elif isinstance(event, SnapshotUpdateEvent): + snapshots[event.iteration].snapshot.merge_snapshot(event.snapshot) + logging.getLogger("forward_models").info("Status event") + fm_status = "\n Forward model status:\b" + for snapshot in snapshots.values(): + logging.getLogger("forward_models").info( + f"Status event: {jsonable_encoder(snapshot)}" ) - if response.status_code == requests.codes.OK: - json_body = json.loads( - response.text if hasattr(response, "text") else response.body - ) - experiment_status = ExperimentStatus.model_validate_json(json_body) - done = True - else: - time.sleep(1) - - response = requests.get( - "/".join([url, SHARED_DATA_ENDPOINT]), - verify=cert, - auth=auth, - timeout=1, - proxies=PROXY, # type: ignore + fm_status += str(snapshot.snapshot.to_dict()) + status, message = _get_optimization_status( + experiment_status.exit_code, + experiment_status.message, + shared_data[STOP_ENDPOINT], ) - if json_body := json.loads( - response.text if hasattr(response, "text") else response.body - ): - shared_data = json_body - - assert experiment_status is not None - status, message = _get_optimization_status(experiment_status, shared_data) if status != ServerStatus.completed: - update_everserver_status(status_path, status, message) + update_everserver_status(status_path, status, message + fm_status) return except: if shared_data[STOP_ENDPOINT]: @@ -462,58 +474,9 @@ def main(): message=traceback.format_exc(), ) return - update_everserver_status(status_path, ServerStatus.completed, message=message) -def _get_optimization_status( - experiment_status: ExperimentStatus, shared_data: dict -) -> tuple[ServerStatus, str]: - match experiment_status.exit_code: - case EverestExitCode.MAX_BATCH_NUM_REACHED: - return ServerStatus.completed, "Maximum number of batches reached." - - case EverestExitCode.MAX_FUNCTIONS_REACHED: - return ( - ServerStatus.completed, - "Maximum number of function evaluations reached.", - ) - - case EverestExitCode.USER_ABORT: - return ServerStatus.stopped, "Optimization aborted." - - case EverestExitCode.EXCEPTION: - assert experiment_status.message is not None - return ServerStatus.failed, experiment_status.message - - case EverestExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped - if shared_data[STOP_ENDPOINT] - else ServerStatus.failed - ) - messages = _failed_realizations_messages(shared_data) - for msg in messages: - logging.getLogger(EVEREST).error(msg) - return status, "\n".join(messages) - case _: - return ServerStatus.completed, "Optimization completed." - - -def _failed_realizations_messages(shared_data): - messages = [OPT_FAILURE_REALIZATIONS] - failed = shared_data[SIM_PROGRESS_ENDPOINT]["status"]["failed"] - if failed > 0: - # Report each unique pair of failed job name and error - for queue in shared_data[SIM_PROGRESS_ENDPOINT]["progress"]: - for job in queue: - if job["status"] == JOB_FAILURE: - err_msg = f"{job['name']} Failed with: {job.get('error', '')}" - if err_msg not in messages: - messages.append(err_msg) - return messages - - def _generate_certificate(cert_folder: str): """Generate a private key and a certificate signed with it diff --git a/src/everest/strings.py b/src/everest/strings.py index 593f592fa35..d82f66579e2 100644 --- a/src/everest/strings.py +++ b/src/everest/strings.py @@ -27,7 +27,6 @@ SIMULATOR_START = "start" SIMULATOR_UPDATE = "update" SIMULATOR_END = "end" -SIM_PROGRESS_ENDPOINT = "sim_progress" SIM_PROGRESS_ID = "simulation_progress" STOP_ENDPOINT = "stop" STORAGE_DIR = "simulation_results" diff --git a/tests/ert/__init__.py b/tests/ert/__init__.py index e3db923e86d..5c1a39b9736 100644 --- a/tests/ert/__init__.py +++ b/tests/ert/__init__.py @@ -71,6 +71,7 @@ def add_fm_step( end_time: datetime | None = None, stdout: str | None = None, stderr: str | None = None, + error: str | None = None, ) -> "SnapshotBuilder": self.fm_steps[fm_step_id] = _filter_nones( FMStepSnapshot( @@ -83,6 +84,7 @@ def add_fm_step( stderr=stderr, current_memory_usage=current_memory_usage, max_memory_usage=max_memory_usage, + error=error, ) ) return self diff --git a/tests/everest/entry_points/test_everest_entry.py b/tests/everest/entry_points/test_everest_entry.py index 3b38b0496ab..efb79732853 100644 --- a/tests/everest/entry_points/test_everest_entry.py +++ b/tests/everest/entry_points/test_everest_entry.py @@ -1,69 +1,26 @@ import logging import os from functools import partial -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +import yaml -from ert.resources import all_shell_script_fm_steps +import everest from everest.bin.everest_script import everest_entry from everest.bin.kill_script import kill_entry from everest.bin.monitor_script import monitor_entry from everest.config import EverestConfig, ServerConfig from everest.detached import ( - SIM_PROGRESS_ENDPOINT, ServerStatus, everserver_status, update_everserver_status, ) -from everest.simulator import JOB_SUCCESS from tests.everest.utils import capture_streams CONFIG_FILE_MINIMAL = "config_minimal.yml" -def query_server_mock(cert, auth, endpoint): - url = "localhost" - sim_endpoint = "/".join([url, SIM_PROGRESS_ENDPOINT]) - - def build_job( - status=JOB_SUCCESS, - start_time="begining", - end_time="end", - name="default_job", - error=None, - ): - return { - "status": status, - "start_time": start_time, - "end_time": end_time, - "name": name, - "error": error, - "realization": 0, - } - - shell_cmd_jobs = [build_job(name=command) for command in all_shell_script_fm_steps] - all_jobs = [ - *shell_cmd_jobs, - build_job(name="make_pancakes"), - build_job(name="make_scrambled_eggs"), - ] - if endpoint == sim_endpoint: - return { - "status": { - "failed": 0, - "running": 0, - "complete": 1, - "pending": 0, - "waiting": 0, - }, - "progress": [all_jobs], - "batch_number": "0", - } - else: - raise Exception("Stop! Hands in the air!") - - def run_detached_monitor_mock(status=ServerStatus.completed, error=None, **kwargs): optimization_output = kwargs.get("optimization_output_dir") path = os.path.join(optimization_output, "../detached_node_output/.session/status") @@ -290,145 +247,30 @@ def test_everest_entry_monitor_no_run( everserver_status_mock.assert_called() -@patch("everest.bin.everest_script.server_is_running", return_value=False) -@patch("everest.bin.everest_script.wait_for_server") -@patch("everest.bin.everest_script.start_server") -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.everest_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) -@patch("everest.bin.everest_script.start_experiment") -def test_everest_entry_show_all_jobs( - start_experiment_mock, - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - start_server_mock, - wait_for_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, -): - """Test running everest with --show-all-jobs""" - - # Test when --show-all-jobs flag is given shell command are in the list - # of forward model jobs - with capture_streams() as (out, _): - everest_entry([CONFIG_FILE_MINIMAL, "--show-all-jobs"]) - for cmd in all_shell_script_fm_steps: - assert cmd in out.getvalue() - - -@patch("everest.bin.everest_script.server_is_running", return_value=False) -@patch("everest.bin.everest_script.wait_for_server") -@patch("everest.bin.everest_script.start_server") -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.everest_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) -@patch("everest.bin.everest_script.start_experiment") -def test_everest_entry_no_show_all_jobs( - start_experiment_mock, - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - start_server_mock, - wait_for_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, -): - """Test running everest without --show-all-jobs""" - - # Test when --show-all-jobs flag is not given the shell command are not - # in the list of forward model jobs - with capture_streams() as (out, _): - everest_entry([CONFIG_FILE_MINIMAL]) - for cmd in all_shell_script_fm_steps: - assert cmd not in out.getvalue() - - # Check the other jobs are still there - assert "make_pancakes" in out.getvalue() - assert "make_scrambled_eggs" in out.getvalue() +@pytest.fixture(autouse=True) +def mock_ssl(monkeypatch): + monkeypatch.setattr(everest.detached, "ssl", MagicMock()) +@pytest.mark.parametrize("show_all_jobs", [True, False]) @patch("everest.bin.monitor_script.server_is_running", return_value=True) -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.monitor_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) def test_monitor_entry_show_all_jobs( - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, + _, + monkeypatch, + tmp_path, + min_config, + show_all_jobs, ): """Test running everest with and without --show-all-jobs""" - - # Test when --show-all-jobs flag is given shell command are in the list - # of forward model jobs - - with capture_streams() as (out, _): - monitor_entry([CONFIG_FILE_MINIMAL, "--show-all-jobs"]) - for cmd in all_shell_script_fm_steps: - assert cmd in out.getvalue() - - -@patch("everest.bin.monitor_script.server_is_running", return_value=True) -@patch("everest.detached._query_server", side_effect=query_server_mock) -@patch.object( - ServerConfig, - "get_server_context", - return_value=("localhost", "", ""), -) -@patch("everest.detached.get_opt_status", return_value={}) -@patch( - "everest.bin.monitor_script.everserver_status", - return_value={"status": ServerStatus.never_run, "message": None}, -) -def test_monitor_entry_no_show_all_jobs( - everserver_status_mock, - get_opt_status_mock, - get_server_context_mock, - query_server_mock, - server_is_running_mock, - copy_math_func_test_data_to_tmp, -): - """Test running everest without --show-all-jobs""" - - # Test when --show-all-jobs flag is not given the shell command are not - # in the list of forward model jobs - with capture_streams() as (out, _): - monitor_entry([CONFIG_FILE_MINIMAL]) - for cmd in all_shell_script_fm_steps: - assert cmd not in out.getvalue() - - # Check the other jobs are still there - assert "make_pancakes" in out.getvalue() - assert "make_scrambled_eggs" in out.getvalue() + monkeypatch.chdir(tmp_path) + with open("config.yml", "w", encoding="utf-8") as fout: + yaml.dump(min_config, fout) + detatched_mock = MagicMock() + monkeypatch.setattr(everest.bin.utils, "start_monitor", MagicMock()) + monkeypatch.setattr(everest.bin.utils, "_DetachedMonitor", detatched_mock) + args = ["config.yml"] if not show_all_jobs else ["config.yml", "--show-all-jobs"] + monitor_entry(args) + detatched_mock.assert_called_once_with(show_all_jobs) @patch( diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index 565bfb7d0bf..ccddc21d713 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -2,13 +2,12 @@ import os import ssl from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest -import requests -from fastapi.encoders import jsonable_encoder -from fastapi.responses import JSONResponse +import everest +from ert.ensemble_evaluator import EndEvent from ert.run_models.everest_run_model import EverestExitCode from ert.scheduler.event import FinishedEvent from everest.config import EverestConfig, ServerConfig @@ -20,12 +19,10 @@ wait_for_server, ) from everest.detached.jobs import everserver +from everest.detached.jobs.everserver import ExperimentStatus from everest.everest_storage import EverestStorage -from everest.simulator import JOB_FAILURE, JOB_SUCCESS from everest.strings import ( OPT_FAILURE_REALIZATIONS, - SIM_PROGRESS_ENDPOINT, - STOP_ENDPOINT, ) @@ -66,7 +63,6 @@ def fail_optimization(self, from_ropt=False): # call the provided simulation callback, which has access to the shared_data # variable in the eversever main function. Patch that callback to modify # shared_data (see set_shared_status() below). - self._sim_callback(None) if from_ropt: self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS return EverestExitCode.TOO_FEW_REALIZATIONS @@ -74,17 +70,22 @@ def fail_optimization(self, from_ropt=False): raise Exception("Failed optimization") -def set_shared_status(*args, progress, shared_data): - # Patch _sim_monitor with this to access the shared_data variable in the - # everserver main function. - failed = len( - [job for queue in progress for job in queue if job["status"] == JOB_FAILURE] - ) +@pytest.fixture +def mock_server(monkeypatch): + def func(exit_code: int, message: str = ""): + def server_mock(shared_data, server_config): + shared_data["experiment_status"] = ExperimentStatus( + exit_code=exit_code, message=message + ) - shared_data[SIM_PROGRESS_ENDPOINT] = { - "status": {"failed": failed}, - "progress": progress, - } + monkeypatch.setattr( + everest.detached.jobs.everserver, "_everserver_thread", server_mock + ) + monkeypatch.setattr( + everest.detached.jobs.everserver, "wait_for_server", MagicMock() + ) + + yield func @pytest.mark.integration_test @@ -137,37 +138,25 @@ def test_configure_logger_failure(mocked_logger, copy_math_func_test_data_to_tmp assert "Exception: Configuring logger failed" in status["message"] +@pytest.fixture +def mock_end_event(monkeypatch): + queue_mock = MagicMock() + queue_mock.get.return_value = EndEvent(failed=False, msg="Experiment complete") + monkeypatch.setattr( + everest.detached.jobs.everserver.mp, "Queue", MagicMock(return_value=queue_mock) + ) + yield queue_mock + + @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("requests.get") def test_status_running_complete( - mocked_get, mocked_logger, copy_math_func_test_data_to_tmp + mocked_logger, copy_math_func_test_data_to_tmp, mock_server ): + mock_server(1) config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - def mocked_server(url, verify, auth, timeout, proxies): - if "/experiment_status" in url: - return JSONResponse( - everserver.ExperimentStatus( - exit_code=EverestExitCode.COMPLETED - ).model_dump_json() - ) - if "/shared_data" in url: - return JSONResponse( - jsonable_encoder( - { - SIM_PROGRESS_ENDPOINT: {}, - STOP_ENDPOINT: False, - } - ) - ) - resp = requests.Response() - resp.status_code = 200 - return resp - - mocked_get.side_effect = mocked_server - everserver.main() status = everserver_status( @@ -180,61 +169,11 @@ def mocked_server(url, verify, auth, timeout, proxies): @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) @patch("everest.detached.jobs.everserver._configure_loggers") -@patch("requests.get") -def test_status_failed_job(mocked_get, mocked_logger, copy_math_func_test_data_to_tmp): +def test_status_failed_job(mocked_logger, copy_math_func_test_data_to_tmp, mock_server): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - def mocked_server(url, verify, auth, timeout, proxies): - if "/experiment_status" in url: - return JSONResponse( - everserver.ExperimentStatus( - exit_code=EverestExitCode.TOO_FEW_REALIZATIONS - ).model_dump_json() - ) - if "/shared_data" in url: - return JSONResponse( - jsonable_encoder( - { - SIM_PROGRESS_ENDPOINT: { - "status": {"failed": 3}, - "progress": [ - [ - { - "name": "job1", - "status": JOB_FAILURE, - "error": "job 1 error 1", - }, - { - "name": "job1", - "status": JOB_FAILURE, - "error": "job 1 error 2", - }, - ], - [ - { - "name": "job2", - "status": JOB_SUCCESS, - "error": "", - }, - { - "name": "job2", - "status": JOB_FAILURE, - "error": "job 2 error 1", - }, - ], - ], - }, - STOP_ENDPOINT: False, - } - ) - ) - resp = requests.Response() - resp.status_code = 200 - return resp - - mocked_get.side_effect = mocked_server - + mock_server(2) everserver.main() status = everserver_status( @@ -244,43 +183,13 @@ def mocked_server(url, verify, auth, timeout, proxies): # The server should fail and store a user-friendly message. assert status["status"] == ServerStatus.failed assert OPT_FAILURE_REALIZATIONS in status["message"] - assert "job1 Failed with: job 1 error 1" in status["message"] - assert "job1 Failed with: job 1 error 2" in status["message"] - assert "job2 Failed with: job 2 error 1" in status["message"] @patch("sys.argv", ["name", "--config-file", "config_minimal.yml"]) -@patch("everest.detached.jobs.everserver._configure_loggers") -@patch("requests.get") -def test_status_exception(mocked_get, mocked_logger, copy_math_func_test_data_to_tmp): +def test_everserver_status_exception(copy_math_func_test_data_to_tmp, mock_server): config_file = "config_minimal.yml" config = EverestConfig.load_file(config_file) - - def mocked_server(url, verify, auth, timeout, proxies): - if "/experiment_status" in url: - return JSONResponse( - everserver.ExperimentStatus( - exit_code=EverestExitCode.EXCEPTION, message="Some message" - ).model_dump_json() - ) - if "/shared_data" in url: - return JSONResponse( - jsonable_encoder( - { - SIM_PROGRESS_ENDPOINT: { - "status": {}, - "progress": [], - }, - STOP_ENDPOINT: False, - } - ) - ) - resp = requests.Response() - resp.status_code = 200 - return resp - - mocked_get.side_effect = mocked_server - + mock_server(6, "Some message") everserver.main() status = everserver_status( ServerConfig.get_everserver_status_path(config.output_dir) @@ -328,7 +237,7 @@ async def test_status_contains_max_runtime_failure( config_file = "config_minimal.yml" Path("SLEEP_job").write_text("EXECUTABLE sleep", encoding="utf-8") - min_config["simulator"] = {"max_runtime": 2} + min_config["simulator"] = {"max_runtime": 1} min_config["forward_model"] = ["sleep 5"] min_config["install_jobs"] = [{"name": "sleep", "source": "SLEEP_job"}] @@ -344,8 +253,4 @@ async def test_status_contains_max_runtime_failure( ) assert status["status"] == ServerStatus.failed - print(status["message"]) - assert ( - "sleep Failed with: The run is cancelled due to reaching MAX_RUNTIME" - in status["message"] - ) + assert "The run is cancelled due to reaching MAX_RUNTIME" in status["message"] diff --git a/tests/everest/test_logging.py b/tests/everest/test_logging.py index 0ca1afe5aea..17976b74686 100644 --- a/tests/everest/test_logging.py +++ b/tests/everest/test_logging.py @@ -3,14 +3,12 @@ import pytest -from ert.scheduler.event import FinishedEvent +from everest.bin.main import start_everest from everest.config import ( EverestConfig, ServerConfig, ) from everest.config.install_job_config import InstallJobConfig -from everest.detached import start_experiment, start_server, wait_for_server -from everest.util import makedirs_if_needed def _string_exists_in_file(file_path, string): @@ -20,13 +18,7 @@ def _string_exists_in_file(file_path, string): @pytest.mark.timeout(240) # Simulation might not finish @pytest.mark.integration_test @pytest.mark.xdist_group(name="starts_everest") -async def test_logging_setup(copy_math_func_test_data_to_tmp): - async def server_running(): - while True: - event = await driver.event_queue.get() - if isinstance(event, FinishedEvent) and event.iens == 0: - return - +def test_logging_setup(copy_math_func_test_data_to_tmp): everest_config = EverestConfig.load_file("config_minimal.yml") everest_config.forward_model.append("toggle_failure --fail simulation_2") everest_config.install_jobs.append( @@ -39,19 +31,7 @@ async def server_running(): # start_server() loads config based on config_path, so we need to actually overwrite it everest_config.dump("config_minimal.yml") - - makedirs_if_needed(everest_config.output_dir, roll_if_exists=True) - driver = await start_server(everest_config, debug=True) - try: - wait_for_server(everest_config.output_dir, 120) - - start_experiment( - server_context=ServerConfig.get_server_context(everest_config.output_dir), - config=everest_config, - ) - except (SystemExit, RuntimeError) as e: - raise e - await server_running() + start_everest(["everest", "run", "config_minimal.yml"]) everest_output_path = os.path.join(os.getcwd(), "everest_output") everest_logs_dir_path = everest_config.log_dir @@ -68,10 +48,7 @@ async def server_running(): assert _string_exists_in_file(everest_log_path, "everest DEBUG:") assert _string_exists_in_file( - forward_model_log_path, "Exception: Failing simulation_2 by request!" - ) - assert _string_exists_in_file( - forward_model_log_path, "Exception: Failing simulation_2 by request!" + forward_model_log_path, "Process exited with status code 1" ) endpoint_logs = Path(endpoint_log_path).read_text(encoding="utf-8") diff --git a/tests/everest/test_monitor.py b/tests/everest/test_monitor.py new file mode 100644 index 00000000000..e6c078136ce --- /dev/null +++ b/tests/everest/test_monitor.py @@ -0,0 +1,194 @@ +import json +import string +from collections import defaultdict +from datetime import datetime +from functools import partial +from unittest.mock import MagicMock, patch + +import pytest +from websockets.sync.client import ClientConnection + +import everest +from ert.ensemble_evaluator import ( + EndEvent, + FullSnapshotEvent, + SnapshotUpdateEvent, + state, +) +from ert.ensemble_evaluator.snapshot import EnsembleSnapshotMetadata +from ert.resources import all_shell_script_fm_steps +from everest.bin.utils import run_detached_monitor +from tests.ert import SnapshotBuilder + +METADATA = EnsembleSnapshotMetadata( + aggr_fm_step_status_colors=defaultdict(dict), + real_status_colors={}, + sorted_real_ids=[], + sorted_fm_step_ids=defaultdict(list), +) + + +from fastapi.encoders import jsonable_encoder + + +@pytest.fixture +def full_snapshot_event(): + snapshot = SnapshotBuilder(metadata=METADATA) + snapshot.add_fm_step( + fm_step_id="0", + index="0", + name="fm_step_0", + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout="job_fm_step_0.stdout", + stderr="job_fm_step_0.stderr", + start_time=datetime(1999, 1, 1), + ) + for i, command in enumerate(all_shell_script_fm_steps): + snapshot.add_fm_step( + fm_step_id=str(i + 1), + index=str(i + 1), + name=command, + status=state.FORWARD_MODEL_STATE_START, + current_memory_usage="500", + max_memory_usage="1000", + stdout=None, + stderr=None, + start_time=datetime(1999, 1, 1), + ) + event = FullSnapshotEvent( + snapshot=snapshot.build( + real_ids=["0", "1"], + status=state.REALIZATION_STATE_PENDING, + start_time=datetime(1999, 1, 1), + exec_hosts="12121.121", + message="Some message", + ), + iteration_label="Foo", + total_iterations=1, + progress=0.25, + realization_count=4, + status_count={ + "Finished": 0, + "Pending": len(all_shell_script_fm_steps), + "Unknown": 0, + }, + iteration=0, + ) + yield json.dumps(jsonable_encoder(event)) + + +@pytest.fixture +def snapshot_update_event(): + event = SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + name=None, + index="0", + status=state.FORWARD_MODEL_STATE_FINISHED, + end_time=datetime(2019, 1, 1), + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_FINISHED, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 1, "Running": 0, "Unknown": 0}, + iteration=0, + ) + yield json.dumps(jsonable_encoder(event)) + + +@pytest.fixture +def snapshot_update_failure_event(): + event = SnapshotUpdateEvent( + snapshot=SnapshotBuilder(metadata=METADATA) + .add_fm_step( + fm_step_id="0", + name=None, + index="0", + status=state.FORWARD_MODEL_STATE_FAILURE, + end_time=datetime(2019, 1, 1), + error="The run is cancelled due to reaching MAX_RUNTIME", + ) + .build( + real_ids=["1"], + status=state.REALIZATION_STATE_FAILED, + ), + iteration_label="Foo", + total_iterations=1, + progress=0.5, + realization_count=4, + status_count={"Finished": 0, "Running": 0, "Unknown": 0, "Failed": 1}, + iteration=0, + ) + yield json.dumps(jsonable_encoder(event)) + + +def test_failed_jobs_monitor( + monkeypatch, full_snapshot_event, snapshot_update_failure_event, capsys +): + server_mock = MagicMock() + connection_mock = MagicMock(spec=ClientConnection) + connection_mock.recv.side_effect = [ + full_snapshot_event, + snapshot_update_failure_event, + json.dumps(jsonable_encoder(EndEvent(failed=True, msg="Failed"))), + ] + server_mock.return_value.__enter__.return_value = connection_mock + monkeypatch.setattr(everest.detached, "_query_server", MagicMock(return_value={})) + monkeypatch.setattr(everest.detached, "connect", server_mock) + monkeypatch.setattr(everest.detached, "ssl", MagicMock()) + patched = partial(everest.detached.start_monitor, polling_interval=0.1) + with patch("everest.bin.utils.start_monitor", patched): + run_detached_monitor(("some/url", None, None), "output", False) + captured = capsys.readouterr() + expected = [ + "===================== Running forward models (Batch #0) ======================\n", + " Waiting: 0 | Pending: 0 | Running: 0 | Finished: 0 | Failed: 1\n", + " fm_step_0: 1/0/1 | Failed: 1" + " fm_step_0: Failed: The run is cancelled due to reaching MAX_RUNTIME, realizations: 1\n", + "Failed\n", + ] + # Ignore whitespace + assert captured.out.translate({ord(c): None for c in string.whitespace}) == "".join( + expected + ).translate({ord(c): None for c in string.whitespace}) + + +@pytest.mark.parametrize("show_all_jobs", [True, False]) +def test_monitor( + monkeypatch, full_snapshot_event, snapshot_update_event, capsys, show_all_jobs +): + server_mock = MagicMock() + connection_mock = MagicMock(spec=ClientConnection) + connection_mock.recv.side_effect = [ + full_snapshot_event, + snapshot_update_event, + json.dumps(jsonable_encoder(EndEvent(failed=False, msg="Experiment complete"))), + ] + server_mock.return_value.__enter__.return_value = connection_mock + monkeypatch.setattr(everest.detached, "_query_server", MagicMock(return_value={})) + monkeypatch.setattr(everest.detached, "connect", server_mock) + monkeypatch.setattr(everest.detached, "ssl", MagicMock()) + patched = partial(everest.detached.start_monitor, polling_interval=0.1) + with patch("everest.bin.utils.start_monitor", patched): + run_detached_monitor(("some/url", None, None), "output", show_all_jobs) + captured = capsys.readouterr() + expected = [ + "===================== Running forward models (Batch #0) ======================\n", + " Waiting: 0 | Pending: 0 | Running: 0 | Finished: 1 | Failed: 0\n", + " fm_step_0: 1/1/0 | Finished: 1\n", + "Experiment complete\n", + ] + if show_all_jobs: + expected[-1:-1] = [f"{name}: 2/0/0" for name in all_shell_script_fm_steps] + # Ignore whitespace + assert captured.out.translate({ord(c): None for c in string.whitespace}) == "".join( + expected + ).translate({ord(c): None for c in string.whitespace})