Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let zmq select port in EnsembleEvaluator #9991

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ def run_cli(args: Namespace, plugin_manager: ErtPluginManager | None = None) ->

use_ipc_protocol = model.queue_system == QueueSystem.LOCAL
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=args.port_range, use_ipc_protocol=use_ipc_protocol
port_range=None
if args.port_range is None
else (min(args.port_range), max(args.port_range) + 1),
use_ipc_protocol=use_ipc_protocol,
)

if model.check_if_runpath_exists():
Expand Down
4 changes: 2 additions & 2 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async def evaluate(
ce_unary_send_method_name,
partialmethod(
self.__class__.send_event,
self._config.get_connection_info().router_uri,
self._config.get_uri(),
token=self._config.token,
),
)
Expand Down Expand Up @@ -267,7 +267,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
max_running=self._queue_config.max_running,
submit_sleep=self._queue_config.submit_sleep,
ens_id=self.id_,
ee_uri=self._config.get_connection_info().router_uri,
ee_uri=self._config.get_uri(),
ee_token=self._config.token,
)
logger.info(
Expand Down
53 changes: 26 additions & 27 deletions src/ert/ensemble_evaluator/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import logging
import socket
import uuid
import warnings

import zmq

from ert.shared import find_available_socket
from ert.shared import get_machine_name as ert_shared_get_machine_name

from .evaluator_connection_info import EvaluatorConnectionInfo
from ert.shared.net_utils import get_ip_address

logger = logging.getLogger(__name__)

Expand All @@ -25,39 +22,41 @@ def get_machine_name() -> str:
class EvaluatorServerConfig:
def __init__(
self,
custom_port_range: range | None = None,
port_range: tuple[int, int] | None = None,
use_token: bool = True,
custom_host: str | None = None,
host: str | None = None,
use_ipc_protocol: bool = True,
) -> None:
self.host: str | None = None
self.host: str | None = host
self.router_port: int | None = None
self.url = f"ipc:///tmp/socket-{uuid.uuid4().hex[:8]}"
self.token: str | None = None
self._socket_handle: socket.socket | None = None

self.server_public_key: bytes | None = None
self.server_secret_key: bytes | None = None
if not use_ipc_protocol:
self._socket_handle = find_available_socket(
custom_range=custom_port_range,
custom_host=custom_host,
will_close_then_reopen_socket=True,
)
self.host, self.router_port = self._socket_handle.getsockname()
self.url = f"tcp://{self.host}:{self.router_port}"
self.use_ipc_protocol: bool = use_ipc_protocol

if port_range is None:
port_range = (51820, 51840 + 1)
else:
if port_range[0] > port_range[1]:
raise ValueError("Minimum port in range is higher than maximum port")

if port_range[0] == port_range[1]:
port_range = (port_range[0], port_range[0] + 1)

self.min_port = port_range[0]
self.max_port = port_range[1]

if use_ipc_protocol:
self.uri = f"ipc:///tmp/socket-{uuid.uuid4().hex[:8]}"
elif self.host is None:
self.host = get_ip_address()

if use_token:
self.server_public_key, self.server_secret_key = zmq.curve_keypair()
self.token = self.server_public_key.decode("utf-8")

def get_socket(self) -> socket.socket | None:
if self._socket_handle:
return self._socket_handle.dup()
return None
def get_uri(self) -> str:
if not self.use_ipc_protocol:
return f"tcp://{self.host}:{self.router_port}"

def get_connection_info(self) -> EvaluatorConnectionInfo:
return EvaluatorConnectionInfo(
xjules marked this conversation as resolved.
Show resolved Hide resolved
self.url,
self.token,
)
return self.uri
13 changes: 9 additions & 4 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,17 @@ async def _server(self) -> None:
self._router_socket.curve_publickey = self._config.server_public_key
self._router_socket.curve_server = True

if self._config.router_port:
self._router_socket.bind(f"tcp://*:{self._config.router_port}")
if self._config.use_ipc_protocol:
self._router_socket.bind(self._config.get_uri())
else:
self._router_socket.bind(self._config.url)
self._config.router_port = self._router_socket.bind_to_random_port(
"tcp://*",
min_port=self._config.min_port,
max_port=self._config.max_port,
)

self._server_started.set_result(None)
except zmq.error.ZMQError as e:
except zmq.error.ZMQBaseError as e:
logger.error(f"ZMQ error encountered {e} during evaluator initialization")
self._server_started.set_exception(e)
zmq_context.destroy(linger=0)
Expand Down
9 changes: 0 additions & 9 deletions src/ert/ensemble_evaluator/evaluator_connection_info.py

This file was deleted.

14 changes: 3 additions & 11 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Final
from typing import Final

from _ert.events import (
EETerminated,
Expand All @@ -16,10 +16,6 @@
)
from _ert.forward_model_runner.client import Client

if TYPE_CHECKING:
from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo


logger = logging.getLogger(__name__)


Expand All @@ -30,15 +26,11 @@ class EventSentinel:
class Monitor(Client):
_sentinel: Final = EventSentinel()

def __init__(self, ee_con_info: EvaluatorConnectionInfo) -> None:
def __init__(self, uri: str, token: str | None = None) -> None:
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Event | EventSentinel] = asyncio.Queue()
self._receiver_timeout: float = 60.0
super().__init__(
ee_con_info.router_uri,
ee_con_info.token,
dealer_name=f"client-{self._id}",
)
super().__init__(uri, token, dealer_name=f"client-{self._id}")

