Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make forwardmodelrunner async #9198

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix failed realization not being marked as failed
  • Loading branch information
jonathan-eq committed Nov 26, 2024
commit 2b8c74a9dd3cbf3eb807f3b031955e6c7978afd8
10 changes: 9 additions & 1 deletion src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ def handle_sigterm(*args, **kwargs):
await job_task


async def let_reporters_finish(reporters):
for reporter in reporters:
if isinstance(reporter, reporting.Event):
await reporter.join()


async def _main(
job_runner: ForwardModelRunner,
parsed_args,
Expand All @@ -161,14 +167,16 @@ async def _main(
for reporter in reporters:
try:
await reporter.report(job_status)
await asyncio.sleep(0)
except OSError as oserror:
print(
f"job_dispatch failed due to {oserror}. Stopping and cleaning up."
)
await let_reporters_finish(reporters)
raise SystemExit(1)

if isinstance(job_status, Finish) and not job_status.success():
await let_reporters_finish(reporters)
raise SystemExit(1)
except asyncio.CancelledError:
await let_reporters_finish(reporters)
raise SystemExit(1)
16 changes: 14 additions & 2 deletions src/_ert/forward_model_runner/forward_model_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
import io
import json
Expand All @@ -12,7 +13,17 @@
from datetime import datetime as dt
from pathlib import Path
from subprocess import Popen, run
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple, cast
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Dict,
Generator,
List,
Optional,
Sequence,
Tuple,
cast,
)

from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess

Expand Down Expand Up @@ -89,10 +100,11 @@ def __init__(
self.std_err = job_data.get("stderr")
self.std_out = job_data.get("stdout")

def run(self) -> Generator[Start | Exited | Running | None]:
async def run(self) -> AsyncGenerator[Start | Exited | Running | None]:
try:
for msg in self._run():
yield msg
await asyncio.sleep(0)
except Exception as e:
yield Exited(self, exit_code=1).with_error(str(e))

Expand Down
3 changes: 2 additions & 1 deletion src/_ert/forward_model_runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ async def run(self, names_of_steps_to_run: List[str]):

yield init_message
for step in step_queue:
for status_update in step.run():
async for status_update in step.run():
yield status_update
if not status_update.success():
yield Checksum(checksum_dict={}, run_path=os.getcwd())
yield Finish().with_error(
"Not all forward model steps completed successfully."
)
return
await asyncio.sleep(0)

checksum_dict = self._populate_checksums(self._read_manifest())
yield Checksum(checksum_dict=checksum_dict, run_path=os.getcwd())
Expand Down
52 changes: 15 additions & 37 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,44 @@
from functools import partial

import pytest

from _ert.forward_model_runner.client import Client, ClientConnectionError
from _ert.threading import ErtThread

from .ensemble_evaluator_utils import _mock_ws
from tests.ert.utils import _mock_ws_task


def test_invalid_server():
async def test_invalid_server():
port = 7777
host = "localhost"
url = f"ws://{host}:{port}"

with (
Client(url, max_retries=2, timeout_multiplier=2) as c1,
pytest.raises(ClientConnectionError),
):
c1.send("hei")
async with Client(url, max_retries=2, timeout_multiplier=2) as c1:
with pytest.raises(ClientConnectionError):
await c1.send("hei")


def test_successful_sending(unused_tcp_port):
async def test_successful_sending(unused_tcp_port):
host = "localhost"
url = f"ws://{host}:{unused_tcp_port}"
messages = []
mock_ws_thread = ErtThread(
target=partial(_mock_ws, messages=messages), args=(host, unused_tcp_port)
)

mock_ws_thread.start()
messages_c1 = ["test_1", "test_2", "test_3", "stop"]

with Client(url) as c1:
messages_c1 = ["test_1", "test_2", "test_3"]
async with _mock_ws_task(host, unused_tcp_port, messages), Client(url) as c1:
for msg in messages_c1:
c1.send(msg)

mock_ws_thread.join()
await c1.send(msg)

for msg in messages_c1:
assert msg in messages


def test_retry(unused_tcp_port):
async def test_retry(unused_tcp_port):
host = "localhost"
url = f"ws://{host}:{unused_tcp_port}"
messages = []
mock_ws_thread = ErtThread(
target=partial(_mock_ws, messages=messages, delay_startup=2),
args=(
host,
unused_tcp_port,
),
)

mock_ws_thread.start()
messages_c1 = ["test_1", "test_2", "test_3", "stop"]

with Client(url, max_retries=2, timeout_multiplier=2) as c1:
messages_c1 = ["test_1", "test_2", "test_3"]
async with _mock_ws_task(host, unused_tcp_port, messages, delay_startup=2), Client(
url, max_retries=2, timeout_multiplier=2
) as c1:
for msg in messages_c1:
c1.send(msg)

mock_ws_thread.join()
await c1.send(msg)

for msg in messages_c1:
assert msg in messages
6 changes: 4 additions & 2 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ async def _mock_ws_thread(host, port, messages):


@contextlib.asynccontextmanager
async def _mock_ws_task(host, port, messages):
mock_ws_task = asyncio.create_task(_mock_ws_async(host, port, messages))
async def _mock_ws_task(host, port, messages, delay_startup=0):
mock_ws_task = asyncio.create_task(
_mock_ws_async(host, port, messages, delay_startup)
)
try:
yield
# Make sure to join the thread even if an exception occurs
Expand Down