diff --git a/docs/source/api.rst b/docs/source/api.rst index 27c4bdb9..c475ddb5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -27,6 +27,12 @@ TextSplitter .. autoclass:: neo4j_graphrag.experimental.components.text_splitters.base.TextSplitter :members: run +FixedSizeSplitter +================= + +.. autoclass:: neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter.FixedSizeSplitter + :members: run + LangChainTextSplitterAdapter ============================ diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index a5d438fa..34d5145e 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -96,16 +96,24 @@ Document Splitter ================= Document splitters, as the name indicate, split documents into smaller chunks -that can be processed within the LLM token limits. Wrappers for LangChain and LlamaIndex -text splitters are included in this package: +that can be processed within the LLM token limits: +.. code:: python + + from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter + + splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200) + splitter.run(text="Hello World. Life is beautiful.") + + +Wrappers for LangChain and LlamaIndex text splitters are included in this package: .. code:: python from langchain_text_splitters import CharacterTextSplitter from neo4j_graphrag.experimental.components.text_splitters.langchain import LangChainTextSplitterAdapter splitter = LangChainTextSplitterAdapter( - CharacterTextSplitter(chunk_size=500, chunk_overlap=100, separator=".") + CharacterTextSplitter(chunk_size=4000, chunk_overlap=200, separator=".") ) splitter.run(text="Hello World. Life is beautiful.") diff --git a/examples/pipeline/kg_builder_from_pdf.py b/examples/pipeline/kg_builder_from_pdf.py index a6265fca..91aa1564 100644 --- a/examples/pipeline/kg_builder_from_pdf.py +++ b/examples/pipeline/kg_builder_from_pdf.py @@ -19,7 +19,6 @@ from typing import Any, Dict, List import neo4j -from langchain_text_splitters import CharacterTextSplitter from neo4j_graphrag.experimental.components.entity_relation_extractor import ( LLMEntityRelationExtractor, OnError, @@ -31,8 +30,8 @@ SchemaEntity, SchemaRelation, ) -from neo4j_graphrag.experimental.components.text_splitters.langchain import ( - LangChainTextSplitterAdapter, +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, ) from neo4j_graphrag.experimental.pipeline import Component, DataModel from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult @@ -142,8 +141,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: pipe = Pipeline() pipe.add_component(PdfLoader(), "pdf_loader") pipe.add_component( - LangChainTextSplitterAdapter(CharacterTextSplitter(separator=". \n")), - "splitter", + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter" ) pipe.add_component(SchemaBuilder(), "schema") pipe.add_component( diff --git a/examples/pipeline/kg_builder_from_text.py b/examples/pipeline/kg_builder_from_text.py index 313422d2..15817e18 100644 --- a/examples/pipeline/kg_builder_from_text.py +++ b/examples/pipeline/kg_builder_from_text.py @@ -18,7 +18,6 @@ import logging.config import neo4j -from langchain_text_splitters import CharacterTextSplitter from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder from neo4j_graphrag.experimental.components.entity_relation_extractor import ( @@ -32,8 +31,8 @@ SchemaProperty, SchemaRelation, ) -from neo4j_graphrag.experimental.components.text_splitters.langchain import ( - LangChainTextSplitterAdapter, +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, ) from neo4j_graphrag.experimental.pipeline import Pipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult @@ -63,7 +62,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: """This is where we define and run the KG builder pipeline, instantiating a few components: - - Text Splitter: in this example we use a text splitter from the LangChain package + - Text Splitter: in this example we use the fixed size text splitter - Schema Builder: this component takes a list of entities, relationships and possible triplets as inputs, validate them and return a schema ready to use for the rest of the pipeline @@ -76,10 +75,8 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult: pipe = Pipeline() # define the components pipe.add_component( - LangChainTextSplitterAdapter( - # chunk_size=50 for the sake of this demo - CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator=".") - ), + # chunk_size=50 for the sake of this demo + FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter", ) pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder") diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/base.py b/src/neo4j_graphrag/experimental/components/text_splitters/base.py index 0d25b448..3061712b 100644 --- a/src/neo4j_graphrag/experimental/components/text_splitters/base.py +++ b/src/neo4j_graphrag/experimental/components/text_splitters/base.py @@ -25,4 +25,12 @@ class TextSplitter(Component): @abstractmethod async def run(self, text: str) -> TextChunks: + """Splits a piece of text into chunks. + + Args: + text (str): The text to be split. + + Returns: + TextChunks: A list of chunks. + """ pass diff --git a/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py new file mode 100644 index 00000000..6add30d8 --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/text_splitters/fixed_size_splitter.py @@ -0,0 +1,65 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pydantic import validate_call + +from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter +from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks + + +class FixedSizeSplitter(TextSplitter): + """Text splitter which splits the input text into fixed size chunks with optional overlap. + + Args: + chunk_size (int): The number of characters in each chunk. + chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`. + + Example: + + .. code-block:: python + + from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter + from neo4j_graphrag.experimental.pipeline import Pipeline + + pipeline = Pipeline() + text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200) + pipeline.add_component(text_splitter, "text_splitter") + """ + + @validate_call + def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None: + if chunk_overlap >= chunk_size: + raise ValueError("chunk_overlap must be strictly less than chunk_size") + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + @validate_call + async def run(self, text: str) -> TextChunks: + """Splits a piece of text into chunks. + + Args: + text (str): The text to be split. + + Returns: + TextChunks: A list of chunks. + """ + chunks = [] + index = 0 + for i in range(0, len(text), self.chunk_size - self.chunk_overlap): + start = i + end = min(start + self.chunk_size, len(text)) + chunk_text = text[start:end] + chunks.append(TextChunk(text=chunk_text, index=index)) + index += 1 + return TextChunks(chunks=chunks) diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index b90e2450..1a93a321 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -26,6 +26,7 @@ class TextChunk(BaseModel): Attributes: text (str): The raw chunk text. + index (int): The position of this chunk in the original document. metadata (Optional[dict[str, Any]]): Metadata associated with this chunk such as the id of the next chunk in the original document. """ diff --git a/tests/e2e/test_kg_builder_pipeline_e2e.py b/tests/e2e/test_kg_builder_pipeline_e2e.py index 9432b4b3..10fa6041 100644 --- a/tests/e2e/test_kg_builder_pipeline_e2e.py +++ b/tests/e2e/test_kg_builder_pipeline_e2e.py @@ -20,7 +20,6 @@ import neo4j import pytest -from langchain_text_splitters import CharacterTextSplitter from neo4j_graphrag.embeddings.base import Embedder from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder @@ -35,8 +34,8 @@ SchemaProperty, SchemaRelation, ) -from neo4j_graphrag.experimental.components.text_splitters.langchain import ( - LangChainTextSplitterAdapter, +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, ) from neo4j_graphrag.experimental.pipeline import Pipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult @@ -63,10 +62,8 @@ def schema_builder() -> SchemaBuilder: @pytest.fixture -def text_splitter() -> LangChainTextSplitterAdapter: - return LangChainTextSplitterAdapter( - CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator="\n\n") - ) +def text_splitter() -> FixedSizeSplitter: + return FixedSizeSplitter(chunk_size=500, chunk_overlap=100) @pytest.fixture @@ -89,7 +86,7 @@ def kg_writer(driver: neo4j.Driver) -> Neo4jWriter: @pytest.fixture def kg_builder_pipeline( - text_splitter: LangChainTextSplitterAdapter, + text_splitter: FixedSizeSplitter, chunk_embedder: TextChunkEmbedder, schema_builder: SchemaBuilder, entity_relation_extractor: LLMEntityRelationExtractor, diff --git a/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py new file mode 100644 index 00000000..9c559874 --- /dev/null +++ b/tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py @@ -0,0 +1,69 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, +) +from neo4j_graphrag.experimental.components.types import TextChunk + + +@pytest.mark.asyncio +async def test_split_text_no_overlap() -> None: + text = "may thy knife chip and shatter" + chunk_size = 5 + chunk_overlap = 0 + splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + chunks = await splitter.run(text) + expected_chunks = [ + TextChunk(text="may t", index=0), + TextChunk(text="hy kn", index=1), + TextChunk(text="ife c", index=2), + TextChunk(text="hip a", index=3), + TextChunk(text="nd sh", index=4), + TextChunk(text="atter", index=5), + ] + assert chunks.chunks == expected_chunks + + +@pytest.mark.asyncio +async def test_split_text_with_overlap() -> None: + text = "may thy knife chip and shatter" + chunk_size = 10 + chunk_overlap = 2 + splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + chunks = await splitter.run(text) + expected_chunks = [ + TextChunk(text="may thy kn", index=0), + TextChunk(text="knife chip", index=1), + TextChunk(text="ip and sha", index=2), + TextChunk(text="hatter", index=3), + ] + assert chunks.chunks == expected_chunks + + +@pytest.mark.asyncio +async def test_split_text_empty_string() -> None: + text = "" + chunk_size = 5 + chunk_overlap = 1 + splitter = FixedSizeSplitter(chunk_size, chunk_overlap) + chunks = await splitter.run(text) + assert chunks.chunks == [] + + +def test_invalid_chunk_overlap() -> None: + with pytest.raises(ValueError) as excinfo: + FixedSizeSplitter(5, 5) + assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo)