Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update communication between the everest server job and the experiment server #10051

Merged
merged 4 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class EverestExitCode(IntEnum):
MAX_FUNCTIONS_REACHED = 3
MAX_BATCH_NUM_REACHED = 4
USER_ABORT = 5
EXCEPTION = 6


class EverestRunModel(BaseRunModel):
Expand Down
2 changes: 1 addition & 1 deletion src/everest/detached/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def start_experiment(
except:
logging.debug(traceback.format_exc())
time.sleep(retry)
raise ValueError("Failed to start experiment")
raise RuntimeError("Failed to start experiment")


def extract_errors_from_file(path: str):
Expand Down
241 changes: 97 additions & 144 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import socket
import ssl
import threading
import time
import traceback
from base64 import b64encode
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path
from queue import Empty, SimpleQueue
from typing import Any

import requests
import uvicorn
from cryptography import x509
from cryptography.hazmat.backends import default_backend
Expand All @@ -40,74 +40,87 @@
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest.config import EverestConfig, ServerConfig
from everest.detached import (
PROXY,
ServerStatus,
get_opt_status,
update_everserver_status,
wait_for_server,
)
from everest.plugins.everest_plugin_manager import EverestPluginManager
from everest.simulator import JOB_FAILURE
from everest.strings import (
DEFAULT_LOGGING_FORMAT,
EVEREST,
EXPERIMENT_STATUS_ENDPOINT,
OPT_FAILURE_REALIZATIONS,
OPT_PROGRESS_ENDPOINT,
OPTIMIZATION_LOG_DIR,
OPTIMIZATION_OUTPUT_DIR,
SHARED_DATA_ENDPOINT,
SIM_PROGRESS_ENDPOINT,
START_EXPERIMENT_ENDPOINT,
STOP_ENDPOINT,
)
from everest.util import makedirs_if_needed, version_info


class ExperimentStatus(BaseModel):
class EverestServerMsg(BaseModel):
msg: str | None = None


class ServerStarted(EverestServerMsg):
pass


class ServerStopped(EverestServerMsg):
pass


class ExperimentComplete(EverestServerMsg):
exit_code: EverestExitCode
message: str | None = None
data: dict[str, Any]


class ExperimentFailed(EverestServerMsg):
pass


class ExperimentRunner(threading.Thread):
def __init__(self, everest_config, shared_data: dict):
def __init__(
self,
everest_config,
shared_data: dict,
msg_queue: SimpleQueue[EverestServerMsg],
):
super().__init__()

self._everest_config = everest_config
self._shared_data = shared_data
self._status: ExperimentStatus | None = None
self._msg_queue = msg_queue

def run(self):
run_model = EverestRunModel.create(
self._everest_config,
simulation_callback=partial(_sim_monitor, shared_data=self._shared_data),
optimization_callback=partial(_opt_monitor, shared_data=self._shared_data),
)

if run_model._queue_config.queue_system == QueueSystem.LOCAL:
evaluator_server_config = EvaluatorServerConfig()
else:
evaluator_server_config = EvaluatorServerConfig(
port_range=(49152, 51819), use_ipc_protocol=False
try:
run_model = EverestRunModel.create(
self._everest_config,
simulation_callback=partial(
_sim_monitor, shared_data=self._shared_data
),
optimization_callback=partial(
_opt_monitor, shared_data=self._shared_data
),
)
if run_model._queue_config.queue_system == QueueSystem.LOCAL:
evaluator_server_config = EvaluatorServerConfig()
else:
evaluator_server_config = EvaluatorServerConfig(
port_range=(49152, 51819), use_ipc_protocol=False
)

try:
run_model.run_experiment(evaluator_server_config)

assert run_model.exit_code is not None
self._status = ExperimentStatus(exit_code=run_model.exit_code)
except Exception as e:
self._status = ExperimentStatus(
exit_code=EverestExitCode.EXCEPTION, message=str(e)
self._msg_queue.put(
ExperimentComplete(
exit_code=run_model.exit_code, data=self._shared_data
)
)

@property
def status(self) -> ExperimentStatus | None:
return self._status

@property
def shared_data(self) -> dict:
return self._shared_data
except Exception as e:
self._msg_queue.put(ExperimentFailed(msg=str(e)))


def _get_machine_name() -> str:
Expand Down Expand Up @@ -140,15 +153,15 @@ def _get_machine_name() -> str:
def _sim_monitor(context_status, shared_data=None):
assert shared_data is not None

status = context_status["status"]
status_ = context_status["status"]
shared_data[SIM_PROGRESS_ENDPOINT] = {
"batch_number": context_status["batch_number"],
"status": {
"running": status.get("Running", 0),
"waiting": status.get("Waiting", 0),
"pending": status.get("Pending", 0),
"complete": status.get("Finished", 0),
"failed": status.get("Failed", 0),
"running": status_.get("Running", 0),
"waiting": status_.get("Waiting", 0),
"pending": status_.get("Pending", 0),
"complete": status_.get("Finished", 0),
"failed": status_.get("Failed", 0),
},
"progress": context_status["progress"],
}
Expand All @@ -163,8 +176,17 @@ def _opt_monitor(shared_data=None):
return "stop_optimization"


def _everserver_thread(shared_data, server_config) -> None:
app = FastAPI()
def _everserver_thread(shared_data, server_config, msg_queue) -> None:
# ruff: noqa: RUF029
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup event
msg_queue.put(ServerStarted())
yield
# Shutdown event
msg_queue.put(ServerStopped())

app = FastAPI(lifespan=lifespan)
security = HTTPBasic()

runner: ExperimentRunner | None = None
Expand Down Expand Up @@ -197,6 +219,7 @@ def stop(
_log(request)
_check_user(credentials)
shared_data[STOP_ENDPOINT] = True
msg_queue.put(ServerStopped())
return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200)

@app.get("/" + SIM_PROGRESS_ENDPOINT)
Expand Down Expand Up @@ -228,41 +251,14 @@ def start_experiment(

nonlocal runner
if runner is None:
runner = ExperimentRunner(config, shared_data)
runner = ExperimentRunner(config, shared_data, msg_queue)
try:
runner.start()
return Response("Everest experiment started")
except Exception as e:
return Response(f"Could not start experiment: {e!s}", status_code=501)
return Response("Everest experiment is running")

@app.get("/" + EXPERIMENT_STATUS_ENDPOINT)
def get_experiment_status(
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> Response:
_log(request)
_check_user(credentials)
if shared_data[STOP_ENDPOINT]:
return JSONResponse(
ExperimentStatus(exit_code=EverestExitCode.USER_ABORT).model_dump_json()
)
if runner is None:
return Response(None, 204)
status = runner.status
if status is None:
return Response(None, 204)
return JSONResponse(status.model_dump_json())

@app.get("/" + SHARED_DATA_ENDPOINT)
def get_shared_data(
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
_log(request)
_check_user(credentials)
if runner is None:
return JSONResponse(jsonable_encoder(shared_data))
return JSONResponse(jsonable_encoder(runner.shared_data))

uvicorn.run(
app,
host="0.0.0.0",
Expand Down Expand Up @@ -364,6 +360,7 @@ def main():

status_path = ServerConfig.get_everserver_status_path(output_dir)
host_file = ServerConfig.get_hostfile_path(output_dir)
msg_queue: SimpleQueue[EverestServerMsg] = SimpleQueue()

try:
_configure_loggers(
Expand Down Expand Up @@ -397,89 +394,49 @@ def main():
"key_passwd": key_pw,
"authentication": authentication,
}

# Starting the server
everserver_instance = threading.Thread(
target=_everserver_thread,
args=(shared_data, server_config),
args=(shared_data, server_config, msg_queue),
)
everserver_instance.daemon = True
everserver_instance.start()

# Monitoring the server
while True:
try:
item = msg_queue.get(timeout=1) # Wait for data
match item:
case ServerStarted():
update_everserver_status(status_path, ServerStatus.running)
case ServerStopped():
update_everserver_status(status_path, ServerStatus.stopped)
return
case ExperimentFailed():
update_everserver_status(
status_path, ServerStatus.failed, item.msg
)
return
case ExperimentComplete():
status, message = _get_optimization_status(
item.exit_code, item.data
)
update_everserver_status(status_path, status, message)
return
except Empty:
continue
except:
update_everserver_status(
status_path,
ServerStatus.failed,
message=traceback.format_exc(),
)
return

try:
wait_for_server(output_dir, 60)

update_everserver_status(status_path, ServerStatus.running)

server_context = (ServerConfig.get_server_context(output_dir),)
url, cert, auth = server_context[0]

done = False
experiment_status: ExperimentStatus | None = None
# loop until the optimization is done
while not done:
response = requests.get(
"/".join([url, EXPERIMENT_STATUS_ENDPOINT]),
verify=cert,
auth=auth,
timeout=1,
proxies=PROXY, # type: ignore
)
if response.status_code == requests.codes.OK:
json_body = json.loads(
response.text if hasattr(response, "text") else response.body
)
experiment_status = ExperimentStatus.model_validate_json(json_body)
done = True
else:
time.sleep(1)

response = requests.get(
"/".join([url, SHARED_DATA_ENDPOINT]),
verify=cert,
auth=auth,
timeout=1,
proxies=PROXY, # type: ignore
)
if json_body := json.loads(
response.text if hasattr(response, "text") else response.body
):
shared_data = json_body

assert experiment_status is not None
status, message = _get_optimization_status(experiment_status, shared_data)
if status != ServerStatus.completed:
update_everserver_status(status_path, status, message)
return
except:
if shared_data[STOP_ENDPOINT]:
update_everserver_status(
status_path,
ServerStatus.stopped,
message="Optimization aborted.",
)
else:
update_everserver_status(
status_path,
ServerStatus.failed,
message=traceback.format_exc(),
)
return

update_everserver_status(status_path, ServerStatus.completed, message=message)


def _get_optimization_status(
experiment_status: ExperimentStatus, shared_data: dict
exit_code: EverestExitCode, shared_data: dict
) -> tuple[ServerStatus, str]:
match experiment_status.exit_code:
match exit_code:
case EverestExitCode.MAX_BATCH_NUM_REACHED:
return ServerStatus.completed, "Maximum number of batches reached."

Expand All @@ -492,20 +449,16 @@ def _get_optimization_status(
case EverestExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."

case EverestExitCode.EXCEPTION:
assert experiment_status.message is not None
return ServerStatus.failed, experiment_status.message

case EverestExitCode.TOO_FEW_REALIZATIONS:
status = (
status_ = (
ServerStatus.stopped
if shared_data[STOP_ENDPOINT]
else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)
return status_, "\n".join(messages)
case _:
return ServerStatus.completed, "Optimization completed."

Expand Down
Loading