From 305bde3a37638e128c8ddb2c136adbbc9826604a Mon Sep 17 00:00:00 2001 From: Tibor Reiss <75096465+tibor-reiss@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:53:36 +0200 Subject: [PATCH] semantic-chunker: add min_chunk_size (#4) Moving code changes here from https://github.com/langchain-ai/langchain/pull/26398 --- .../langchain_experimental/text_splitter.py | 8 ++++++ .../tests/unit_tests/test_text_splitter.py | 26 ++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/libs/experimental/langchain_experimental/text_splitter.py b/libs/experimental/langchain_experimental/text_splitter.py index 0ea0eec..78c8437 100644 --- a/libs/experimental/langchain_experimental/text_splitter.py +++ b/libs/experimental/langchain_experimental/text_splitter.py @@ -117,6 +117,7 @@ def __init__( breakpoint_threshold_amount: Optional[float] = None, number_of_chunks: Optional[int] = None, sentence_split_regex: str = r"(?<=[.?!])\s+", + min_chunk_size: Optional[int] = None, ): self._add_start_index = add_start_index self.embeddings = embeddings @@ -130,6 +131,7 @@ def __init__( ] else: self.breakpoint_threshold_amount = breakpoint_threshold_amount + self.min_chunk_size = min_chunk_size def _calculate_breakpoint_threshold( self, distances: List[float] @@ -250,6 +252,12 @@ def split_text( # Slice the sentence_dicts from the current start index to the end index group = sentences[start_index : end_index + 1] combined_text = " ".join([d["sentence"] for d in group]) + # If specified, merge together small chunks. + if ( + self.min_chunk_size is not None + and len(combined_text) < self.min_chunk_size + ): + continue chunks.append(combined_text) # Update the start index for the next group diff --git a/libs/experimental/tests/unit_tests/test_text_splitter.py b/libs/experimental/tests/unit_tests/test_text_splitter.py index d016001..1fbc0ad 100644 --- a/libs/experimental/tests/unit_tests/test_text_splitter.py +++ b/libs/experimental/tests/unit_tests/test_text_splitter.py @@ -1,5 +1,5 @@ import re -from typing import List +from typing import List, Optional import pytest from langchain_core.embeddings import Embeddings @@ -52,3 +52,27 @@ def test_split_text_gradient(input_length: int, expected_length: int) -> None: chunks = chunker.split_text(" ".join(list_of_sentences)) assert len(chunks) == expected_length + + +@pytest.mark.parametrize( + "min_chunk_size, expected_chunks", + [ + (None, 4), + (30, 4), + (60, 3), + (120, 3), + (240, 2), + ], +) +def test_min_chunk_size(min_chunk_size: Optional[int], expected_chunks: int) -> None: + embeddings = MockEmbeddings() + chunker = SemanticChunker( + embeddings, + breakpoint_threshold_type="percentile", + breakpoint_threshold_amount=50, + min_chunk_size=min_chunk_size, + ) + + chunks = chunker.split_text(SAMPLE_TEXT) + + assert len(chunks) == expected_chunks