From 308084fda3a921c465e3f331e2ec21456c7d2a30 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 31 Oct 2024 16:41:13 +0000 Subject: [PATCH] Add flexibility for lexical graph config to SimpleKGPipeline --- .../components/entity_relation_extractor.py | 9 +++++++- .../experimental/pipeline/kg_builder.py | 7 ++++++ .../test_entity_relation_extractor.py | 23 +++++++++++++++++++ .../experimental/pipeline/test_kg_builder.py | 23 +++++++++++++++++++ 4 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 1fe49465..8849837a 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -177,6 +177,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor): Args: llm (LLMInterface): The language model to use for extraction. prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction. + lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. @@ -201,6 +202,7 @@ def __init__( self, llm: LLMInterface, prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), + lexical_graph_config: Optional[LexicalGraphConfig] = None, create_lexical_graph: bool = True, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, @@ -213,6 +215,7 @@ def __init__( else: template = prompt_template self.prompt_template = template + self.lexical_graph_config = lexical_graph_config async def extract_for_chunk( self, schema: SchemaConfig, examples: str, chunk: TextChunk @@ -334,7 +337,11 @@ async def run( lexical_graph_builder = None lexical_graph = None if self.create_lexical_graph: - config = lexical_graph_config or LexicalGraphConfig() + config = ( + lexical_graph_config + or self.lexical_graph_config + or LexicalGraphConfig() + ) lexical_graph_builder = LexicalGraphBuilder(config=config) lexical_graph_result = await lexical_graph_builder.run( text_chunks=chunks, document_info=document_info diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 9f7ff488..207952cc 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -39,6 +39,7 @@ from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, ) +from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult from neo4j_graphrag.generation.prompts import ERExtractionTemplate @@ -59,6 +60,7 @@ class SimpleKGPipelineConfig(BaseModel): on_error: OnError = OnError.RAISE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True + lexical_graph_config: Optional[LexicalGraphConfig] = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -84,6 +86,7 @@ class SimpleKGPipeline: on_error (str): Error handling strategy for the Entity and relation extractor. Defaults to "IGNORE", where chunk will be ignored if extraction fails. Possible values: "RAISE" or "IGNORE". perform_entity_resolution (bool): Merge entities with same label and name. Default: True prompt_template (str): A custom prompt template to use for extraction. + lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. """ def __init__( @@ -101,6 +104,7 @@ def __init__( on_error: str = "IGNORE", prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(), perform_entity_resolution: bool = True, + lexical_graph_config: Optional[LexicalGraphConfig] = None, ): self.entities = [SchemaEntity(label=label) for label in entities or []] self.relations = [SchemaRelation(label=label) for label in relations or []] @@ -127,6 +131,7 @@ def __init__( prompt_template=prompt_template, embedder=embedder, perform_entity_resolution=perform_entity_resolution, + lexical_graph_config=lexical_graph_config, ) self.from_pdf = config.from_pdf @@ -141,6 +146,7 @@ def __init__( ) self.prompt_template = config.prompt_template self.perform_entity_resolution = config.perform_entity_resolution + self.lexical_graph_config = config.lexical_graph_config self.pipeline = self._build_pipeline() @@ -154,6 +160,7 @@ def _build_pipeline(self) -> Pipeline: llm=self.llm, on_error=self.on_error, prompt_template=self.prompt_template, + lexical_graph_config=self.lexical_graph_config, ), "extractor", ) diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 1ec56153..861e92f4 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -27,6 +27,7 @@ ) from neo4j_graphrag.experimental.components.pdf_loader import DocumentInfo from neo4j_graphrag.experimental.components.types import ( + LexicalGraphConfig, Neo4jGraph, TextChunk, TextChunks, @@ -51,6 +52,28 @@ async def test_extractor_happy_path_no_entities_no_document() -> None: assert result.relationships == [] +@pytest.mark.asyncio +async def test_extractor_happy_path_with_lexical_graph_config() -> None: + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse(content='{"nodes": [], "relationships": []}') + + extractor = LLMEntityRelationExtractor( + llm=llm, + lexical_graph_config=LexicalGraphConfig( + document_node_label="testDocumentNode", + chunk_node_label="testChunkNode", + ), + ) + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + document_info = DocumentInfo(path="path") + result = await extractor.run(chunks=chunks, document_info=document_info) + + assert isinstance(result, Neo4jGraph) + # one Chunk node and one Document node + assert len(result.nodes) == 2 + assert set(n.label for n in result.nodes) == {"testDocumentNode", "testChunkNode"} + + @pytest.mark.asyncio async def test_extractor_happy_path_no_entities() -> None: llm = MagicMock(spec=LLMInterface) diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index 64da47a0..47789187 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -20,6 +20,7 @@ from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation +from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult @@ -316,3 +317,25 @@ def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None: ) assert "resolver" not in kg_builder.pipeline + + +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) +@pytest.mark.asyncio +def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None: + llm = MagicMock(spec=LLMInterface) + driver = MagicMock(spec=neo4j.Driver) + embedder = MagicMock(spec=Embedder) + + lexical_graph_config = LexicalGraphConfig() + kg_builder = SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + on_error="IGNORE", + lexical_graph_config=lexical_graph_config, + ) + + assert kg_builder.lexical_graph_config == lexical_graph_config