Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Bidirectional streaming for regex sentence splitting #346

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 134 additions & 1 deletion caikit_nlp/modules/tokenization/regex_sentence_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
"""Module that provides capability to split documents into sentences via regex"""

# Standard
from typing import Iterable
import itertools
import os
import re

# First Party
from caikit.core.exceptions import error_handler
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.interfaces.nlp.data_model import Token, TokenizationResults
from caikit.interfaces.nlp.data_model import (
Token,
TokenizationResults,
TokenizationStreamResult,
)
from caikit.interfaces.nlp.tasks import TokenizationTask
import alog

Expand Down Expand Up @@ -98,6 +104,7 @@ def load(cls, model_path: str) -> "RegexSentenceSplitter":
config = ModuleConfig.load(os.path.abspath(model_path))
return cls(regex_str=config.regex_str)

@TokenizationTask.taskmethod()
def run(self, text: str) -> TokenizationResults:
"""Run sentence splitting regex on input text.

Expand All @@ -120,3 +127,129 @@ def run(self, text: str) -> TokenizationResults:
tokens.append(token)

return TokenizationResults(results=tokens)

@TokenizationTask.taskmethod(input_streaming=True, output_streaming=True)
def run_bidi_stream(
self, text_stream: Iterable[str]
) -> Iterable[TokenizationStreamResult]:
"""Run bi-directional streaming sentence splitting. Aggregates text
in the stream and returns back concatenable stream of sentences, with
surrounding whitespace included

Args:
text_stream: Iterable[str]
Text stream to run sentence splitting on

Returns:
Iterable[TokenizationStreamResult]
"""
# Avoid length check here since it can be time consuming to iterate through stream
# Tee stream to 2 - one to check emptiness, one for full iteration + analysis
text_streams = itertools.tee(text_stream, 2)
try:
next(text_streams[0])
except StopIteration:
# Empty text case
yield TokenizationStreamResult(results=[], start_index=0, processed_index=0)

for token_output in self._stream_token_output(text_streams[1]):
# start_index and processed_index here are simplified since each sentence
# is expected to be concatenable and will be streamed
yield TokenizationStreamResult(
results=[token_output],
start_index=token_output.start,
processed_index=token_output.end,
)

################################## Private functions ##########################################

def _stream_token_output(self, text_stream):
"""Function to yield token output from input text stream"""
# NOTE: Can potentially consolidate with parts of the filtered span classification function
# in the future but this implementation currently works for tokens and accounts for
# whitespace between sentences/tokens

stream_accumulator = ""
detected_tokens = None
token_start_offset = 0
len_missing_idx = 0
# Tracker of text up until tokens/sentences detected - accounts for only whitespace case
text_tracker = []

def __update_tokens(token, stream_accumulator, len_missing_idx):
# Check if the starting offset for the token is greater than
# token_start_offset already, in which case, we need not
# update the token
if token.start < token_start_offset:
# This is indicating that starting offset of sentence is off as we expect
# the sentence to start at token_start_offset+1. So we need to recalibrate
# the sentence offsets and have them start at token_start_offset. This
# means we need to know the length of the sentence to manipulate the
# token.end, which we do by subtracting end - start
original_start = token.start
token.start = token_start_offset
token.end = (
token_start_offset + (token.end - original_start) + len_missing_idx
)
token.text = stream_accumulator[token.start : token.end]
return token

for text in text_stream:
error.type_check("<NLP38367928E>", str, text=text)
stream_accumulator += text
text_tracker.append(text)

# In order to avoid processing all of the tokens again, we only
# send out the tokens that are not yet finalized in detected_tokens
matches = self.regex.finditer(stream_accumulator[token_start_offset:])
detected_tokens = []
for match_token in matches:
token = Token(
start=match_token.start(),
end=match_token.end(),
text=match_token.group(),
)
detected_tokens.append(token)

if len(detected_tokens) > 1:
# Optimization for not keeping track of all text chunks in the case
# when there are actually sentences detected
text_tracker = []

# We have detected more than 1 sentence
# Return 1st sentence
new_token = detected_tokens.pop(0)

new_token = __update_tokens(
new_token, stream_accumulator, len_missing_idx
)

# We have detected new sentence, return the new sentence
yield new_token

# We only send out part of the text, so we need to track
# the starting point of the subsequent
token_start_offset = new_token.end
next_token_len = detected_tokens[0].end - detected_tokens[0].start
len_missing_idx = (
len(stream_accumulator) - token_start_offset - next_token_len
)

# Return remaining sentence(s)
if detected_tokens and len(detected_tokens) > 0:
for detected_token in detected_tokens:
new_token = __update_tokens(
detected_token, stream_accumulator, len_missing_idx
)
yield new_token

else:
# This allows us to keep track of text that is only whitespace that would
# otherwise not return tokens since the tokenizer is only used to detect
# sentences. This may have to be adjusted to keep track of any generated trailing
# whitespace
token_start = 0
for text in text_tracker:
token_end = token_start + len(text)
yield Token(start=token_start, end=token_end, text=text)
token_start = token_end
78 changes: 77 additions & 1 deletion tests/modules/tokenization/test_regex_sentence_splitter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Tests for regex sentence splitter
"""
# Standard
from typing import Iterable
import os
import tempfile

