Skip to content

Commit

Permalink
Splitt more
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Nov 27, 2024
1 parent 790f506 commit 237cca7
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 118 deletions.
199 changes: 144 additions & 55 deletions src/everest/detached/jobs/everest_server_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
from dns import resolver, reversename
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.responses import (
JSONResponse,
Expand Down Expand Up @@ -50,103 +50,192 @@



def everest_server_api():

def __init__(output_dir:str, optimization_output_dir:str):
# same code is in ensemble evaluator
authentication = _generate_authentication()

# same code is in ensemble evaluator
cert_path, key_path, key_pw = _generate_certificate(
ServerConfig.get_certificate_dir(output_dir)
)
host = _get_machine_name()
port = _find_open_port(host, lower=5000, upper=5800)

host_file = ServerConfig.get_hostfile_path(output_dir)
_write_hostfile(host_file, host, port, cert_path, authentication)



def _sim_monitor(context_status, event=None, shared_data=None):
status = context_status["status"]
assert shared_data
shared_data[SIM_PROGRESS_ENDPOINT] = {
"batch_number": context_status["batch_number"],
"status": {
"running": status.running,
"waiting": status.waiting,
"pending": status.pending,
"complete": status.complete,
"failed": status.failed,
},
"progress": context_status["progress"],
"event": event,
}

if shared_data[STOP_ENDPOINT]:
return "stop_queue"


def _opt_monitor(shared_data=None):
assert shared_data
if shared_data[STOP_ENDPOINT]:
return "stop_optimization"







class ExperimentRunner(threading.Thread):
def __init__(self, everest_config, shared_data):
super().__init__()

self.everest_config = everest_config
self.shared_data = shared_data
self.exit_code = None


def run(self):
run_model = EverestRunModel.create(
config,
simulation_callback=partial(_sim_monitor, shared_data=shared_data),
optimization_callback=partial(_opt_monitor, shared_data=shared_data),
self.everest_config,
simulation_callback=partial(_sim_monitor, shared_data=self.shared_data),
optimization_callback=partial(_opt_monitor, shared_data=self.shared_data),
)

evaluator_server_config = EvaluatorServerConfig(
custom_port_range=range(49152, 51819)
if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL
else None
)

run_model.run_experiment(evaluator_server_config)

self.exit_code =run_model.exit_code


def start(self):
uvicorn.run(
app,
host="0.0.0.0",
port=server_config["port"],
ssl_keyfile=server_config["key_path"],
ssl_certfile=server_config["cert_path"],
ssl_version=ssl.PROTOCOL_SSLv23,
ssl_keyfile_password=server_config["key_passwd"],
)
def exit_code(self):
return self.exit_code



security = HTTPBasic()


class EverestServerAPI(threading.Thread):

def __init__(self, everest_config: EverestConfig, shared_data:dict):
super().__init__()

self.app = FastAPI()

self.router = APIRouter()
self.router.add_api_route("/", self.get_status, methods=["GET"])
self.router.add_api_route("/stop", self.stop, methods=["POST"])
self.router.add_api_route("/sim_progress", self.get_sim_progress, methods=["GET"])
self.router.add_api_route("/opt_progress", self.get_opt_progress, methods=["GET"])
self.router.add_api_route("/start", self.start_experiment, methods=["POST"])


self.router.add_api_route("/exit_code", self.get_exit_code, methods=["GET"])

self.app.include_router(self.router)



self.shared_data =shared_data
self.everest_config =everest_config
self.output_dir = everest_config.output_dir
self.optimization_output_dir = everest_config.optimization_output_dir


# same code is in ensemble evaluator
self.authentication = _generate_authentication()

# same code is in ensemble evaluator
self.cert_path, self.key_path, self.key_pw = _generate_certificate(
ServerConfig.get_certificate_dir(self.output_dir)
)
self.host = _get_machine_name()
self.port = _find_open_port(self.host, lower=5000, upper=5800)

host_file = ServerConfig.get_hostfile_path(self.output_dir)
_write_hostfile(host_file, self.host, self.port, self.cert_path, self.authentication)



def run(self):

uvicorn.run(
self.app,
host="0.0.0.0",
port=self.port,
ssl_keyfile=self.key_path,
ssl_certfile=self.cert_path,
ssl_version=ssl.PROTOCOL_SSLv23,
ssl_keyfile_password=self.key_pw,
)

def _check_user(credentials: HTTPBasicCredentials) -> None:
if credentials.password != server_config["authentication"]:

def _check_user(self, credentials: HTTPBasicCredentials) -> None:
if credentials.password != self.authentication:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid credentials",
headers={"WWW-Authenticate": "Basic"},
)

def _log(request: Request) -> None:
def _log(self, request: Request) -> None:
logging.getLogger("everserver").info(
f"{request.scope['path']} entered from {request.client.host if request.client else 'unknown host'} with HTTP {request.method}"
)




app = FastAPI()
security = HTTPBasic()


