Skip to content

Commit

Permalink
Add ConcurrentExecution step (#511)
Browse files Browse the repository at this point in the history
* Add `ConcurrentExecution` step

Based on the preexisting internal `_ConcurrentJobExecution` class. With support for three concurrency mechanisms: asyncio, threading, and multiprocessing.

* Fix typo

* Expose `ConcurrentExecution`

* Allow `ConcurrentExecution` to pass context to user function

* Fix list append

* Fail early on non-serializable context

* Fix passing of default `max_in_flight` as `max_workers`

* Fix

* Support passing context to multiprocessing step

* Minor refactoring

* Change event processor back to function because of mlrun serialization issues

* Move function to avoid serialization issue

* Add documentation

* Revert attempts to pass context to multiprocessing, add docs

* Remove dill requirement
  • Loading branch information
gtopper authored Apr 21, 2024
1 parent 1c86107 commit bbc644f
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 2 deletions.
1 change: 1 addition & 0 deletions storey/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .flow import Batch # noqa: F401
from .flow import Choice # noqa: F401
from .flow import Complete # noqa: F401
from .flow import ConcurrentExecution # noqa: F401
from .flow import Context # noqa: F401
from .flow import Extend # noqa: F401
from .flow import Filter # noqa: F401
Expand Down
73 changes: 71 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import copy
import datetime
import inspect
import pickle
import time
import traceback
from asyncio import Task
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union

import aiohttp
Expand Down Expand Up @@ -800,11 +802,11 @@ def __init__(self, max_in_flight=None, retries=None, backoff_factor=None, **kwar
Flow.__init__(self, **kwargs)
if max_in_flight is not None and max_in_flight < 1:
raise ValueError(f"max_in_flight may not be less than 1 (got {max_in_flight})")
self.max_in_flight = max_in_flight
self.retries = retries
self.backoff_factor = backoff_factor

self._queue_size = max_in_flight - 1 if max_in_flight else 8
self._max_in_flight = max_in_flight or 8
self._queue_size = self._max_in_flight - 1

def _init(self):
super()._init()
Expand Down Expand Up @@ -916,6 +918,73 @@ async def _do(self, event):
await self._worker_awaitable


class ConcurrentExecution(_ConcurrentJobExecution):
"""
Inherit this class and override `process_event()` to process events concurrently.
:param process_event: Function that will be run on each event
:param concurrency_mechanism: One of:
* "asyncio" (default) – for I/O implemented using asyncio
* "threading" – for blocking I/O
* "multiprocessing" – for processing-intensive tasks
:param max_in_flight: Maximum number of events to be processed at a time (default 8)
:param retries: Maximum number of retries per event (default 0)
:param backoff_factor: Wait time in seconds between retries (default 1)
:param pass_context: If False, the process_event function will be called with just one parameter (event). If True,
the process_event function will be called with two parameters (event, context). Defaults to False.
"""

_supported_concurrency_mechanisms = ["asyncio", "threading", "multiprocessing"]

def __init__(
self,
event_processor: Union[Callable[[Event], Any], Callable[[Event, Any], Any]],
concurrency_mechanism=None,
pass_context=None,
**kwargs,
):
super().__init__(**kwargs)

self._event_processor = event_processor

if concurrency_mechanism and concurrency_mechanism not in self._supported_concurrency_mechanisms:
raise ValueError(f"Concurrency mechanism '{concurrency_mechanism}' is not supported")

if concurrency_mechanism == "multiprocessing" and pass_context:
try:
pickle.dumps(self.context)
except Exception as ex:
raise ValueError(
'When concurrency_mechanism="multiprocessing" is used in conjunction with '
"pass_context=True, context must be serializable"
) from ex

self._executor = None
if concurrency_mechanism == "threading":
self._executor = ThreadPoolExecutor(max_workers=self._max_in_flight)
elif concurrency_mechanism == "multiprocessing":
self._executor = ProcessPoolExecutor(max_workers=self._max_in_flight)

self._pass_context = pass_context

async def _process_event(self, event):
args = [event]
if self._pass_context:
args.append(self.context)
if self._executor:
result = await asyncio.get_running_loop().run_in_executor(self._executor, self._event_processor, *args)
else:
result = self._event_processor(*args)
if asyncio.iscoroutine(result):
result = await result
return result

async def _handle_completed(self, event, response):
await self._do_downstream(response)


class SendToHttp(_ConcurrentJobExecution):
"""Joins each event with data from any HTTP source. Used for event augmentation.
Expand Down
78 changes: 78 additions & 0 deletions tests/test_concurrent_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import asyncio
import time

import pytest

from storey import AsyncEmitSource
from storey.flow import ConcurrentExecution, Reduce, build_flow
from tests.test_flow import append_and_return

event_processing_duration = 0.5


class SomeContext:
def __init__(self):
self.fn = lambda x: x


async def process_event_slow_asyncio(event, context):
assert isinstance(context, SomeContext) and callable(context.fn)
await asyncio.sleep(event_processing_duration)
return event


def process_event_slow_io(event, context):
assert isinstance(context, SomeContext) and callable(context.fn)
time.sleep(event_processing_duration)
return event


def process_event_slow_processing(event):
start = time.monotonic()
while time.monotonic() - start < event_processing_duration:
pass
return event


async def async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context):
controller = build_flow(
[
AsyncEmitSource(),
ConcurrentExecution(
event_processor=event_processor,
concurrency_mechanism=concurrency_mechanism,
pass_context=pass_context,
max_in_flight=10,
context=SomeContext(),
),
Reduce([], append_and_return),
]
).run()

num_events = 8

start = time.monotonic()
for counter in range(num_events):
await controller.emit(counter)

await controller.terminate()
result = await controller.await_termination()
end = time.monotonic()

assert result == list(range(num_events))
assert end - start > event_processing_duration, "Run time cannot be less than the time to process a single event"
assert (
end - start < event_processing_duration * num_events
), "Run time must be less than the time to process all events in serial"


@pytest.mark.parametrize(
["concurrency_mechanism", "event_processor", "pass_context"],
[
("asyncio", process_event_slow_asyncio, True),
("threading", process_event_slow_io, True),
("multiprocessing", process_event_slow_processing, False),
],
)
def test_concurrent_execution(concurrency_mechanism, event_processor, pass_context):
asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context))

0 comments on commit bbc644f

Please sign in to comment.