From ecda1741c7dfc727f726eb8f7130efb8ae42c270 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 3 Dec 2024 22:45:06 +0100 Subject: [PATCH] Implement SimpleKGBuilder with this setup --- .../pipeline/simple_kg_pipeline_config.json | 105 +++++++--- .../simple_kg_pipeline_from_config_file.py | 6 +- .../experimental/components/schema.py | 8 +- .../pipeline/config/config_poc.py | 194 +++++++++++++++--- 4 files changed, 254 insertions(+), 59 deletions(-) diff --git a/examples/customize/build_graph/pipeline/simple_kg_pipeline_config.json b/examples/customize/build_graph/pipeline/simple_kg_pipeline_config.json index fffcc895..ef251624 100644 --- a/examples/customize/build_graph/pipeline/simple_kg_pipeline_config.json +++ b/examples/customize/build_graph/pipeline/simple_kg_pipeline_config.json @@ -1,36 +1,38 @@ { "version_": "1", + "template_": "SimpleKGPipeline", "neo4j_config": { - "uri": { - "resolver_": "ENV", - "var_": "NEO4J_URI" - }, - "user": { - "resolver_": "ENV", - "var_": "NEO4J_USER" - }, - "password": { - "resolver_": "ENV", - "var_": "NEO4J_PASSWORD" - }, - "database": { - "resolver_": "ENV", - "var_": "NEO4J_DATABASE" + "params_": { + "uri": { + "resolver_": "ENV", + "var_": "NEO4J_URI" + }, + "user": { + "resolver_": "ENV", + "var_": "NEO4J_USER" + }, + "password": { + "resolver_": "ENV", + "var_": "NEO4J_PASSWORD" + } } }, "llm_config": { - "name_": "openai", "class_": "OpenAILLM", "params_": { "api_key": { "resolver_": "ENV", "var_": "OPENAI_API_KEY" }, - "model_name": "gpt-4o" + "model_name": "gpt-4o", + "model_params": { + "temperature": 0, + "max_tokens": 2000, + "response_format": {"type": "json_object"} + } } }, "embedder_config": { - "name_": "openai", "class_": "OpenAIEmbeddings", "params_": { "api_key": { @@ -40,18 +42,71 @@ } }, "from_pdf": false, - "entities": ["Person", {"label": "Organization"}], - "relations": ["WORKS_FOR", {"label": "DIRECTED_BY"}], + "entities": [ + "Person", + { + "label": "House", + "description": "Family the person belongs to", + "properties": [ + { + "name": "name", + "type": "STRING" + } + ] + }, + { + "label": "Planet", + "properties": [ + { + "name": "name", + "type": "STRING" + }, + { + "name": "weather", + "type": "STRING" + } + ] + } + ], + "relations": [ + "PARENT_OF", + { + "label": "HEIR_OF", + "description": "Used for inheritor relationship between father and sons" + }, + { + "label": "RULES", + "properties": [ + { + "name": "fromYear", + "type": "INTEGER" + } + ] + } + ], "potential_schema": [ - ["Person", "WORKS_FOR", "Organization"], - ["Organization", "DIRECTED_BY", "Person"] + [ + "Person", + "PARENT_OF", + "Person" + ], + [ + "Person", + "HEIR_OF", + "House" + ], + [ + "House", + "RULES", + "Planet" + ] ], "text_splitter": { - "class_": "fixed_size_splitter.FixedSizeSplitter", + "class_": "text_splitters.fixed_size_splitter.FixedSizeSplitter", "params_": { - "chunk_size": 100, + "chunk_size": 100, "chunk_overlap": 10 } }, - "perform_entity_resolution": false + "perform_entity_resolution": true } diff --git a/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py index fedd5e86..462c2a11 100644 --- a/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py +++ b/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py @@ -12,7 +12,7 @@ # env vars manually set for testing: import os -from neo4j_graphrag.experimental.pipeline.config.parser import SimpleKGPipelineBuilder +from neo4j_graphrag.experimental.pipeline.config.config_poc import PipelineRunner from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult os.environ["NEO4J_URI"] = "bolt://localhost:7687" @@ -28,8 +28,8 @@ async def main() -> PipelineResult: file_path = "examples/customize/build_graph/pipeline/simple_kg_pipeline_config.json" - pipeline = SimpleKGPipelineBuilder.from_config_file(file_path) - return await pipeline.run_async(text=TEXT) + pipeline = PipelineRunner.from_config_file(file_path) + return await pipeline.run({"text":TEXT}) if __name__ == "__main__": diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 439c60c6..cb0bab8d 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -57,8 +57,10 @@ class SchemaEntity(BaseModel): @classmethod def from_text_or_dict( - cls, input: str | dict[str, Union[str, dict[str, str]]] + cls, input: SchemaEntity | str | dict[str, Union[str, dict[str, str]]] ) -> Self: + if isinstance(input, SchemaEntity): + return input if isinstance(input, str): return cls(label=input) return cls.model_validate(input) @@ -75,8 +77,10 @@ class SchemaRelation(BaseModel): @classmethod def from_text_or_dict( - cls, input: str | dict[str, Union[str, dict[str, str]]] + cls, input: SchemaRelation | str | dict[str, Union[str, dict[str, str]]] ) -> Self: + if isinstance(input, SchemaRelation): + return input if isinstance(input, str): return cls(label=input) return cls.model_validate(input) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py b/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py index 50443749..7214e7e7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py @@ -33,7 +33,16 @@ from pydantic.v1.utils import deep_update from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder +from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError, EntityRelationExtractor, \ + LLMEntityRelationExtractor +from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader +from neo4j_graphrag.experimental.components.resolver import EntityResolver, SinglePropertyExactMatchResolver +from neo4j_graphrag.experimental.components.schema import SchemaBuilder, SchemaEntity, SchemaRelation +from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter +from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline import Component, Pipeline from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader @@ -45,8 +54,9 @@ from neo4j_graphrag.experimental.pipeline.types import ( ComponentDefinition, ConnectionDefinition, - PipelineDefinition, + PipelineDefinition, EntityInputType, RelationInputType, ) +from neo4j_graphrag.generation.prompts import ERExtractionTemplate from neo4j_graphrag.llm import LLMInterface @@ -409,7 +419,28 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefiniti ) def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - return {} + return user_input + + def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver: + drivers = self._global_data.get("neo4j_config", {}) + return drivers.get(name) + + def get_default_neo4j_driver(self) -> neo4j.Driver: + return self.get_neo4j_driver_by_name(self.DEFAULT_NAME) + + def get_llm_by_name(self, name: str) -> LLMInterface: + llms = self._global_data.get("llm_config", {}) + return llms.get(name) + + def get_default_llm(self) -> LLMInterface: + return self.get_llm_by_name(self.DEFAULT_NAME) + + def get_embedder_by_name(self, name: str) -> Embedder: + embedders = self._global_data.get("embedder_config", {}) + return embedders.get(name) + + def get_default_embedder(self) -> Embedder: + return self.get_embedder_by_name(self.DEFAULT_NAME) class PipelineConfig(AbstractPipelineConfig): @@ -465,42 +496,149 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: class SimpleKGPipelineConfig(TemplatePipelineConfig): COMPONENTS: ClassVar[list[str]] = [ "pdf_loader", - # "splitter", - # "chunk_embedder", - # "extractor", - # "writer", - # "entity_resolver", + "splitter", + "chunk_embedder", + "schema_builder", + "extractor", + "writer", + "resolver", ] template_: Literal[PipelineType.SIMPLE_KG_PIPELINE] = ( PipelineType.SIMPLE_KG_PIPELINE ) + from_pdf: bool = False + entities: list[EntityInputType] = [] + relations: list[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None - # on_error: OnError = OnError.IGNORE - # prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() + on_error: OnError = OnError.IGNORE + prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True - # lexical_graph_config: Optional[LexicalGraphConfig] = None + lexical_graph_config: Optional[LexicalGraphConfig] = None neo4j_database: Optional[str] = None pdf_loader: ComponentConfig | None = None kg_writer: ComponentConfig | None = None text_splitter: ComponentConfig | None = None - # entities: list[SchemaEntity] = [] - # relations: list[SchemaRelation] = [] - def _get_pdf_loader(self) -> Component | None: + model_config = ConfigDict(arbitrary_types_allowed=True) + + def _get_pdf_loader(self) -> PdfLoader | None: if not self.from_pdf: return None if self.pdf_loader: - return self._resolve_component(self.pdf_loader) + return self.pdf_loader.parse(self._global_data) # type: ignore return PdfLoader() + def _get_splitter(self) -> TextSplitter: + if self.text_splitter: + return self.text_splitter.parse(self._global_data) # type: ignore + return FixedSizeSplitter() + + def _get_chunk_embedder(self) -> TextChunkEmbedder: + return TextChunkEmbedder(embedder=self.get_default_embedder()) + + def _get_schema_builder(self) -> SchemaBuilder: + return SchemaBuilder() + + def _get_run_params_for_schema_builder(self) -> dict[str, Any]: + return { + "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], + "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], + "potential_schema": self.potential_schema, + } + + def _get_extractor(self) -> EntityRelationExtractor: + return LLMEntityRelationExtractor( + llm=self.get_default_llm(), + prompt_template=self.prompt_template, + on_error=self.on_error, + ) + + def _get_writer(self) -> KGWriter: + if self.kg_writer: + return self.kg_writer.parse(self._global_data) # type: ignore + return Neo4jWriter( + driver=self.get_default_neo4j_driver() + ) + + def _get_resolver(self) -> EntityResolver | None: + if not self.perform_entity_resolution: + return None + return SinglePropertyExactMatchResolver( + driver=self.get_default_neo4j_driver(), + ) + + def _get_connections(self) -> list[ConnectionDefinition]: + connections = [] + if self.from_pdf: + connections.append(ConnectionDefinition( + start="pdf_loader", + end="splitter", + input_config={"text": "pdf_loader.text"}, + )) + connections.append(ConnectionDefinition( + start="schema_builder", + end="extractor", + input_config={ + "schema": "schema_builder", + "document_info": "pdf_loader.document_info", + }, + )) + else: + connections.append(ConnectionDefinition( + start="schema_builder", + end="extractor", + input_config={ + "schema": "schema_builder", + }, + )) + connections.append(ConnectionDefinition( + start="splitter", + end="chunk_embedder", + input_config={ + "text_chunks": "splitter", + }, + )) + connections.append(ConnectionDefinition( + start="chunk_embedder", + end="extractor", + input_config={ + "chunks": "chunk_embedder", + }, + )) + connections.append(ConnectionDefinition( + start="extractor", + end="writer", + input_config={ + "graph": "extractor", + }, + )) + + if self.perform_entity_resolution: + connections.append(ConnectionDefinition( + start="writer", + end="resolver", + input_config={}, + )) + + return connections + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - return {} + run_params = {} + if self.lexical_graph_config: + run_params["extractor"] = { + "lexical_graph_config": self.lexical_graph_config + } + if self.from_pdf: + run_params["pdf_loader"] = {"filepath": user_input["filepath"]} + else: + run_params["splitter"] = {"text": user_input["text"]} + return run_params -def get_discriminator_value(model: Any) -> PipelineType: +def _get_discriminator_value(model: Any) -> PipelineType: template_ = None if "template_" in model: template_ = model["template_"] @@ -515,13 +653,13 @@ class PipelineConfigWrapper(BaseModel): config: Union[ Annotated[PipelineConfig, Tag(PipelineType.NONE)], Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], - ] = Field(discriminator=Discriminator(get_discriminator_value)) + ] = Field(discriminator=Discriminator(_get_discriminator_value)) def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: return self.config.parse(resolved_data) def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - return self.config.get_run_params({}) + return self.config.get_run_params(user_input) class PipelineRunner: @@ -532,31 +670,29 @@ class PipelineRunner: - A PipelineConfig (`from_config` method) - A config file (`from_config_file` method) """ - def __init__(self, pipeline_definition: PipelineDefinition) -> None: + def __init__(self, pipeline_definition: PipelineDefinition, config: PipelineConfigWrapper | None = None) -> None: + self.config = config self.pipeline = Pipeline.from_definition(pipeline_definition) self.run_params = pipeline_definition.get_run_params() @classmethod - def from_config(cls, config: AbstractPipelineConfig) -> Self: + def from_config(cls, config: AbstractPipelineConfig | dict[str, Any]) -> Self: wrapper = PipelineConfigWrapper.model_validate({"config": config}) - return cls(wrapper.parse()) + return cls(wrapper.parse(), config=wrapper) @classmethod def from_config_file(cls, file_path: Union[str, Path]) -> Self: - pipeline_definition = cls._parse(file_path) - return cls(pipeline_definition) - - @classmethod - def _parse(cls, file_path: Union[str, Path]) -> PipelineDefinition: if not isinstance(file_path, Path): file_path = Path(file_path) data = ConfigReader().read(file_path) - wrapper = PipelineConfigWrapper.model_validate({"config": data}) - return wrapper.parse() + return cls.from_config(data) async def run(self, data: dict[str, Any]) -> PipelineResult: # pipeline_conditional_run_params = self. - run_param = deep_update(self.run_params, data) + if self.config: + run_param = deep_update(self.run_params, self.config.get_run_params(data)) + else: + run_param = deep_update(self.run_params, data) return await self.pipeline.run(data=run_param)