Skip to content

Commit

Permalink
Add support for passing context to function handler (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtopper authored Feb 14, 2024
1 parent 67ffa55 commit f17a341
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
8 changes: 6 additions & 2 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ async def _do(self, event):


class _UnaryFunctionFlow(Flow):
def __init__(self, fn, long_running=None, **kwargs):
def __init__(self, fn, long_running=None, pass_context=None, **kwargs):
super().__init__(**kwargs)
if not callable(fn):
raise TypeError(f"Expected a callable, got {type(fn)}")
Expand All @@ -402,12 +402,16 @@ def __init__(self, fn, long_running=None, **kwargs):
raise ValueError("long_running=True cannot be used in conjunction with a coroutine")
self._long_running = long_running
self._fn = fn
self._pass_context = pass_context

async def _call(self, element):
if self._long_running:
res = await asyncio.get_running_loop().run_in_executor(None, self._fn, element)
else:
res = self._fn(element)
kwargs = {}
if self._pass_context:
kwargs = {"context": self.context}
res = self._fn(element, **kwargs)
if self._is_async:
res = await res
return res
Expand Down
16 changes: 16 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ def test_functional_flow():
assert termination_result == 3300


def test_pass_context_to_function():
controller = build_flow(
[
SyncEmitSource(),
Map(lambda x, context: x + context, pass_context=True, context=10),
Reduce(0, lambda acc, x: acc + x),
]
).run()

for i in range(5):
controller.emit(i)
controller.terminate()
termination_result = controller.await_termination()
assert termination_result == 60


class Committer:
def __init__(self):
self.offsets = {}
Expand Down

0 comments on commit f17a341

Please sign in to comment.