Skip to content

Commit

Permalink
Add flexibility for lexical graph config to SimpleKGPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Oct 31, 2024
1 parent 508323a commit 308084f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
Args:
llm (LLMInterface): The language model to use for extraction.
prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction.
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True.
on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error.
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
Expand All @@ -201,6 +202,7 @@ def __init__(
self,
llm: LLMInterface,
prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(),
lexical_graph_config: Optional[LexicalGraphConfig] = None,
create_lexical_graph: bool = True,
on_error: OnError = OnError.RAISE,
max_concurrency: int = 5,
Expand All @@ -213,6 +215,7 @@ def __init__(
else:
template = prompt_template
self.prompt_template = template
self.lexical_graph_config = lexical_graph_config

async def extract_for_chunk(
self, schema: SchemaConfig, examples: str, chunk: TextChunk
Expand Down Expand Up @@ -334,7 +337,11 @@ async def run(
lexical_graph_builder = None
lexical_graph = None
if self.create_lexical_graph:
config = lexical_graph_config or LexicalGraphConfig()
config = (
lexical_graph_config
or self.lexical_graph_config
or LexicalGraphConfig()
)
lexical_graph_builder = LexicalGraphBuilder(config=config)
lexical_graph_result = await lexical_graph_builder.run(
text_chunks=chunks, document_info=document_info
Expand Down
7 changes: 7 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, PipelineResult
from neo4j_graphrag.generation.prompts import ERExtractionTemplate
Expand All @@ -59,6 +60,7 @@ class SimpleKGPipelineConfig(BaseModel):
on_error: OnError = OnError.RAISE
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
perform_entity_resolution: bool = True
lexical_graph_config: Optional[LexicalGraphConfig] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand All @@ -84,6 +86,7 @@ class SimpleKGPipeline:
on_error (str): Error handling strategy for the Entity and relation extractor. Defaults to "IGNORE", where chunk will be ignored if extraction fails. Possible values: "RAISE" or "IGNORE".
perform_entity_resolution (bool): Merge entities with same label and name. Default: True
prompt_template (str): A custom prompt template to use for extraction.
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
"""

def __init__(
Expand All @@ -101,6 +104,7 @@ def __init__(
on_error: str = "IGNORE",
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(),
perform_entity_resolution: bool = True,
lexical_graph_config: Optional[LexicalGraphConfig] = None,
):
self.entities = [SchemaEntity(label=label) for label in entities or []]
self.relations = [SchemaRelation(label=label) for label in relations or []]
Expand All @@ -127,6 +131,7 @@ def __init__(
prompt_template=prompt_template,
embedder=embedder,
perform_entity_resolution=perform_entity_resolution,
lexical_graph_config=lexical_graph_config,
)

self.from_pdf = config.from_pdf
Expand All @@ -141,6 +146,7 @@ def __init__(
)
self.prompt_template = config.prompt_template
self.perform_entity_resolution = config.perform_entity_resolution
self.lexical_graph_config = config.lexical_graph_config

self.pipeline = self._build_pipeline()

Expand All @@ -154,6 +160,7 @@ def _build_pipeline(self) -> Pipeline:
llm=self.llm,
on_error=self.on_error,
prompt_template=self.prompt_template,
lexical_graph_config=self.lexical_graph_config,
),
"extractor",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from neo4j_graphrag.experimental.components.pdf_loader import DocumentInfo
from neo4j_graphrag.experimental.components.types import (
LexicalGraphConfig,
Neo4jGraph,
TextChunk,
TextChunks,
Expand All @@ -51,6 +52,28 @@ async def test_extractor_happy_path_no_entities_no_document() -> None:
assert result.relationships == []


@pytest.mark.asyncio
async def test_extractor_happy_path_with_lexical_graph_config() -> None:
llm = MagicMock(spec=LLMInterface)
llm.ainvoke.return_value = LLMResponse(content='{"nodes": [], "relationships": []}')

extractor = LLMEntityRelationExtractor(
llm=llm,
lexical_graph_config=LexicalGraphConfig(
document_node_label="testDocumentNode",
chunk_node_label="testChunkNode",
),
)
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
document_info = DocumentInfo(path="path")
result = await extractor.run(chunks=chunks, document_info=document_info)

assert isinstance(result, Neo4jGraph)
# one Chunk node and one Document node
assert len(result.nodes) == 2
assert set(n.label for n in result.nodes) == {"testDocumentNode", "testChunkNode"}


@pytest.mark.asyncio
async def test_extractor_happy_path_no_entities() -> None:
llm = MagicMock(spec=LLMInterface)
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/experimental/pipeline/test_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from neo4j_graphrag.embeddings import Embedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError
from neo4j_graphrag.experimental.components.schema import SchemaEntity, SchemaRelation
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
Expand Down Expand Up @@ -316,3 +317,25 @@ def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None:
)

assert "resolver" not in kg_builder.pipeline


@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
return_value=(5, 23, 0),
)
@pytest.mark.asyncio
def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None:
llm = MagicMock(spec=LLMInterface)
driver = MagicMock(spec=neo4j.Driver)
embedder = MagicMock(spec=Embedder)

lexical_graph_config = LexicalGraphConfig()
kg_builder = SimpleKGPipeline(
llm=llm,
driver=driver,
embedder=embedder,
on_error="IGNORE",
lexical_graph_config=lexical_graph_config,
)

assert kg_builder.lexical_graph_config == lexical_graph_config

0 comments on commit 308084f

Please sign in to comment.