-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ec1fb85
commit 7eff8ac
Showing
3 changed files
with
123 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2024. Some Engineering | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
import asyncio | ||
import logging | ||
from concurrent.futures import ProcessPoolExecutor | ||
from concurrent.futures.process import BrokenProcessPool | ||
from typing import Optional, Any, Callable, TypeVar | ||
|
||
from fixcloudutils.service import Service | ||
|
||
log = logging.getLogger(__name__) | ||
T = TypeVar("T") | ||
|
||
|
||
class AsyncProcessPool(Service): | ||
def __init__(self, **pool_args: Any): | ||
self.pool_args = pool_args | ||
self.executor: Optional[ProcessPoolExecutor] = None | ||
self.lock = asyncio.Lock() | ||
|
||
async def start(self) -> Any: | ||
async with self.lock: | ||
if not self.executor: | ||
self.executor = ProcessPoolExecutor(**self.pool_args) | ||
|
||
async def stop(self) -> None: | ||
async with self.lock: | ||
if self.executor: | ||
self.executor.shutdown(wait=True) | ||
self.executor = None | ||
|
||
async def submit(self, func: Callable[..., T], *args: Any, timeout: Optional[float] = None) -> T: | ||
assert self.executor is not None, "Executor not started" | ||
loop = asyncio.get_running_loop() | ||
executor_id = id(self.executor) | ||
try: | ||
return await asyncio.wait_for(loop.run_in_executor(self.executor, func, *args), timeout=timeout) | ||
except BrokenProcessPool: # every running task will raise this exception | ||
async with self.lock: | ||
if id(self.executor) == executor_id: # does this exception comes from the current executor? | ||
log.warning("A process in the pool died unexpectedly. Creating a new pool.", exc_info=True) | ||
self.executor.shutdown(wait=False) | ||
self.executor = ProcessPoolExecutor(**self.pool_args) | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) 2024. Some Engineering | ||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
import asyncio | ||
import os | ||
import re | ||
from asyncio import Task | ||
from concurrent.futures.process import BrokenProcessPool | ||
from time import sleep | ||
from typing import Tuple, List | ||
|
||
import pytest | ||
|
||
from fixcloudutils.asyncio.process_pool import AsyncProcessPool | ||
|
||
|
||
def do_work(i: int) -> Tuple[str, int]: | ||
sleep(0.1) | ||
return f"GOT {i}", os.getpid() | ||
|
||
|
||
def broken_work(i: int) -> Tuple[str, int]: | ||
if i == 3: | ||
os._exit(1) | ||
else: | ||
return do_work(i) | ||
|
||
|
||
async def test_process_pool() -> None: | ||
tasks: List[Task[Tuple[str, int]]] = [] | ||
pids = set() | ||
task: str | ||
async with AsyncProcessPool(max_workers=5) as pool: | ||
for num in range(5): | ||
create_task = asyncio.create_task(pool.submit(do_work, num)) | ||
tasks.append(create_task) | ||
for task, pid in await asyncio.gather(*tasks): | ||
assert re.match(r"GOT [0-4]", task) is not None | ||
assert pid not in pids | ||
pids.add(pid) | ||
|
||
|
||
async def test_process_pool_broken() -> None: | ||
tasks = [] | ||
async with AsyncProcessPool(max_workers=5) as pool: | ||
for num in range(5): | ||
tasks.append(asyncio.create_task(pool.submit(broken_work, num))) | ||
for t in asyncio.as_completed(tasks): | ||
# one failing task will make the whole batch fail | ||
with pytest.raises(BrokenProcessPool): | ||
await t | ||
# we can still submit new tasks that get executed | ||
tasks = [] | ||
for num in range(5): | ||
tasks.append(asyncio.create_task(pool.submit(do_work, num))) | ||
for task, pid in await asyncio.gather(*tasks): | ||
assert re.match(r"GOT [0-4]", task) is not None |