diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index c0ce1c74facb3..efb55c1b984d3 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -670,6 +670,8 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index f09366f1539ce..2f9cf2ac600f4 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -13,6 +13,8 @@ MarkdownHeaderTextSplitter, PythonCodeTextSplitter, RecursiveCharacterTextSplitter, + Tokenizer, + split_text_on_tokens, ) FAKE_PYTHON_TEXT = """ @@ -1175,3 +1177,18 @@ def test_html_header_text_splitter(tmp_path: Path) -> None: docs_from_file = splitter.split_text_from_file(tmp_path / "doc.html") assert docs_from_file == expected + + +def test_split_text_on_tokens() -> None: + """Test splitting by tokens per chunk.""" + text = "foo bar baz 123" + + tokenizer = Tokenizer( + chunk_overlap=3, + tokens_per_chunk=7, + decode=(lambda it: "".join(chr(i) for i in it)), + encode=(lambda it: [ord(c) for c in it]), + ) + output = split_text_on_tokens(text=text, tokenizer=tokenizer) + expected_output = ["foo bar", "bar baz", "baz 123"] + assert output == expected_output