diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 28b86a9..5430c39 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -93,26 +93,21 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]: - to_select_from = { - k: inputs[k].value +def get_based_on(inputs: Dict[str, Any]) -> Dict: + return { + k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value for k in inputs.keys() - if isinstance(inputs[k], _ToSelectFrom) + if isinstance(inputs[k], _BasedOn) } - if not to_select_from: - raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." - ) - based_on = { - k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value +def get_to_select_from(inputs: Dict[str, Any]) -> Dict: + return { + k: inputs[k].value for k in inputs.keys() - if isinstance(inputs[k], _BasedOn) + if isinstance(inputs[k], _ToSelectFrom) } - return based_on, to_select_from - # end helper functions diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 8ed4b36..3c5d727 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -223,24 +223,20 @@ class PickBest(base.RLLoop[PickBestEvent]): """ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: - context, actions = base.get_based_on_and_to_select_from(inputs=inputs) - if not actions: + to_select_from = base.get_to_select_from(inputs) + if not to_select_from: raise ValueError( "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." ) - - if len(list(actions.values())) > 1: + if len(to_select_from) > 1: raise ValueError( "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." ) - - if not context: - raise ValueError( - "No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." - ) - - event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context) - return event + return PickBestEvent( + inputs=inputs, + to_select_from=to_select_from, + based_on=base.get_based_on(inputs), + ) def _call_after_predict_before_scoring( self, diff --git a/tests/unit_tests/test_pick_best_call.py b/tests/unit_tests/test_pick_best_call.py index b283092..30c7343 100644 --- a/tests/unit_tests/test_pick_best_call.py +++ b/tests/unit_tests/test_pick_best_call.py @@ -36,7 +36,7 @@ def test_multiple_ToSelectFrom_throws() -> None: ) -def test_missing_basedOn_from_throws() -> None: +def test_missing_basedOn_from_dont_throw() -> None: pick = learn_to_pick.PickBest.create( llm=fake_llm_caller, featurizer=learn_to_pick.PickBestFeaturizer( @@ -44,8 +44,7 @@ def test_missing_basedOn_from_throws() -> None: ), ) actions = ["0", "1", "2"] - with pytest.raises(ValueError): - pick.run(action=learn_to_pick.ToSelectFrom(actions)) + pick.run(action=learn_to_pick.ToSelectFrom(actions)) def test_ToSelectFrom_not_a_list_throws() -> None: