diff --git a/dialogy/plugins/text/duckling_plugin/__init__.py b/dialogy/plugins/text/duckling_plugin/__init__.py index 89bdfef8..456e932a 100644 --- a/dialogy/plugins/text/duckling_plugin/__init__.py +++ b/dialogy/plugins/text/duckling_plugin/__init__.py @@ -58,6 +58,7 @@ def update_workflow(workflow, entities): import pydash as py_ import pytz import requests +from numpy import isin from pytz.tzinfo import BaseTzInfo from tqdm import tqdm @@ -423,7 +424,7 @@ def _get_entities( def _get_entities_concurrent( self, - texts: Union[str, List[str]], + texts: List[str], locale: str = "en_IN", reference_time: Optional[int] = None, use_latent: bool = False, @@ -444,33 +445,20 @@ def _get_entities_concurrent( :return: Duckling entities as :code:`dicts`. :rtype: List[List[Dict[str, Any]]] """ - futures_ = [] workers = min(10, len(texts)) - if isinstance(texts, str): - entities_list = [ - self._get_entities( - texts, + with futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures_ = [ + executor.submit( + self._get_entities, + text, locale, reference_time=reference_time, use_latent=use_latent, + sort_idx=i, ) + for i, text in enumerate(texts) ] - else: - with futures.ThreadPoolExecutor(max_workers=workers) as executor: - futures_ = [ - executor.submit( - self._get_entities, - text, - locale, - reference_time=reference_time, - use_latent=use_latent, - sort_idx=i, - ) - for i, text in enumerate(texts) - ] - entities_list = [ - future.result() for future in futures.as_completed(futures_) - ] + entities_list = [future.result() for future in futures.as_completed(futures_)] return [ entities[const.VALUE] for entities in sorted( @@ -478,19 +466,21 @@ def _get_entities_concurrent( ) ] - def utility(self, *args: Any) -> List[BaseEntity]: - """ - Produces Duckling entities, runs with a :ref:`Workflow's run` method. - - :param args: Expects a tuple of :code:`Tuple[natural language for parsing entities, reference time in milliseconds, locale]` - :type args: Tuple(Union[str, List[str]], int, str) - :return: A list of duckling entities. - :rtype: List[BaseEntity] - """ - list_of_entities: List[List[Dict[str, Any]]] = [] - shaped_entities: List[List[BaseEntity]] = [] - - input_, reference_time, locale, use_latent = args + def apply_entity_classes( + self, list_of_entities: List[List[Dict[str, Any]]] + ) -> List[BaseEntity]: + shaped_entities = [] + for (alternative_index, entities) in enumerate(list_of_entities): + shaped_entities.append(self._reshape(entities, alternative_index)) + return py_.flatten(shaped_entities) + + def validate( + self, input_: Union[str, List[str]], reference_time: Optional[int] + ) -> "DucklingPlugin": + input_is_str = isinstance(input_, str) + inputs_are_list_of_strings = isinstance(input_, list) and all( + isinstance(text, str) for text in input_ + ) if not isinstance(reference_time, int) and self.datetime_filters: raise TypeError( "Duckling requires reference_time to be a unix timestamp (int) but" @@ -498,35 +488,50 @@ def utility(self, *args: Any) -> List[BaseEntity]: "https://stackoverflow.com/questions/20822821/what-is-a-unix-timestamp-and-why-use-it\n" ) - self.reference_time = reference_time - input_size = 1 - args = (input_, locale) - kwargs = {"reference_time": reference_time, "use_latent": use_latent} + if not input_is_str and not inputs_are_list_of_strings: + raise TypeError(f"Expected {input_} to be a List[str] or str.") + return self - input_is_str = isinstance(input_, str) - inputs_are_list_of_strings = isinstance(input_, list) and all( - isinstance(text, str) for text in input_ - ) - if inputs_are_list_of_strings: - input_size = len(input_) + def extract( + self, + input_: Union[str, List[str]], + locale: str, + reference_time: Optional[int] = None, + use_latent: bool = False, + ) -> List[BaseEntity]: + list_of_entities: List[List[Dict[str, Any]]] = [] + entities: List[BaseEntity] = [] - try: - if input_is_str or inputs_are_list_of_strings: - list_of_entities = self._get_entities_concurrent(*args, **kwargs) - else: - raise TypeError(f"Expected {input_} to be a List[str] or str.") + self.validate(input_, reference_time) + self.reference_time = reference_time - for (alternative_index, entities) in enumerate(list_of_entities): - shaped_entities.append(self._reshape(entities, alternative_index)) + if isinstance(input_, str): + input_ = [input_] - shaped_entities_flattened = py_.flatten(shaped_entities) - aggregate_entities = self.entity_consensus( - shaped_entities_flattened, input_size + try: + list_of_entities = self._get_entities_concurrent( + input_, locale, reference_time=reference_time, use_latent=use_latent ) - return self.apply_filters(aggregate_entities) + entities = self.apply_entity_classes(list_of_entities) + entities = self.entity_consensus(entities, len(input_)) + return self.apply_filters(entities) except ValueError as value_error: raise ValueError(str(value_error)) from value_error + def utility(self, *args: Any) -> List[BaseEntity]: + """ + Produces Duckling entities, runs with a :ref:`Workflow's run` method. + + :param args: Expects a tuple of :code:`Tuple[natural language for parsing entities, reference time in milliseconds, locale]` + :type args: Tuple(Union[str, List[str]], int, str) + :return: A list of duckling entities. + :rtype: List[BaseEntity] + """ + input_, reference_time, locale, use_latent = args + return self.extract( + input_, locale, reference_time=reference_time, use_latent=use_latent + ) + def transform(self, training_data: pd.DataFrame) -> pd.DataFrame: """ Transform training data.