Skip to content

Commit

Permalink
handing both default strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Dec 19, 2024
1 parent 0807902 commit 460cc7d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
35 changes: 14 additions & 21 deletions haystack/components/preprocessors/recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,36 +237,29 @@ def _chunk_text(self, text: str) -> List[str]:
return chunks

# if no separator worked, fall back to word- or character-level chunking
if self.split_units == "word":
return self.fall_back_to_word_level_chunking(text)

return self.fall_back_to_char_level_chunking(text)
return self.fall_back_to_fixed_chunking(text, self.split_units)

def fall_back_to_word_level_chunking(self, text: str) -> List[str]:
def fall_back_to_fixed_chunking(self, text: str, split_units: Literal["word", "char"]) -> List[str]:
"""
Fall back to word-level chunking if no separator works.
Fall back to a fixed chunking approach if no separator works for the text.
:param text: The text to be split into chunks.
:param split_units: The unit of the split_length parameter. It can be either "word" or "char".
:returns:
A list of text chunks.
"""
return [
" ".join(text.split()[i : i + self.split_length])
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]
chunks = []
step = self.split_length - self.split_overlap

def fall_back_to_char_level_chunking(self, text: str) -> List[str]:
"""
Fall back to character-level chunking if no separator works.
if split_units == "word":
words = text.split()
for i in range(0, self._chunk_length(text), step):
chunks.append(" ".join(words[i : i + self.split_length]))
else:
for i in range(0, self._chunk_length(text), step):
chunks.append(text[i : i + self.split_length])

:param text: The text to be split into chunks.
:returns:
A list of text chunks.
"""
return [
text[i : i + self.split_length]
for i in range(0, self._chunk_length(text), self.split_length - self.split_overlap)
]
return chunks

def _add_overlap_info(self, curr_pos: int, new_doc: Document, new_docs: List[Document]) -> None:
prev_doc = new_docs[-1]
Expand Down
10 changes: 10 additions & 0 deletions test/components/preprocessors/test_recursive_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,16 @@ def test_run_fallback_to_character_chunking_by_default_length_too_short():
assert len(chunk.content) <= 2


def test_run_fallback_to_word_chunking_by_default_length_too_short():
text = "This is some text. This is some more text, and even more text."
separators = ["\n\n", "\n", "."]
splitter = RecursiveDocumentSplitter(split_length=2, separators=separators, split_unit="word")
doc = Document(content=text)
chunks = splitter.run([doc])["documents"]
for chunk in chunks:
assert splitter._chunk_length(chunk.content) <= 2


def test_run_custom_sentence_tokenizer_document_and_overlap_char_unit():
"""Test that RecursiveDocumentSplitter works correctly with custom sentence tokenizer and overlap"""
splitter = RecursiveDocumentSplitter(split_length=25, split_overlap=5, separators=["sentence"], split_unit="char")
Expand Down

0 comments on commit 460cc7d

Please sign in to comment.