Skip to content

Commit

Permalink
refactor: duckling plugin utility.
Browse files Browse the repository at this point in the history
  • Loading branch information
ltbringer committed Jan 26, 2022
1 parent cb8ca20 commit ce9cf50
Showing 1 changed file with 61 additions and 56 deletions.
117 changes: 61 additions & 56 deletions dialogy/plugins/text/duckling_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def update_workflow(workflow, entities):
import pydash as py_
import pytz
import requests
from numpy import isin
from pytz.tzinfo import BaseTzInfo
from tqdm import tqdm

Expand Down Expand Up @@ -423,7 +424,7 @@ def _get_entities(

def _get_entities_concurrent(
self,
texts: Union[str, List[str]],
texts: List[str],
locale: str = "en_IN",
reference_time: Optional[int] = None,
use_latent: bool = False,
Expand All @@ -444,89 +445,93 @@ def _get_entities_concurrent(
:return: Duckling entities as :code:`dicts`.
:rtype: List[List[Dict[str, Any]]]
"""
futures_ = []
workers = min(10, len(texts))
if isinstance(texts, str):
entities_list = [
self._get_entities(
texts,
with futures.ThreadPoolExecutor(max_workers=workers) as executor:
futures_ = [
executor.submit(
self._get_entities,
text,
locale,
reference_time=reference_time,
use_latent=use_latent,
sort_idx=i,
)
for i, text in enumerate(texts)
]
else:
with futures.ThreadPoolExecutor(max_workers=workers) as executor:
futures_ = [
executor.submit(
self._get_entities,
text,
locale,
reference_time=reference_time,
use_latent=use_latent,
sort_idx=i,
)
for i, text in enumerate(texts)
]
entities_list = [
future.result() for future in futures.as_completed(futures_)
]
entities_list = [future.result() for future in futures.as_completed(futures_)]
return [
entities[const.VALUE]
for entities in sorted(
entities_list, key=lambda entities: entities[const.IDX]
)
]

def utility(self, *args: Any) -> List[BaseEntity]:
"""
Produces Duckling entities, runs with a :ref:`Workflow's run<workflow_run>` method.
:param args: Expects a tuple of :code:`Tuple[natural language for parsing entities, reference time in milliseconds, locale]`
:type args: Tuple(Union[str, List[str]], int, str)
:return: A list of duckling entities.
:rtype: List[BaseEntity]
"""
list_of_entities: List[List[Dict[str, Any]]] = []
shaped_entities: List[List[BaseEntity]] = []

input_, reference_time, locale, use_latent = args
def apply_entity_classes(
self, list_of_entities: List[List[Dict[str, Any]]]
) -> List[BaseEntity]:
shaped_entities = []
for (alternative_index, entities) in enumerate(list_of_entities):
shaped_entities.append(self._reshape(entities, alternative_index))
return py_.flatten(shaped_entities)

def validate(
self, input_: Union[str, List[str]], reference_time: Optional[int]
) -> "DucklingPlugin":
input_is_str = isinstance(input_, str)
inputs_are_list_of_strings = isinstance(input_, list) and all(
isinstance(text, str) for text in input_
)
if not isinstance(reference_time, int) and self.datetime_filters:
raise TypeError(
"Duckling requires reference_time to be a unix timestamp (int) but"
f" {type(reference_time)} was found"
"https://stackoverflow.com/questions/20822821/what-is-a-unix-timestamp-and-why-use-it\n"
)

self.reference_time = reference_time
input_size = 1
args = (input_, locale)
kwargs = {"reference_time": reference_time, "use_latent": use_latent}
if not input_is_str and not inputs_are_list_of_strings:
raise TypeError(f"Expected {input_} to be a List[str] or str.")
return self

input_is_str = isinstance(input_, str)
inputs_are_list_of_strings = isinstance(input_, list) and all(
isinstance(text, str) for text in input_
)
if inputs_are_list_of_strings:
input_size = len(input_)
def extract(
self,
input_: Union[str, List[str]],
locale: str,
reference_time: Optional[int] = None,
use_latent: bool = False,
) -> List[BaseEntity]:
list_of_entities: List[List[Dict[str, Any]]] = []
entities: List[BaseEntity] = []

try:
if input_is_str or inputs_are_list_of_strings:
list_of_entities = self._get_entities_concurrent(*args, **kwargs)
else:
raise TypeError(f"Expected {input_} to be a List[str] or str.")
self.validate(input_, reference_time)
self.reference_time = reference_time

for (alternative_index, entities) in enumerate(list_of_entities):
shaped_entities.append(self._reshape(entities, alternative_index))
if isinstance(input_, str):
input_ = [input_]

shaped_entities_flattened = py_.flatten(shaped_entities)
aggregate_entities = self.entity_consensus(
shaped_entities_flattened, input_size
try:
list_of_entities = self._get_entities_concurrent(
input_, locale, reference_time=reference_time, use_latent=use_latent
)
return self.apply_filters(aggregate_entities)
entities = self.apply_entity_classes(list_of_entities)
entities = self.entity_consensus(entities, len(input_))
return self.apply_filters(entities)
except ValueError as value_error:
raise ValueError(str(value_error)) from value_error

def utility(self, *args: Any) -> List[BaseEntity]:
"""
Produces Duckling entities, runs with a :ref:`Workflow's run<workflow_run>` method.
:param args: Expects a tuple of :code:`Tuple[natural language for parsing entities, reference time in milliseconds, locale]`
:type args: Tuple(Union[str, List[str]], int, str)
:return: A list of duckling entities.
:rtype: List[BaseEntity]
"""
input_, reference_time, locale, use_latent = args
return self.extract(
input_, locale, reference_time=reference_time, use_latent=use_latent
)

def transform(self, training_data: pd.DataFrame) -> pd.DataFrame:
"""
Transform training data.
Expand Down

0 comments on commit ce9cf50

Please sign in to comment.