Skip to content

Commit

Permalink
Move potential nltk download to warm_up
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Dec 16, 2024
1 parent a5b57f4 commit e861c85
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
27 changes: 19 additions & 8 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
splitting_function=splitting_function,
respect_sentence_boundary=respect_sentence_boundary,
)

if split_by == "sentence" or (respect_sentence_boundary and split_by == "word"):
self._use_sentence_splitter = split_by == "sentence" or (respect_sentence_boundary and split_by == "word")
if self._use_sentence_splitter:
nltk_imports.check()
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
extend_abbreviations=extend_abbreviations,
keep_white_spaces=True,
)
self.sentence_splitter = None

if split_by == "sentence":
# ToDo: remove this warning in the next major release
Expand Down Expand Up @@ -164,6 +159,18 @@ def _init_checks(
)
self.respect_sentence_boundary = False

def warm_up(self):
"""
Warm up the DocumentSplitter by loading the sentence tokenizer.
"""
if self._use_sentence_splitter and self.sentence_splitter is None:
self.sentence_splitter = SentenceSplitter(
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
keep_white_spaces=True,
)

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Expand All @@ -182,6 +189,10 @@ def run(self, documents: List[Document]):
:raises TypeError: if the input is not a list of Documents.
:raises ValueError: if the content of a document is None.
"""
if self._use_sentence_splitter and self.sentence_splitter is None:
raise RuntimeError(
"The component DocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
raise TypeError("DocumentSplitter expects a List of Documents as input.")
Expand Down
19 changes: 13 additions & 6 deletions haystack/components/preprocessors/nltk_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,21 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.respect_sentence_boundary = respect_sentence_boundary
self.use_split_rules = use_split_rules
self.extend_abbreviations = extend_abbreviations
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
extend_abbreviations=extend_abbreviations,
keep_white_spaces=True,
)
self.sentence_splitter = None
self.language = language

def warm_up(self):
"""
Warm up the NLTKDocumentSplitter by loading the sentence tokenizer.
"""
if self.sentence_splitter is None:
self.sentence_splitter = SentenceSplitter(
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
keep_white_spaces=True,
)

def _split_into_units(
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"]
) -> List[str]:
Expand Down

0 comments on commit e861c85

Please sign in to comment.