@app.get("/")
def get_status(
def get_status(self,
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> PlainTextResponse:
_log(request)
_check_user(credentials)
self._log(request)
self._check_user(credentials)
return PlainTextResponse("Everest is running")

@app.post("/" + STOP_ENDPOINT)
def stop(
def stop(self,
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> Response:
_log(request)
_check_user(credentials)
shared_data[STOP_ENDPOINT] = True
self._log(request)
self._check_user(credentials)
self.shared_data[STOP_ENDPOINT] = True
return Response("Raise STOP flag succeeded. Everest initiates shutdown..", 200)

@app.get("/" + SIM_PROGRESS_ENDPOINT)
def get_sim_progress(
def get_sim_progress(self,
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
_log(request)
_check_user(credentials)
progress = shared_data[SIM_PROGRESS_ENDPOINT]
self._log(request)
self._check_user(credentials)
progress = self.shared_data[SIM_PROGRESS_ENDPOINT]
print(self.runner.exit_code)
return JSONResponse(jsonable_encoder(progress))

@app.get("/" + OPT_PROGRESS_ENDPOINT)
def get_opt_progress(
def get_exit_code(self,
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
return JSONResponse({"exit_code" : self.runner.exit_code})


def get_opt_progress(self,
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:
_log(request)
_check_user(credentials)
progress = get_opt_status(server_config["optimization_output_dir"])
self._log(request)
self._check_user(credentials)
progress = get_opt_status(self.optimization_output_dir)
return JSONResponse(jsonable_encoder(progress))


def start_experiment(self,
request: Request, credentials: HTTPBasicCredentials = Depends(security)
) -> JSONResponse:

self.runner = ExperimentRunner(self.everest_config, self.shared_data)
self.runner.start()


return JSONResponse("ok")

94 changes: 31 additions & 63 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import threading
import traceback
from base64 import b64encode
import time
from datetime import datetime, timedelta
from functools import partial

import requests
import uvicorn
from cryptography import x509
from cryptography.hazmat.backends import default_backend
Expand All @@ -35,7 +36,8 @@
from ert.run_models.everest_run_model import EverestRunModel
from everest import export_to_csv, export_with_progress
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, get_opt_status, update_everserver_status
from everest.detached import PROXY, ServerStatus, _query_server, get_opt_status, update_everserver_status
from everest.detached.jobs.everest_server_api import EverestServerAPI
from everest.export import check_for_errors
from everest.simulator import JOB_FAILURE
from everest.strings import (
Expand All @@ -49,33 +51,6 @@




def _sim_monitor(context_status, event=None, shared_data=None):
status = context_status["status"]
shared_data[SIM_PROGRESS_ENDPOINT] = {
"batch_number": context_status["batch_number"],
"status": {
"running": status.running,
"waiting": status.waiting,
"pending": status.pending,
"complete": status.complete,
"failed": status.failed,
},
"progress": context_status["progress"],
"event": event,
}

if shared_data[STOP_ENDPOINT]:
return "stop_queue"


def _opt_monitor(shared_data=None):
if shared_data[STOP_ENDPOINT]:
return "stop_optimization"




def _get_optimization_status(exit_code, shared_data):
if exit_code == "max_batch_num_reached":
return ServerStatus.completed, "Maximum number of batches reached."
Expand Down Expand Up @@ -190,26 +165,13 @@ def main():
STOP_ENDPOINT: False,
}

server_config = {
"optimization_output_dir": config.optimization_output_dir,
"port": port,
"cert_path": cert_path,
"key_path": key_path,
"key_passwd": key_pw,
"authentication": authentication,
}

everest_server_api = everest_server_api(config.output_dir, config.optimization_output_dir)


everserver_instance = threading.Thread(
target=_everserver_thread,
args=(shared_data, server_config),
)
everserver_instance.daemon = True
everserver_instance.start()

everest_server_api = EverestServerAPI(config, shared_data)
everest_server_api.daemon = True
everest_server_api.start()

server_context=ServerConfig.get_server_context(config.output_dir),
url, cert, auth = server_context[0]


except:
Expand All @@ -221,25 +183,31 @@ def main():
return

try:
# wait until the server is running
is_running= False
while not is_running:
try:
requests.get(url + "/", verify=cert, auth=auth, proxies=PROXY )
except:
time.sleep(1)
pass
is_running =True

update_everserver_status(status_path, ServerStatus.running)


response= requests.post(url + "/start", verify=cert, auth=auth, proxies=PROXY )

evaluator_server_config = EvaluatorServerConfig(
custom_port_range=range(49152, 51819)
if run_model.ert_config.queue_config.queue_system == QueueSystem.LOCAL
else None
)

run_model.run_experiment(evaluator_server_config)

## yield






status, message = _get_optimization_status(run_model.exit_code, shared_data)
is_running= True
while is_running:
response= requests.get(url + "/exit_code", verify=cert, auth=auth, proxies=PROXY )
if json_body:= json.loads(response.text):
if exit_code:=json_body["exit_code"]:
is_running= False
else:
time.sleep(1)

status, message = _get_optimization_status(exit_code, shared_data)
if status != ServerStatus.completed:
update_everserver_status(status_path, status, message)
return
Expand Down

0 comments on commit 237cca7

Please sign in to comment.