diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 5430c39..73cd9f8 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -20,6 +20,7 @@ from learn_to_pick.model_repository import ModelRepository from learn_to_pick.vw_logger import VwLogger from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures +from enum import Enum if TYPE_CHECKING: import vowpal_wabbit_next as vw @@ -27,23 +28,15 @@ logger = logging.getLogger(__name__) -class _BasedOn: - def __init__(self, value: Any): - self.value = value - - def __str__(self) -> str: - return str(self.value) - - __repr__ = __str__ - - -def BasedOn(anything: Any) -> _BasedOn: - return _BasedOn(anything) +class Role(Enum): + CONTEXT = 1 + ACTIONS = 2 -class _ToSelectFrom: - def __init__(self, value: Any): +class _Roled: + def __init__(self, value: Any, role: Role): self.value = value + self.role = role def __str__(self) -> str: return str(self.value) @@ -51,35 +44,42 @@ def __str__(self) -> str: __repr__ = __str__ -def ToSelectFrom(anything: Any) -> _ToSelectFrom: +def BasedOn(anything: Any) -> _Roled: + return _Roled(anything, Role.CONTEXT) + + +def ToSelectFrom(anything: Any) -> _Roled: if not isinstance(anything, list): raise ValueError("ToSelectFrom must be a list to select from") - return _ToSelectFrom(anything) + return _Roled(anything, Role.ACTIONS) -class _Embed: - def __init__(self, value: Any, keep: bool = False): +class _Input: + def __init__(self, value: Any, keep: bool = True, embed: bool = False): self.value = value self.keep = keep + self.embed = embed def __str__(self) -> str: return str(self.value) + @staticmethod + def create(value: Any, *args, **kwargs): + if isinstance(value, _Roled): + return _Roled(_Input.create(value.value, *args, **kwargs), value.role) + if isinstance(value, list): + return [_Input.create(v, *args, **kwargs) for v in value] + if isinstance(value, dict): + return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()} + if isinstance(value, _Input): # should we swap? it will allow overwriting + return value + return _Input(value, *args, **kwargs) + __repr__ = __str__ def Embed(anything: Any, keep: bool = False) -> Any: - if isinstance(anything, _ToSelectFrom): - return ToSelectFrom(Embed(anything.value, keep=keep)) - elif isinstance(anything, _BasedOn): - return BasedOn(Embed(anything.value, keep=keep)) - if isinstance(anything, list): - return [Embed(v, keep=keep) for v in anything] - elif isinstance(anything, dict): - return {k: Embed(v, keep=keep) for k, v in anything.items()} - elif isinstance(anything, _Embed): - return anything - return _Embed(anything, keep=keep) + return _Input.create(anything, keep=keep, embed=True) def EmbedAndKeep(anything: Any) -> Any: @@ -93,19 +93,11 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on(inputs: Dict[str, Any]) -> Dict: - return { - k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value - for k in inputs.keys() - if isinstance(inputs[k], _BasedOn) - } - - -def get_to_select_from(inputs: Dict[str, Any]) -> Dict: +def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]: return { - k: inputs[k].value - for k in inputs.keys() - if isinstance(inputs[k], _ToSelectFrom) + k: v.value + for k, v in inputs.items() + if isinstance(v, _Roled) and v.role == role } @@ -480,14 +472,15 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: def _embed_string_type( - item: Union[str, _Embed], model: Any, namespace: str + item: Union[str, _Input], model: Any, namespace: str ) -> Featurized: """Helper function to embed a string or an _Embed object.""" import re result = Featurized() - if isinstance(item, _Embed): - result[namespace] = DenseFeatures(model.encode(item.value)) + if isinstance(item, _Input): + if item.embed: + result[namespace] = DenseFeatures(model.encode(item.value)) if item.keep: keep_str = item.value.replace(" ", "_") result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)} @@ -529,7 +522,7 @@ def _embed_list_type( def embed( - to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]], + to_embed: Union[Union[str, _Input], Dict, List[Union[str, _Input]], List[Dict]], model: Any, namespace: Optional[str] = None, ) -> Union[Featurized, List[Featurized]]: @@ -543,7 +536,7 @@ def embed( Returns: List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value """ - if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance( + if (isinstance(to_embed, _Input) and isinstance(to_embed.value, str)) or isinstance( to_embed, str ): return _embed_string_type(to_embed, model, namespace) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index c85eefa..f51aedc 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -38,8 +38,8 @@ def __init__( selected: Optional[PickBestSelected] = None, ): super().__init__(inputs=inputs, selected=selected or PickBestSelected()) - self.to_select_from = base.get_to_select_from(inputs) - self.based_on = base.get_based_on(inputs) + self.to_select_from = base.filter_inputs(inputs, base.Role.ACTIONS) + self.based_on = base.filter_inputs(inputs, base.Role.CONTEXT) if not self.to_select_from: raise ValueError( "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."