Skip to content

Commit

Permalink
featurization cleanup intermediate changes (#36)
Browse files Browse the repository at this point in the history
* unified role

* _Embed -> _Featurize

* _Featurize -> _Input

* black
  • Loading branch information
ataymano authored Nov 20, 2023
1 parent 4aee78d commit 213e931
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 48 deletions.
85 changes: 39 additions & 46 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,66 +20,66 @@
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger
from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

if TYPE_CHECKING:
import vowpal_wabbit_next as vw

logger = logging.getLogger(__name__)


class _BasedOn:
def __init__(self, value: Any):
self.value = value

def __str__(self) -> str:
return str(self.value)

__repr__ = __str__


def BasedOn(anything: Any) -> _BasedOn:
return _BasedOn(anything)
class Role(Enum):
CONTEXT = 1
ACTIONS = 2


class _ToSelectFrom:
def __init__(self, value: Any):
class _Roled:
def __init__(self, value: Any, role: Role):
self.value = value
self.role = role

def __str__(self) -> str:
return str(self.value)

__repr__ = __str__


def ToSelectFrom(anything: Any) -> _ToSelectFrom:
def BasedOn(anything: Any) -> _Roled:
return _Roled(anything, Role.CONTEXT)


def ToSelectFrom(anything: Any) -> _Roled:
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
return _Roled(anything, Role.ACTIONS)


class _Embed:
def __init__(self, value: Any, keep: bool = False):
class _Input:
def __init__(self, value: Any, keep: bool = True, embed: bool = False):
self.value = value
self.keep = keep
self.embed = embed

def __str__(self) -> str:
return str(self.value)

@staticmethod
def create(value: Any, *args, **kwargs):
if isinstance(value, _Roled):
return _Roled(_Input.create(value.value, *args, **kwargs), value.role)
if isinstance(value, list):
return [_Input.create(v, *args, **kwargs) for v in value]
if isinstance(value, dict):
return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()}
if isinstance(value, _Input): # should we swap? it will allow overwriting
return value
return _Input(value, *args, **kwargs)

__repr__ = __str__


def Embed(anything: Any, keep: bool = False) -> Any:
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
return BasedOn(Embed(anything.value, keep=keep))
if isinstance(anything, list):
return [Embed(v, keep=keep) for v in anything]
elif isinstance(anything, dict):
return {k: Embed(v, keep=keep) for k, v in anything.items()}
elif isinstance(anything, _Embed):
return anything
return _Embed(anything, keep=keep)
return _Input.create(anything, keep=keep, embed=True)


def EmbedAndKeep(anything: Any) -> Any:
Expand All @@ -93,19 +93,11 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam
return [parser.parse_line(line) for line in input_str.split("\n")]


def get_based_on(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
}


def get_to_select_from(inputs: Dict[str, Any]) -> Dict:
def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
k: v.value
for k, v in inputs.items()
if isinstance(v, _Roled) and v.role == role
}


Expand Down Expand Up @@ -480,14 +472,15 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:


def _embed_string_type(
item: Union[str, _Embed], model: Any, namespace: str
item: Union[str, _Input], model: Any, namespace: str
) -> Featurized:
"""Helper function to embed a string or an _Embed object."""
import re

result = Featurized()
if isinstance(item, _Embed):
result[namespace] = DenseFeatures(model.encode(item.value))
if isinstance(item, _Input):
if item.embed:
result[namespace] = DenseFeatures(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_")
result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
Expand Down Expand Up @@ -529,7 +522,7 @@ def _embed_list_type(


def embed(
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
to_embed: Union[Union[str, _Input], Dict, List[Union[str, _Input]], List[Dict]],
model: Any,
namespace: Optional[str] = None,
) -> Union[Featurized, List[Featurized]]:
Expand All @@ -543,7 +536,7 @@ def embed(
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
if (isinstance(to_embed, _Input) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
return _embed_string_type(to_embed, model, namespace)
Expand Down
4 changes: 2 additions & 2 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected or PickBestSelected())
self.to_select_from = base.get_to_select_from(inputs)
self.based_on = base.get_based_on(inputs)
self.to_select_from = base.filter_inputs(inputs, base.Role.ACTIONS)
self.based_on = base.filter_inputs(inputs, base.Role.CONTEXT)
if not self.to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
Expand Down

0 comments on commit 213e931

Please sign in to comment.