diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index c4ece25320455..8d49a3c94baeb 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -141,12 +141,15 @@ def create_documents( _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): - index = -1 + index = 0 + previous_chunk_len = 0 for chunk in self.split_text(text): metadata = copy.deepcopy(_metadatas[i]) if self._add_start_index: - index = text.find(chunk, index + 1) + offset = index + previous_chunk_len - self._chunk_overlap + index = text.find(chunk, max(0, offset)) metadata["start_index"] = index + previous_chunk_len = len(chunk) new_doc = Document(page_content=chunk, metadata=metadata) documents.append(new_doc) return documents diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index f099cc7cc2d21..6156d794d02e3 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -13,6 +13,7 @@ MarkdownHeaderTextSplitter, PythonCodeTextSplitter, RecursiveCharacterTextSplitter, + TextSplitter, Tokenizer, split_text_on_tokens, ) @@ -169,19 +170,47 @@ def test_create_documents_with_metadata() -> None: assert docs == expected_docs -def test_create_documents_with_start_index() -> None: +@pytest.mark.parametrize( + "splitter, text, expected_docs", + [ + ( + CharacterTextSplitter( + separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True + ), + "foo bar baz 123", + [ + Document(page_content="foo bar", metadata={"start_index": 0}), + Document(page_content="bar baz", metadata={"start_index": 4}), + Document(page_content="baz 123", metadata={"start_index": 8}), + ], + ), + ( + RecursiveCharacterTextSplitter( + chunk_size=6, + chunk_overlap=0, + separators=["\n\n", "\n", " ", ""], + add_start_index=True, + ), + "w1 w1 w1 w1 w1 w1 w1 w1 w1", + [ + Document(page_content="w1 w1", metadata={"start_index": 0}), + Document(page_content="w1 w1", metadata={"start_index": 6}), + Document(page_content="w1 w1", metadata={"start_index": 12}), + Document(page_content="w1 w1", metadata={"start_index": 18}), + Document(page_content="w1", metadata={"start_index": 24}), + ], + ), + ], +) +def test_create_documents_with_start_index( + splitter: TextSplitter, text: str, expected_docs: List[Document] +) -> None: """Test create documents method.""" - texts = ["foo bar baz 123"] - splitter = CharacterTextSplitter( - separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True - ) - docs = splitter.create_documents(texts) - expected_docs = [ - Document(page_content="foo bar", metadata={"start_index": 0}), - Document(page_content="bar baz", metadata={"start_index": 4}), - Document(page_content="baz 123", metadata={"start_index": 8}), - ] + docs = splitter.create_documents([text]) assert docs == expected_docs + for doc in docs: + s_i = doc.metadata["start_index"] + assert text[s_i : s_i + len(doc.page_content)] == doc.page_content def test_metadata_not_shallow() -> None: