Skip to content

Commit

Permalink
Adds a fixed size text splitter component (neo4j#139)
Browse files Browse the repository at this point in the history
* Added fixed size text splitter class

* Updated docs

* Updated examples

* Added init defaults to fixed size text splitter

* Fixed bug in example

* Updated E2E tests

* Update docs/source/api.rst

Co-authored-by: willtai <[email protected]>

* Updated fixed size splitter defaults

---------

Co-authored-by: willtai <[email protected]>
  • Loading branch information
alexthomas93 and willtai authored Sep 19, 2024
1 parent fc7d319 commit 6851a0b
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 24 deletions.
6 changes: 6 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ TextSplitter
.. autoclass:: neo4j_graphrag.experimental.components.text_splitters.base.TextSplitter
:members: run

FixedSizeSplitter
=================

.. autoclass:: neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter.FixedSizeSplitter
:members: run

LangChainTextSplitterAdapter
============================

Expand Down
14 changes: 11 additions & 3 deletions docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,24 @@ Document Splitter
=================

Document splitters, as the name indicate, split documents into smaller chunks
that can be processed within the LLM token limits. Wrappers for LangChain and LlamaIndex
text splitters are included in this package:
that can be processed within the LLM token limits:

.. code:: python
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
splitter.run(text="Hello World. Life is beautiful.")
Wrappers for LangChain and LlamaIndex text splitters are included in this package:

.. code:: python
from langchain_text_splitters import CharacterTextSplitter
from neo4j_graphrag.experimental.components.text_splitters.langchain import LangChainTextSplitterAdapter
splitter = LangChainTextSplitterAdapter(
CharacterTextSplitter(chunk_size=500, chunk_overlap=100, separator=".")
CharacterTextSplitter(chunk_size=4000, chunk_overlap=200, separator=".")
)
splitter.run(text="Hello World. Life is beautiful.")
Expand Down
8 changes: 3 additions & 5 deletions examples/pipeline/kg_builder_from_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, Dict, List

import neo4j
from langchain_text_splitters import CharacterTextSplitter
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
LLMEntityRelationExtractor,
OnError,
Expand All @@ -31,8 +30,8 @@
SchemaEntity,
SchemaRelation,
)
from neo4j_graphrag.experimental.components.text_splitters.langchain import (
LangChainTextSplitterAdapter,
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
Expand Down Expand Up @@ -142,8 +141,7 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
pipe = Pipeline()
pipe.add_component(PdfLoader(), "pdf_loader")
pipe.add_component(
LangChainTextSplitterAdapter(CharacterTextSplitter(separator=". \n")),
"splitter",
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200), "splitter"
)
pipe.add_component(SchemaBuilder(), "schema")
pipe.add_component(
Expand Down
13 changes: 5 additions & 8 deletions examples/pipeline/kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging.config

import neo4j
from langchain_text_splitters import CharacterTextSplitter
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
Expand All @@ -32,8 +31,8 @@
SchemaProperty,
SchemaRelation,
)
from neo4j_graphrag.experimental.components.text_splitters.langchain import (
LangChainTextSplitterAdapter,
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
Expand Down Expand Up @@ -63,7 +62,7 @@
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
"""This is where we define and run the KG builder pipeline, instantiating a few
components:
- Text Splitter: in this example we use a text splitter from the LangChain package
- Text Splitter: in this example we use the fixed size text splitter
- Schema Builder: this component takes a list of entities, relationships and
possible triplets as inputs, validate them and return a schema ready to use
for the rest of the pipeline
Expand All @@ -76,10 +75,8 @@ async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
pipe = Pipeline()
# define the components
pipe.add_component(
LangChainTextSplitterAdapter(
# chunk_size=50 for the sake of this demo
CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator=".")
),
# chunk_size=50 for the sake of this demo
FixedSizeSplitter(chunk_size=4000, chunk_overlap=200),
"splitter",
)
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,12 @@ class TextSplitter(Component):

@abstractmethod
async def run(self, text: str) -> TextChunks:
"""Splits a piece of text into chunks.
Args:
text (str): The text to be split.
Returns:
TextChunks: A list of chunks.
"""
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import validate_call

from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks


class FixedSizeSplitter(TextSplitter):
"""Text splitter which splits the input text into fixed size chunks with optional overlap.
Args:
chunk_size (int): The number of characters in each chunk.
chunk_overlap (int): The number of characters from the previous chunk to overlap with each chunk. Must be less than `chunk_size`.
Example:
.. code-block:: python
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
from neo4j_graphrag.experimental.pipeline import Pipeline
pipeline = Pipeline()
text_splitter = FixedSizeSplitter(chunk_size=4000, chunk_overlap=200)
pipeline.add_component(text_splitter, "text_splitter")
"""

@validate_call
def __init__(self, chunk_size: int = 4000, chunk_overlap: int = 200) -> None:
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be strictly less than chunk_size")
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

@validate_call
async def run(self, text: str) -> TextChunks:
"""Splits a piece of text into chunks.
Args:
text (str): The text to be split.
Returns:
TextChunks: A list of chunks.
"""
chunks = []
index = 0
for i in range(0, len(text), self.chunk_size - self.chunk_overlap):
start = i
end = min(start + self.chunk_size, len(text))
chunk_text = text[start:end]
chunks.append(TextChunk(text=chunk_text, index=index))
index += 1
return TextChunks(chunks=chunks)
1 change: 1 addition & 0 deletions src/neo4j_graphrag/experimental/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TextChunk(BaseModel):
Attributes:
text (str): The raw chunk text.
index (int): The position of this chunk in the original document.
metadata (Optional[dict[str, Any]]): Metadata associated with this chunk such as the id of the next chunk in the original document.
"""

Expand Down
13 changes: 5 additions & 8 deletions tests/e2e/test_kg_builder_pipeline_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import neo4j
import pytest
from langchain_text_splitters import CharacterTextSplitter
from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
Expand All @@ -35,8 +34,8 @@
SchemaProperty,
SchemaRelation,
)
from neo4j_graphrag.experimental.components.text_splitters.langchain import (
LangChainTextSplitterAdapter,
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
Expand All @@ -63,10 +62,8 @@ def schema_builder() -> SchemaBuilder:


@pytest.fixture
def text_splitter() -> LangChainTextSplitterAdapter:
return LangChainTextSplitterAdapter(
CharacterTextSplitter(chunk_size=50, chunk_overlap=10, separator="\n\n")
)
def text_splitter() -> FixedSizeSplitter:
return FixedSizeSplitter(chunk_size=500, chunk_overlap=100)


@pytest.fixture
Expand All @@ -89,7 +86,7 @@ def kg_writer(driver: neo4j.Driver) -> Neo4jWriter:

@pytest.fixture
def kg_builder_pipeline(
text_splitter: LangChainTextSplitterAdapter,
text_splitter: FixedSizeSplitter,
chunk_embedder: TextChunkEmbedder,
schema_builder: SchemaBuilder,
entity_relation_extractor: LLMEntityRelationExtractor,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.components.types import TextChunk


@pytest.mark.asyncio
async def test_split_text_no_overlap() -> None:
text = "may thy knife chip and shatter"
chunk_size = 5
chunk_overlap = 0
splitter = FixedSizeSplitter(chunk_size, chunk_overlap)
chunks = await splitter.run(text)
expected_chunks = [
TextChunk(text="may t", index=0),
TextChunk(text="hy kn", index=1),
TextChunk(text="ife c", index=2),
TextChunk(text="hip a", index=3),
TextChunk(text="nd sh", index=4),
TextChunk(text="atter", index=5),
]
assert chunks.chunks == expected_chunks


@pytest.mark.asyncio
async def test_split_text_with_overlap() -> None:
text = "may thy knife chip and shatter"
chunk_size = 10
chunk_overlap = 2
splitter = FixedSizeSplitter(chunk_size, chunk_overlap)
chunks = await splitter.run(text)
expected_chunks = [
TextChunk(text="may thy kn", index=0),
TextChunk(text="knife chip", index=1),
TextChunk(text="ip and sha", index=2),
TextChunk(text="hatter", index=3),
]
assert chunks.chunks == expected_chunks


@pytest.mark.asyncio
async def test_split_text_empty_string() -> None:
text = ""
chunk_size = 5
chunk_overlap = 1
splitter = FixedSizeSplitter(chunk_size, chunk_overlap)
chunks = await splitter.run(text)
assert chunks.chunks == []


def test_invalid_chunk_overlap() -> None:
with pytest.raises(ValueError) as excinfo:
FixedSizeSplitter(5, 5)
assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo)

0 comments on commit 6851a0b

Please sign in to comment.