Skip to content

Commit

Permalink
semantic-chunker: add min_chunk_size (#4)
Browse files Browse the repository at this point in the history
Moving code changes here from
langchain-ai/langchain#26398
  • Loading branch information
tibor-reiss authored Sep 26, 2024
1 parent 59a1111 commit 305bde3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
8 changes: 8 additions & 0 deletions libs/experimental/langchain_experimental/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
breakpoint_threshold_amount: Optional[float] = None,
number_of_chunks: Optional[int] = None,
sentence_split_regex: str = r"(?<=[.?!])\s+",
min_chunk_size: Optional[int] = None,
):
self._add_start_index = add_start_index
self.embeddings = embeddings
Expand All @@ -130,6 +131,7 @@ def __init__(
]
else:
self.breakpoint_threshold_amount = breakpoint_threshold_amount
self.min_chunk_size = min_chunk_size

def _calculate_breakpoint_threshold(
self, distances: List[float]
Expand Down Expand Up @@ -250,6 +252,12 @@ def split_text(
# Slice the sentence_dicts from the current start index to the end index
group = sentences[start_index : end_index + 1]
combined_text = " ".join([d["sentence"] for d in group])
# If specified, merge together small chunks.
if (
self.min_chunk_size is not None
and len(combined_text) < self.min_chunk_size
):
continue
chunks.append(combined_text)

# Update the start index for the next group
Expand Down
26 changes: 25 additions & 1 deletion libs/experimental/tests/unit_tests/test_text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List
from typing import List, Optional

import pytest
from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -52,3 +52,27 @@ def test_split_text_gradient(input_length: int, expected_length: int) -> None:
chunks = chunker.split_text(" ".join(list_of_sentences))

assert len(chunks) == expected_length


@pytest.mark.parametrize(
"min_chunk_size, expected_chunks",
[
(None, 4),
(30, 4),
(60, 3),
(120, 3),
(240, 2),
],
)
def test_min_chunk_size(min_chunk_size: Optional[int], expected_chunks: int) -> None:
embeddings = MockEmbeddings()
chunker = SemanticChunker(
embeddings,
breakpoint_threshold_type="percentile",
breakpoint_threshold_amount=50,
min_chunk_size=min_chunk_size,
)

chunks = chunker.split_text(SAMPLE_TEXT)

assert len(chunks) == expected_chunks

0 comments on commit 305bde3

Please sign in to comment.