From d1b8dcfa931961e658828ae66f0afa3217efd499 Mon Sep 17 00:00:00 2001 From: Gal Topper Date: Sun, 17 Nov 2024 15:33:32 +0800 Subject: [PATCH] Raise error in `Choice` on duplicate outlets (#545) * Raise error in `Choice` on duplicate outlets * Move validation * Add tests, improve error messages --- storey/flow.py | 11 ++++++++--- tests/test_flow.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/storey/flow.py b/storey/flow.py index 7f2df3f0..6956df30 100644 --- a/storey/flow.py +++ b/storey/flow.py @@ -23,7 +23,7 @@ from asyncio import Task from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union +from typing import Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Union import aiohttp @@ -363,12 +363,12 @@ def _init(self): # TODO: hacky way of supporting mlrun preview, which replaces targets with a DFTarget self._passthrough_for_preview = list(self._name_to_outlet) == ["dataframe"] - def select_outlets(self, event) -> List[str]: + def select_outlets(self, event) -> Collection[str]: """ Override this method to route events based on a customer logic. The default implementation will route all events to all outlets. """ - return list(self._name_to_outlet.keys()) + return self._name_to_outlet.keys() async def _do(self, event): if event is _termination_obj: @@ -381,6 +381,11 @@ async def _do(self, event): outlet = self._name_to_outlet["dataframe"] outlets.append(outlet) else: + if len(set(outlet_names)) != len(outlet_names): + raise ValueError( + "select_outlets() returned duplicate outlets among the defined outlets: " + + ", ".join(outlet_names) + ) for outlet_name in outlet_names: if outlet_name not in self._name_to_outlet: raise ValueError( diff --git a/tests/test_flow.py b/tests/test_flow.py index 755597ce..9911efab 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1728,6 +1728,50 @@ def select_outlets(self, event): assert termination_result == expected +def test_duplicate_choice(): + class DuplicateChoice(Choice): + def select_outlets(self, event): + outlets = ["all_events", "all_events"] + return outlets + + source = SyncEmitSource() + duplicate_choice = DuplicateChoice(termination_result_fn=lambda x, y: x + y) + all_events = Map(lambda x: x, name="all_events") + + source.to(duplicate_choice).to(all_events) + + controller = source.run() + controller.emit(0) + controller.terminate() + with pytest.raises( + ValueError, + match=r"select_outlets\(\) returned duplicate outlets among the defined outlets: all_events, all_events", + ): + controller.await_termination() + + +def test_nonexistent_choice(): + class NonexistentChoice(Choice): + def select_outlets(self, event): + outlets = ["wrong"] + return outlets + + source = SyncEmitSource() + nonexistent_choice = NonexistentChoice(termination_result_fn=lambda x, y: x + y) + all_events = Map(lambda x: x, name="all_events") + + source.to(nonexistent_choice).to(all_events) + + controller = source.run() + controller.emit(0) + controller.terminate() + with pytest.raises( + ValueError, + match=r"select_outlets\(\) returned outlet name 'wrong', which is not one of the defined outlets: all_events", + ): + controller.await_termination() + + def test_metadata(): def mapf(x): x.key = x.key + 1