diff --git a/caikit_nlp/modules/tokenization/regex_sentence_splitter.py b/caikit_nlp/modules/tokenization/regex_sentence_splitter.py index 5b8adf31..eef02d9e 100644 --- a/caikit_nlp/modules/tokenization/regex_sentence_splitter.py +++ b/caikit_nlp/modules/tokenization/regex_sentence_splitter.py @@ -14,13 +14,19 @@ """Module that provides capability to split documents into sentences via regex""" # Standard +from typing import Iterable +import itertools import os import re # First Party from caikit.core.exceptions import error_handler from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.interfaces.nlp.data_model import Token, TokenizationResults +from caikit.interfaces.nlp.data_model import ( + Token, + TokenizationResults, + TokenizationStreamResult, +) from caikit.interfaces.nlp.tasks import TokenizationTask import alog @@ -98,6 +104,7 @@ def load(cls, model_path: str) -> "RegexSentenceSplitter": config = ModuleConfig.load(os.path.abspath(model_path)) return cls(regex_str=config.regex_str) + @TokenizationTask.taskmethod() def run(self, text: str) -> TokenizationResults: """Run sentence splitting regex on input text. @@ -120,3 +127,129 @@ def run(self, text: str) -> TokenizationResults: tokens.append(token) return TokenizationResults(results=tokens) + + @TokenizationTask.taskmethod(input_streaming=True, output_streaming=True) + def run_bidi_stream( + self, text_stream: Iterable[str] + ) -> Iterable[TokenizationStreamResult]: + """Run bi-directional streaming sentence splitting. Aggregates text + in the stream and returns back concatenable stream of sentences, with + surrounding whitespace included + + Args: + text_stream: Iterable[str] + Text stream to run sentence splitting on + + Returns: + Iterable[TokenizationStreamResult] + """ + # Avoid length check here since it can be time consuming to iterate through stream + # Tee stream to 2 - one to check emptiness, one for full iteration + analysis + text_streams = itertools.tee(text_stream, 2) + try: + next(text_streams[0]) + except StopIteration: + # Empty text case + yield TokenizationStreamResult(results=[], start_index=0, processed_index=0) + + for token_output in self._stream_token_output(text_streams[1]): + # start_index and processed_index here are simplified since each sentence + # is expected to be concatenable and will be streamed + yield TokenizationStreamResult( + results=[token_output], + start_index=token_output.start, + processed_index=token_output.end, + ) + + ################################## Private functions ########################################## + + def _stream_token_output(self, text_stream): + """Function to yield token output from input text stream""" + # NOTE: Can potentially consolidate with parts of the filtered span classification function + # in the future but this implementation currently works for tokens and accounts for + # whitespace between sentences/tokens + + stream_accumulator = "" + detected_tokens = None + token_start_offset = 0 + len_missing_idx = 0 + # Tracker of text up until tokens/sentences detected - accounts for only whitespace case + text_tracker = [] + + def __update_tokens(token, stream_accumulator, len_missing_idx): + # Check if the starting offset for the token is greater than + # token_start_offset already, in which case, we need not + # update the token + if token.start < token_start_offset: + # This is indicating that starting offset of sentence is off as we expect + # the sentence to start at token_start_offset+1. So we need to recalibrate + # the sentence offsets and have them start at token_start_offset. This + # means we need to know the length of the sentence to manipulate the + # token.end, which we do by subtracting end - start + original_start = token.start + token.start = token_start_offset + token.end = ( + token_start_offset + (token.end - original_start) + len_missing_idx + ) + token.text = stream_accumulator[token.start : token.end] + return token + + for text in text_stream: + error.type_check("", str, text=text) + stream_accumulator += text + text_tracker.append(text) + + # In order to avoid processing all of the tokens again, we only + # send out the tokens that are not yet finalized in detected_tokens + matches = self.regex.finditer(stream_accumulator[token_start_offset:]) + detected_tokens = [] + for match_token in matches: + token = Token( + start=match_token.start(), + end=match_token.end(), + text=match_token.group(), + ) + detected_tokens.append(token) + + if len(detected_tokens) > 1: + # Optimization for not keeping track of all text chunks in the case + # when there are actually sentences detected + text_tracker = [] + + # We have detected more than 1 sentence + # Return 1st sentence + new_token = detected_tokens.pop(0) + + new_token = __update_tokens( + new_token, stream_accumulator, len_missing_idx + ) + + # We have detected new sentence, return the new sentence + yield new_token + + # We only send out part of the text, so we need to track + # the starting point of the subsequent + token_start_offset = new_token.end + next_token_len = detected_tokens[0].end - detected_tokens[0].start + len_missing_idx = ( + len(stream_accumulator) - token_start_offset - next_token_len + ) + + # Return remaining sentence(s) + if detected_tokens and len(detected_tokens) > 0: + for detected_token in detected_tokens: + new_token = __update_tokens( + detected_token, stream_accumulator, len_missing_idx + ) + yield new_token + + else: + # This allows us to keep track of text that is only whitespace that would + # otherwise not return tokens since the tokenizer is only used to detect + # sentences. This may have to be adjusted to keep track of any generated trailing + # whitespace + token_start = 0 + for text in text_tracker: + token_end = token_start + len(text) + yield Token(start=token_start, end=token_end, text=text) + token_start = token_end diff --git a/tests/modules/tokenization/test_regex_sentence_splitter.py b/tests/modules/tokenization/test_regex_sentence_splitter.py index f8590c29..c94e9d03 100644 --- a/tests/modules/tokenization/test_regex_sentence_splitter.py +++ b/tests/modules/tokenization/test_regex_sentence_splitter.py @@ -1,11 +1,17 @@ """Tests for regex sentence splitter """ # Standard +from typing import Iterable import os import tempfile # First Party -from caikit.interfaces.nlp.data_model import TokenizationResults +from caikit.core import data_model +from caikit.interfaces.nlp.data_model import ( + Token, + TokenizationResults, + TokenizationStreamResult, +) # Local from caikit_nlp.modules.tokenization.regex_sentence_splitter import ( @@ -15,6 +21,7 @@ ## Setup ######################################################################## # Regex sentence splitter model for reusability across tests +# NOTE: Regex may not be extremely accurate for sentence splitting needs. REGEX_STR = "[^.!?\s][^.!?\n]*(?:[.!?](?!['\"]?\s|$)[^.!?]*)*[.!?]?['\"]?(?=\s|$)" SENTENCE_TOKENIZER = RegexSentenceSplitter.bootstrap(REGEX_STR) DOCUMENT = "What he told me before, I have it in my heart. I am tired of fighting." @@ -39,3 +46,72 @@ def test_save_load_and_run_model(): tokenization_result = new_splitter.run(DOCUMENT) assert isinstance(tokenization_result, TokenizationResults) assert len(tokenization_result.results) == 2 + + +### Streaming tests ############################################################## + + +def test_run_bidi_stream_model(): + """Check if model prediction works as expected for bi-directional stream""" + + stream_input = data_model.DataStream.from_iterable(DOCUMENT) + streaming_tokenization_result = SENTENCE_TOKENIZER.run_bidi_stream(stream_input) + assert isinstance(streaming_tokenization_result, Iterable) + # Convert to list to more easily check outputs + result_list = list(streaming_tokenization_result) + + first_result = result_list[0].results[0] + assert isinstance(first_result, Token) + assert first_result.start == 0 + assert first_result.end == 46 + assert first_result.text == "What he told me before, I have it in my heart." + + # Check processed indices + assert result_list[0].processed_index == 46 + assert result_list[1].processed_index == len(stream_input) + + # Assert total number of results should be equal to expected number of sentences + expected_number_of_sentences = 2 # Sentence tokenizer returns 2 results + count = len(result_list) + assert count == expected_number_of_sentences + + +def test_run_bidi_stream_chunk_stream_input(): + """Check if model prediction with tokenization + with chunks of text input works as expected for bi-directional stream""" + + chunked_document_input = ( + "What he told me ", + "before, I have it in my heart. I am tired of fighting. ", + " The cow jumped over the moon. ", + ) + stream_input = data_model.DataStream.from_iterable(chunked_document_input) + streaming_tokenization_result = SENTENCE_TOKENIZER.run_bidi_stream(stream_input) + result_list = list(streaming_tokenization_result) + # Convert to list to more easily check outputs + first_result = result_list[0].results[0] + assert isinstance(first_result, Token) + assert first_result.start == 0 + assert first_result.end == 46 + assert first_result.text == "What he told me before, I have it in my heart." + + # Check processed indices + assert result_list[0].processed_index == 46 # ...heart. + assert result_list[1].processed_index == 71 # ...fighting. + assert result_list[2].processed_index == 102 # end of doc + + expected_results = 3 + count = len(result_list) + assert count == expected_results + + +def test_run_bidi_stream_empty(): + """Check if tokenization can run with empty space for streaming""" + stream_input = data_model.DataStream.from_iterable("") + streaming_tokenization_result = SENTENCE_TOKENIZER.run_bidi_stream(stream_input) + assert isinstance(streaming_tokenization_result, Iterable) + # Convert to list to more easily check outputs + result_list = list(streaming_tokenization_result) + assert len(result_list) == 1 + assert result_list[0].results == [] + assert result_list[0].processed_index == 0