Skip to content

Commit

Permalink
langchain[patch]: inconsistent results with `RecursiveCharacterTextSp…
Browse files Browse the repository at this point in the history
…litter`'s `add_start_index=True` (#16583)

This PR fixes issue #16579
  • Loading branch information
antoniolanza1996 authored Jan 25, 2024
1 parent 42db964 commit 08d3fd7
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
7 changes: 5 additions & 2 deletions libs/langchain/langchain/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,15 @@ def create_documents(
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = -1
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
index = text.find(chunk, index + 1)
offset = index + previous_chunk_len - self._chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents
Expand Down
51 changes: 40 additions & 11 deletions libs/langchain/tests/unit_tests/test_text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MarkdownHeaderTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
TextSplitter,
Tokenizer,
split_text_on_tokens,
)
Expand Down Expand Up @@ -169,19 +170,47 @@ def test_create_documents_with_metadata() -> None:
assert docs == expected_docs


def test_create_documents_with_start_index() -> None:
@pytest.mark.parametrize(
"splitter, text, expected_docs",
[
(
CharacterTextSplitter(
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
),
"foo bar baz 123",
[
Document(page_content="foo bar", metadata={"start_index": 0}),
Document(page_content="bar baz", metadata={"start_index": 4}),
Document(page_content="baz 123", metadata={"start_index": 8}),
],
),
(
RecursiveCharacterTextSplitter(
chunk_size=6,
chunk_overlap=0,
separators=["\n\n", "\n", " ", ""],
add_start_index=True,
),
"w1 w1 w1 w1 w1 w1 w1 w1 w1",
[
Document(page_content="w1 w1", metadata={"start_index": 0}),
Document(page_content="w1 w1", metadata={"start_index": 6}),
Document(page_content="w1 w1", metadata={"start_index": 12}),
Document(page_content="w1 w1", metadata={"start_index": 18}),
Document(page_content="w1", metadata={"start_index": 24}),
],
),
],
)
def test_create_documents_with_start_index(
splitter: TextSplitter, text: str, expected_docs: List[Document]
) -> None:
"""Test create documents method."""
texts = ["foo bar baz 123"]
splitter = CharacterTextSplitter(
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
)
docs = splitter.create_documents(texts)
expected_docs = [
Document(page_content="foo bar", metadata={"start_index": 0}),
Document(page_content="bar baz", metadata={"start_index": 4}),
Document(page_content="baz 123", metadata={"start_index": 8}),
]
docs = splitter.create_documents([text])
assert docs == expected_docs
for doc in docs:
s_i = doc.metadata["start_index"]
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content


def test_metadata_not_shallow() -> None:
Expand Down

0 comments on commit 08d3fd7

Please sign in to comment.