# First Party
from caikit.interfaces.nlp.data_model import TokenizationResults
from caikit.core import data_model
from caikit.interfaces.nlp.data_model import (
Token,
TokenizationResults,
TokenizationStreamResult,
)

# Local
from caikit_nlp.modules.tokenization.regex_sentence_splitter import (
Expand All @@ -15,6 +21,7 @@
## Setup ########################################################################

# Regex sentence splitter model for reusability across tests
# NOTE: Regex may not be extremely accurate for sentence splitting needs.
REGEX_STR = "[^.!?\s][^.!?\n]*(?:[.!?](?!['\"]?\s|$)[^.!?]*)*[.!?]?['\"]?(?=\s|$)"
SENTENCE_TOKENIZER = RegexSentenceSplitter.bootstrap(REGEX_STR)
DOCUMENT = "What he told me before, I have it in my heart. I am tired of fighting."
Expand All @@ -39,3 +46,72 @@ def test_save_load_and_run_model():
tokenization_result = new_splitter.run(DOCUMENT)
assert isinstance(tokenization_result, TokenizationResults)
assert len(tokenization_result.results) == 2


### Streaming tests ##############################################################


def test_run_bidi_stream_model():
"""Check if model prediction works as expected for bi-directional stream"""

stream_input = data_model.DataStream.from_iterable(DOCUMENT)
streaming_tokenization_result = SENTENCE_TOKENIZER.run_bidi_stream(stream_input)
assert isinstance(streaming_tokenization_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_tokenization_result)

first_result = result_list[0].results[0]
assert isinstance(first_result, Token)
assert first_result.start == 0
assert first_result.end == 46
assert first_result.text == "What he told me before, I have it in my heart."

# Check processed indices
assert result_list[0].processed_index == 46
assert result_list[1].processed_index == len(stream_input)

# Assert total number of results should be equal to expected number of sentences
expected_number_of_sentences = 2 # Sentence tokenizer returns 2 results
count = len(result_list)
assert count == expected_number_of_sentences


def test_run_bidi_stream_chunk_stream_input():
"""Check if model prediction with tokenization
with chunks of text input works as expected for bi-directional stream"""

chunked_document_input = (
"What he told me ",
"before, I have it in my heart. I am tired of fighting. ",
" The cow jumped over the moon. ",
)
stream_input = data_model.DataStream.from_iterable(chunked_document_input)
streaming_tokenization_result = SENTENCE_TOKENIZER.run_bidi_stream(stream_input)
result_list = list(streaming_tokenization_result)
# Convert to list to more easily check outputs
first_result = result_list[0].results[0]
assert isinstance(first_result, Token)
assert first_result.start == 0
assert first_result.end == 46
assert first_result.text == "What he told me before, I have it in my heart."

# Check processed indices
assert result_list[0].processed_index == 46 # ...heart.
assert result_list[1].processed_index == 71 # ...fighting.
assert result_list[2].processed_index == 102 # end of doc

expected_results = 3
count = len(result_list)
assert count == expected_results


def test_run_bidi_stream_empty():
"""Check if tokenization can run with empty space for streaming"""
stream_input = data_model.DataStream.from_iterable("")
streaming_tokenization_result = SENTENCE_TOKENIZER.run_bidi_stream(stream_input)
assert isinstance(streaming_tokenization_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_tokenization_result)
assert len(result_list) == 1
assert result_list[0].results == []
assert result_list[0].processed_index == 0
Loading