Skip to content

Commit

Permalink
exceptions in one place
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 16, 2023
1 parent 9f39ea9 commit ca7f469
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 28 deletions.
21 changes: 8 additions & 13 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,21 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam
return [parser.parse_line(line) for line in input_str.split("\n")]


def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = {
k: inputs[k].value
def get_based_on(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
if isinstance(inputs[k], _BasedOn)
}

if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

based_on = {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
def get_to_select_from(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
if isinstance(inputs[k], _ToSelectFrom)
}

return based_on, to_select_from


# end helper functions

Expand Down
20 changes: 8 additions & 12 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,24 +223,20 @@ class PickBest(base.RLLoop[PickBestEvent]):
"""

def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
to_select_from = base.get_to_select_from(inputs)
if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

if len(list(actions.values())) > 1:
if len(to_select_from) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)

if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
)

event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event
return PickBestEvent(
inputs=inputs,
to_select_from=to_select_from,
based_on=base.get_based_on(inputs),
)

def _call_after_predict_before_scoring(
self,
Expand Down
5 changes: 2 additions & 3 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def test_multiple_ToSelectFrom_throws() -> None:
)


def test_missing_basedOn_from_throws() -> None:
def test_missing_basedOn_from_dont_throw() -> None:
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller,
featurizer=learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
pick.run(action=learn_to_pick.ToSelectFrom(actions))
pick.run(action=learn_to_pick.ToSelectFrom(actions))


def test_ToSelectFrom_not_a_list_throws() -> None:
Expand Down

0 comments on commit ca7f469

Please sign in to comment.