Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Jan 27, 2025
1 parent 8334307 commit 5c92cb5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
59 changes: 40 additions & 19 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
HTTPBasic,
HTTPBasicCredentials,
)
from pydantic import BaseModel

from ert.config.parsing.queue_system import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
Expand Down Expand Up @@ -62,13 +63,18 @@
from everest.util import makedirs_if_needed, version_info


class ExperimentStatus(BaseModel):
exit_code: EverestExitCode
message: str | None = None


class ExperimentRunner(threading.Thread):
def __init__(self, everest_config, shared_data: dict):
super().__init__()

self._everest_config = everest_config
self._shared_data = shared_data
self._exit_code: EverestExitCode | None = None
self._status: ExperimentStatus | None = None

def run(self):
run_model = EverestRunModel.create(
Expand All @@ -86,13 +92,17 @@ def run(self):

try:
run_model.run_experiment(evaluator_server_config)
self._exit_code = run_model.exit_code
except Exception:
self._exit_code = EverestExitCode.EXCEPTION

assert run_model.exit_code is not None
self._status = ExperimentStatus(exit_code=run_model.exit_code)
except Exception as e:
self._status = ExperimentStatus(
exit_code=EverestExitCode.EXCEPTION, message=str(e)
)

@property
def exit_code(self) -> EverestExitCode | None:
return self._exit_code
def status(self) -> ExperimentStatus | None:
return self._status

@property
def shared_data(self) -> dict:
Expand Down Expand Up @@ -216,10 +226,14 @@ def start_experiment(
_check_user(credentials)

nonlocal runner
runner = ExperimentRunner(config, shared_data)
runner.start()

return Response("Everest experiment started", 200)
if runner is None:
runner = ExperimentRunner(config, shared_data)
try:
runner.start()
return Response("Everest experiment started")
except Exception as e:
return Response(f"Could not start experiment: {e!s}", status_code=501)
return Response("Everest experiment is running")

@app.get("/" + EXPERIMENT_STATUS_ENDPOINT)
def get_experiment_status(
Expand All @@ -228,13 +242,15 @@ def get_experiment_status(
_log(request)
_check_user(credentials)
if shared_data[STOP_ENDPOINT]:
return Response(f"{EverestExitCode.USER_ABORT}", 200)
return JSONResponse(
ExperimentStatus(exit_code=EverestExitCode.USER_ABORT).model_dump_json()
)
if runner is None:
return Response(None, 204)
status = runner.exit_code
status = runner.status
if status is None:
return Response(None, 204)
return Response(f"{status}", 200)
return JSONResponse(status.model_dump_json())

@app.get("/" + SHARED_DATA_ENDPOINT)
def get_shared_data(
Expand Down Expand Up @@ -398,7 +414,7 @@ def main():
url, cert, auth = server_context[0]

done = False
exit_code = None
experiment_status: ExperimentStatus | None = None
# loop until the optimization is done
while not done:
response = requests.get(
Expand All @@ -409,9 +425,10 @@ def main():
proxies=PROXY, # type: ignore
)
if response.status_code == requests.codes.OK:
exit_code = int(
json_body = json.loads(
response.text if hasattr(response, "text") else response.body
)
experiment_status = ExperimentStatus.model_validate_json(json_body)
done = True
else:
time.sleep(1)
Expand All @@ -428,7 +445,8 @@ def main():
):
shared_data = json_body

status, message = _get_optimization_status(exit_code, shared_data)
assert experiment_status is not None
status, message = _get_optimization_status(experiment_status, shared_data)
if status != ServerStatus.completed:
update_everserver_status(status_path, status, message)
return
Expand Down Expand Up @@ -478,8 +496,10 @@ def main():
update_everserver_status(status_path, ServerStatus.completed, message=message)


def _get_optimization_status(exit_code, shared_data):
match exit_code:
def _get_optimization_status(
experiment_status: ExperimentStatus, shared_data: dict
) -> tuple[ServerStatus, str]:
match experiment_status.exit_code:
case EverestExitCode.MAX_BATCH_NUM_REACHED:
return ServerStatus.completed, "Maximum number of batches reached."

Expand All @@ -493,7 +513,8 @@ def _get_optimization_status(exit_code, shared_data):
return ServerStatus.stopped, "Optimization aborted."

case EverestExitCode.EXCEPTION:
return ServerStatus.failed, "Optimization failed with exception."
assert experiment_status.message is not None
return ServerStatus.failed, experiment_status.message

case EverestExitCode.TOO_FEW_REALIZATIONS:
status = (
Expand Down
28 changes: 17 additions & 11 deletions tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import requests
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, Response
from fastapi.responses import JSONResponse
from seba_sqlite.snapshot import SebaSnapshot

from ert.run_models.everest_run_model import EverestExitCode
Expand Down Expand Up @@ -149,7 +149,11 @@ def test_status_running_complete(

def mocked_server(url, verify, auth, timeout, proxies):
if "/experiment_status" in url:
return Response(f"{EverestExitCode.COMPLETED}", 200)
return JSONResponse(
everserver.ExperimentStatus(
exit_code=EverestExitCode.COMPLETED
).model_dump_json()
)
if "/shared_data" in url:
return JSONResponse(
jsonable_encoder(
Expand All @@ -159,7 +163,6 @@ def mocked_server(url, verify, auth, timeout, proxies):
}
)
)

resp = requests.Response()
resp.status_code = 200
return resp
Expand All @@ -185,8 +188,11 @@ def test_status_failed_job(mocked_get, mocked_logger, copy_math_func_test_data_t

def mocked_server(url, verify, auth, timeout, proxies):
if "/experiment_status" in url:
return Response(f"{EverestExitCode.TOO_FEW_REALIZATIONS}", 200)

return JSONResponse(
everserver.ExperimentStatus(
exit_code=EverestExitCode.TOO_FEW_REALIZATIONS
).model_dump_json()
)
if "/shared_data" in url:
return JSONResponse(
jsonable_encoder(
Expand Down Expand Up @@ -253,8 +259,11 @@ def test_status_exception(mocked_get, mocked_logger, copy_math_func_test_data_to

def mocked_server(url, verify, auth, timeout, proxies):
if "/experiment_status" in url:
return Response(f"{EverestExitCode.EXCEPTION}", 200)

return JSONResponse(
everserver.ExperimentStatus(
exit_code=EverestExitCode.EXCEPTION, message="Some message"
).model_dump_json()
)
if "/shared_data" in url:
return JSONResponse(
jsonable_encoder(
Expand All @@ -267,7 +276,6 @@ def mocked_server(url, verify, auth, timeout, proxies):
}
)
)

resp = requests.Response()
resp.status_code = 200
return resp
Expand All @@ -279,10 +287,8 @@ def mocked_server(url, verify, auth, timeout, proxies):
ServerConfig.get_everserver_status_path(config.output_dir)
)

# The server should fail, and store the exception that
# start_optimization raised.
assert status["status"] == ServerStatus.failed
assert "Optimization failed with exception." in status["message"]
assert "Some message" in status["message"]


@pytest.mark.integration_test
Expand Down

0 comments on commit 5c92cb5

Please sign in to comment.