Skip to content

Commit

Permalink
Raise error in Choice on duplicate outlets (#545)
Browse files Browse the repository at this point in the history
* Raise error in `Choice` on duplicate outlets

* Move validation

* Add tests, improve error messages
  • Loading branch information
gtopper authored Nov 17, 2024
1 parent 22765a0 commit d1b8dcf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
11 changes: 8 additions & 3 deletions storey/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from asyncio import Task
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union
from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Union

import aiohttp

Expand Down Expand Up @@ -363,12 +363,12 @@ def _init(self):
# TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget
self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"]

def select_outlets(self, event) -> List[str]:
def select_outlets(self, event) -> Collection[str]:
"""
Override this method to route events based on a customer logic. The default implementation will route all
events to all outlets.
"""
return list(self._name_to_outlet.keys())
return self._name_to_outlet.keys()

async def _do(self, event):
if event is _termination_obj:
Expand All @@ -381,6 +381,11 @@ async def _do(self, event):
outlet = self._name_to_outlet["dataframe"]
outlets.append(outlet)
else:
if len(set(outlet_names)) != len(outlet_names):
raise ValueError(
"select_outlets() returned duplicate outlets among the defined outlets: "
+ ", ".join(outlet_names)
)
for outlet_name in outlet_names:
if outlet_name not in self._name_to_outlet:
raise ValueError(
Expand Down
44 changes: 44 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,50 @@ def select_outlets(self, event):
assert termination_result == expected


def test_duplicate_choice():
class DuplicateChoice(Choice):
def select_outlets(self, event):
outlets = ["all_events", "all_events"]
return outlets

source = SyncEmitSource()
duplicate_choice = DuplicateChoice(termination_result_fn=lambda x, y: x + y)
all_events = Map(lambda x: x, name="all_events")

source.to(duplicate_choice).to(all_events)

controller = source.run()
controller.emit(0)
controller.terminate()
with pytest.raises(
ValueError,
match=r"select_outlets\(\) returned duplicate outlets among the defined outlets: all_events, all_events",
):
controller.await_termination()


def test_nonexistent_choice():
class NonexistentChoice(Choice):
def select_outlets(self, event):
outlets = ["wrong"]
return outlets

source = SyncEmitSource()
nonexistent_choice = NonexistentChoice(termination_result_fn=lambda x, y: x + y)
all_events = Map(lambda x: x, name="all_events")

source.to(nonexistent_choice).to(all_events)

controller = source.run()
controller.emit(0)
controller.terminate()
with pytest.raises(
ValueError,
match=r"select_outlets\(\) returned outlet name 'wrong', which is not one of the defined outlets: all_events",
):
controller.await_termination()


def test_metadata():
def mapf(x):
x.key = x.key + 1
Expand Down

0 comments on commit d1b8dcf

Please sign in to comment.