Skip to content

Commit

Permalink
[feat] AsyncProcessPool (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Apr 19, 2024
1 parent ec1fb85 commit 7eff8ac
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 1 deletion.
55 changes: 55 additions & 0 deletions fixcloudutils/asyncio/process_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2024. Some Engineering
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from typing import Optional, Any, Callable, TypeVar

from fixcloudutils.service import Service

log = logging.getLogger(__name__)
T = TypeVar("T")


class AsyncProcessPool(Service):
def __init__(self, **pool_args: Any):
self.pool_args = pool_args
self.executor: Optional[ProcessPoolExecutor] = None
self.lock = asyncio.Lock()

async def start(self) -> Any:
async with self.lock:
if not self.executor:
self.executor = ProcessPoolExecutor(**self.pool_args)

async def stop(self) -> None:
async with self.lock:
if self.executor:
self.executor.shutdown(wait=True)
self.executor = None

async def submit(self, func: Callable[..., T], *args: Any, timeout: Optional[float] = None) -> T:
assert self.executor is not None, "Executor not started"
loop = asyncio.get_running_loop()
executor_id = id(self.executor)
try:
return await asyncio.wait_for(loop.run_in_executor(self.executor, func, *args), timeout=timeout)
except BrokenProcessPool: # every running task will raise this exception
async with self.lock:
if id(self.executor) == executor_id: # does this exception comes from the current executor?
log.warning("A process in the pool died unexpectedly. Creating a new pool.", exc_info=True)
self.executor.shutdown(wait=False)
self.executor = ProcessPoolExecutor(**self.pool_args)
raise
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "fixcloudutils"
version = "1.13.3"
version = "1.14.0"
authors = [{ name = "Some Engineering Inc." }]
description = "Utilities for fixcloud."
license = { file = "LICENSE" }
Expand Down
67 changes: 67 additions & 0 deletions tests/process_pool_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2024. Some Engineering
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import os
import re
from asyncio import Task
from concurrent.futures.process import BrokenProcessPool
from time import sleep
from typing import Tuple, List

import pytest

from fixcloudutils.asyncio.process_pool import AsyncProcessPool


def do_work(i: int) -> Tuple[str, int]:
sleep(0.1)
return f"GOT {i}", os.getpid()


def broken_work(i: int) -> Tuple[str, int]:
if i == 3:
os._exit(1)
else:
return do_work(i)


async def test_process_pool() -> None:
tasks: List[Task[Tuple[str, int]]] = []
pids = set()
task: str
async with AsyncProcessPool(max_workers=5) as pool:
for num in range(5):
create_task = asyncio.create_task(pool.submit(do_work, num))
tasks.append(create_task)
for task, pid in await asyncio.gather(*tasks):
assert re.match(r"GOT [0-4]", task) is not None
assert pid not in pids
pids.add(pid)


async def test_process_pool_broken() -> None:
tasks = []
async with AsyncProcessPool(max_workers=5) as pool:
for num in range(5):
tasks.append(asyncio.create_task(pool.submit(broken_work, num)))
for t in asyncio.as_completed(tasks):
# one failing task will make the whole batch fail
with pytest.raises(BrokenProcessPool):
await t
# we can still submit new tasks that get executed
tasks = []
for num in range(5):
tasks.append(asyncio.create_task(pool.submit(do_work, num)))
for task, pid in await asyncio.gather(*tasks):
assert re.match(r"GOT [0-4]", task) is not None

0 comments on commit 7eff8ac

Please sign in to comment.