async def process_message(self, msg: str) -> None:
event = event_from_json(msg)
Expand Down
7 changes: 1 addition & 6 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,8 @@ def run_experiment(self, restart: bool = False) -> None:
self._snapshot_model.reset()
self._tab_widget.clear()

port_range = None
use_ipc_protocol = False
if self._run_model.queue_system == QueueSystem.LOCAL:
port_range = range(49152, 51819)
use_ipc_protocol = True
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=port_range, use_ipc_protocol=use_ipc_protocol
use_ipc_protocol=self._run_model.queue_system == QueueSystem.LOCAL
)

def run() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def run_monitor(
) -> bool:
try:
logger.debug("connecting to new monitor...")
async with Monitor(ee_config.get_connection_info()) as monitor:
async with Monitor(ee_config.get_uri(), ee_config.token) as monitor:
logger.debug("connected")
async for event in monitor.track(heartbeat_interval=0.1):
if type(event) in {
Expand Down
8 changes: 2 additions & 6 deletions src/ert/services/_storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def run_server(
if args is None:
args = parse_args()

if "ERT_STORAGE_TOKEN" in os.environ:
authtoken = os.environ["ERT_STORAGE_TOKEN"]
else:
if (authtoken := os.environ.get("ERT_STORAGE_TOKEN")) is None:
authtoken = generate_authtoken()
os.environ["ERT_STORAGE_TOKEN"] = authtoken

Expand All @@ -106,9 +104,7 @@ def run_server(
config_args.update(reload=True, reload_dirs=[os.path.dirname(ert_shared_path)])
os.environ["ERT_STORAGE_DEBUG"] = "1"

sock = find_available_socket(
custom_host=args.host, custom_range=range(51850, 51870)
)
sock = find_available_socket(host=args.host, port_range=range(51850, 51870 + 1))
connection_info = _create_connection_info(sock, authtoken)

# Appropriated from uvicorn.main:run
Expand Down
61 changes: 21 additions & 40 deletions src/ert/shared/net_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def get_machine_name() -> str:


def find_available_socket(
custom_host: str | None = None,
custom_range: range | None = None,
will_close_then_reopen_socket: bool = False,
host: str | None = None,
port_range: range = range(51820, 51840 + 1),
) -> socket.socket:
"""
The default and recommended approach here is to return a bound socket to the
Expand All @@ -70,49 +69,30 @@ def find_available_socket(

See e.g. implementation and comments in EvaluatorServerConfig
"""
current_host = custom_host if custom_host is not None else _get_ip_address()
current_range = (
custom_range if custom_range is not None else range(51820, 51840 + 1)
)
current_host = host if host is not None else get_ip_address()

if current_range.start == current_range.stop:
ports = list(range(current_range.start, current_range.stop + 1))
if port_range.start == port_range.stop:
JHolba marked this conversation as resolved.
Show resolved Hide resolved
ports = list(range(port_range.start, port_range.stop + 1))
else:
ports = list(range(current_range.start, current_range.stop))
ports = list(range(port_range.start, port_range.stop))

random.shuffle(ports)
for port in ports:
try:
return _bind_socket(
host=current_host,
port=port,
will_close_then_reopen_socket=will_close_then_reopen_socket,
)
except PortAlreadyInUseException:
continue

raise NoPortsInRangeException(f"No available ports in range {current_range}.")
raise NoPortsInRangeException(f"No available ports in range {port_range}.")


def _bind_socket(
host: str, port: int, will_close_then_reopen_socket: bool = False
) -> socket.socket:
def _bind_socket(host: str, port: int) -> socket.socket:
try:
family = get_family(host=host)
sock = socket.socket(family=family, type=socket.SOCK_STREAM)

# Setting flags like SO_REUSEADDR and/or SO_REUSEPORT may have
# undesirable side-effects but we allow it if caller insists. Refer to
# comment on find_available_socket()
#
# See e.g. https://stackoverflow.com/a/14388707 for an extensive
# explanation of these flags, in particular the part about TIME_WAIT

if will_close_then_reopen_socket:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
else:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)

sock.bind((host, port))
return sock
except socket.gaierror as err_info:
Expand All @@ -139,18 +119,19 @@ def get_family(host: str) -> socket.AddressFamily:


# See https://stackoverflow.com/a/28950776
def _get_ip_address() -> str:
def get_ip_address() -> str:
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.settimeout(0)
# try pinging a reserved, internal address in order
# to determine IP representing the default route
s.connect(("10.255.255.255", 1))
retval = s.getsockname()[0]
try:
s.settimeout(0)
# try pinging a reserved, internal address in order
# to determine IP representing the default route
s.connect(("10.255.255.255", 1))
address = s.getsockname()[0]
finally:
s.close()
except BaseException:
logger.warning("Cannot determine ip-address. Fallback to localhost...")
retval = "127.0.0.1"
finally:
s.close()
logger.debug(f"ip-address: {retval}")
return retval
logger.warning("Cannot determine ip-address. Falling back to localhost.")
address = "127.0.0.1"
logger.debug(f"ip-address: {address}")
return address
2 changes: 1 addition & 1 deletion src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def run(self):
evaluator_server_config = EvaluatorServerConfig()
else:
evaluator_server_config = EvaluatorServerConfig(
custom_port_range=range(49152, 51819), use_ipc_protocol=False
port_range=(49152, 51819), use_ipc_protocol=False
)

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,6 @@ def _dump_forward_model(forward_model, index):
@pytest.fixture(name="make_ee_config")
def make_ee_config_fixture():
def _ee_config(**kwargs):
return EvaluatorServerConfig(custom_port_range=range(1024, 65535), **kwargs)
return EvaluatorServerConfig(**kwargs)

return _ee_config
Loading