From 460cc7d1fae0648bb21357f45c0030ea90ab7650 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 19 Dec 2024 12:33:53 +0100 Subject: [PATCH] handing both default strategies --- .../preprocessors/recursive_splitter.py | 35 ++++++++----------- .../preprocessors/test_recursive_splitter.py | 10 ++++++ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/haystack/components/preprocessors/recursive_splitter.py b/haystack/components/preprocessors/recursive_splitter.py index 8223294f5f..35e87ebdb6 100644 --- a/haystack/components/preprocessors/recursive_splitter.py +++ b/haystack/components/preprocessors/recursive_splitter.py @@ -237,36 +237,29 @@ def _chunk_text(self, text: str) -> List[str]: return chunks # if no separator worked, fall back to word- or character-level chunking - if self.split_units == "word": - return self.fall_back_to_word_level_chunking(text) - - return self.fall_back_to_char_level_chunking(text) + return self.fall_back_to_fixed_chunking(text, self.split_units) - def fall_back_to_word_level_chunking(self, text: str) -> List[str]: + def fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]: """ - Fall back to word-level chunking if no separator works. + Fall back to a fixed chunking approach if no separator works for the text. :param text: The text to be split into chunks. + :param split_units: The unit of the split_length parameter. It can be either "word" or "char". :returns: A list of text chunks. """ - return [ - " ".join(text.split()[i : i + self.split_length]) - for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap) - ] + chunks = [] + step = self.split_length - self.split_overlap - def fall_back_to_char_level_chunking(self, text: str) -> List[str]: - """ - Fall back to character-level chunking if no separator works. + if split_units == "word": + words = text.split() + for i in range(0, self._chunk_length(text), step): + chunks.append(" ".join(words[i : i + self.split_length])) + else: + for i in range(0, self._chunk_length(text), step): + chunks.append(text[i : i + self.split_length]) - :param text: The text to be split into chunks. - :returns: - A list of text chunks. - """ - return [ - text[i : i + self.split_length] - for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap) - ] + return chunks def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Document]) -> None: prev_doc = new_docs[-1] diff --git a/test/components/preprocessors/test_recursive_splitter.py b/test/components/preprocessors/test_recursive_splitter.py index 4ac5582386..d73da66a9e 100644 --- a/test/components/preprocessors/test_recursive_splitter.py +++ b/test/components/preprocessors/test_recursive_splitter.py @@ -409,6 +409,16 @@ def test_run_fallback_to_character_chunking_by_default_length_too_short(): assert len(chunk.content) <= 2 +def test_run_fallback_to_word_chunking_by_default_length_too_short(): + text = "This is some text. This is some more text, and even more text." + separators = ["\n\n", "\n", "."] + splitter = RecursiveDocumentSplitter(split_length=2, separators=separators, split_unit="word") + doc = Document(content=text) + chunks = splitter.run([doc])["documents"] + for chunk in chunks: + assert splitter._chunk_length(chunk.content) <= 2 + + def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit(): """Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap""" splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"], split_unit="char")