From a7d37e42a672cf1ef42dbb84b4909a9c9bdfcbc9 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 5 Dec 2024 14:31:10 +0100 Subject: [PATCH] Restructure files + increase tests coverage --- .../pipeline/pipeline_from_config_file.py | 2 +- .../simple_kg_pipeline_from_config_file.py | 2 +- .../experimental/pipeline/config/base.py | 91 +++ .../pipeline/config/config_parser.py | 739 ------------------ .../pipeline/config/config_poc_test.py | 248 ------ .../pipeline/config/config_reader.py | 5 +- .../pipeline/config/object_config.py | 212 +++++ .../pipeline/config/pipeline_config.py | 182 +++++ .../experimental/pipeline/config/runner.py | 115 +++ .../config/template_pipeline/__init__.py | 20 + .../pipeline/config/template_pipeline/base.py | 55 ++ .../template_pipeline/simple_kg_builder.py | 224 ++++++ .../experimental/pipeline/config/types.py | 26 + .../experimental/pipeline/kg_builder.py | 4 +- tests/unit/conftest.py | 6 + .../experimental/pipeline/config/__init__.py | 0 .../config/template_pipeline/__init__.py | 0 .../config/template_pipeline/test_base.py | 0 .../test_simple_kg_builder.py | 0 .../experimental/pipeline/config/test_base.py | 67 ++ .../pipeline/config/test_object_config.py | 133 ++++ .../pipeline/config/test_param_resolver.py | 56 ++ .../pipeline/config/test_pipeline_config.py | 378 +++++++++ .../pipeline/config/test_runner.py | 0 24 files changed, 1573 insertions(+), 992 deletions(-) create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/base.py delete mode 100644 src/neo4j_graphrag/experimental/pipeline/config/config_parser.py delete mode 100644 src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/object_config.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/runner.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/config/types.py create mode 100644 tests/unit/experimental/pipeline/config/__init__.py create mode 100644 tests/unit/experimental/pipeline/config/template_pipeline/__init__.py create mode 100644 tests/unit/experimental/pipeline/config/template_pipeline/test_base.py create mode 100644 tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py create mode 100644 tests/unit/experimental/pipeline/config/test_base.py create mode 100644 tests/unit/experimental/pipeline/config/test_object_config.py create mode 100644 tests/unit/experimental/pipeline/config/test_param_resolver.py create mode 100644 tests/unit/experimental/pipeline/config/test_pipeline_config.py create mode 100644 tests/unit/experimental/pipeline/config/test_runner.py diff --git a/examples/customize/build_graph/pipeline/pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/pipeline_from_config_file.py index 917efefa..9a2a1680 100644 --- a/examples/customize/build_graph/pipeline/pipeline_from_config_file.py +++ b/examples/customize/build_graph/pipeline/pipeline_from_config_file.py @@ -13,7 +13,7 @@ import os from pathlib import Path -from neo4j_graphrag.experimental.pipeline.config.config_parser import PipelineRunner +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult os.environ["NEO4J_URI"] = "bolt://localhost:7687" diff --git a/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py index 837ce5d7..b4f26709 100644 --- a/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py +++ b/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py @@ -15,7 +15,7 @@ import os from pathlib import Path -from neo4j_graphrag.experimental.pipeline.config.config_parser import PipelineRunner +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult os.environ["NEO4J_URI"] = "bolt://localhost:7687" diff --git a/src/neo4j_graphrag/experimental/pipeline/config/base.py b/src/neo4j_graphrag/experimental/pipeline/config/base.py new file mode 100644 index 00000000..e386848e --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/base.py @@ -0,0 +1,91 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract class for all pipeline configs.""" + +from __future__ import annotations + +import importlib +import logging +from typing import Any, Optional + +from pydantic import BaseModel, PrivateAttr + +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamConfig, + ParamToResolveConfig, +) + +logger = logging.getLogger(__name__) + + +class AbstractConfig(BaseModel): + """Base class for all configs. + Provides methods to get a class from a string and resolve a parameter defined by + a dict with a 'resolver_' key. + + Each subclass must implement a 'parse' method that returns the relevant object. + """ + + _global_data: dict[str, Any] = PrivateAttr({}) + """Additional parameter ignored by all Pydantic model_* methods.""" + + @classmethod + def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type: + """Get class from string and an optional module + + Will first try to import the class from `class_path` alone. If it results in an ImportError, + will try to import from `f'{optional_module}.{class_path}'` + + Args: + class_path (str): Class path with format 'my_module.MyClass'. + optional_module (Optional[str]): Optional module path. Used to provide a default path for some known objects and simplify the notation. + + Raises: + ValueError: if the class can't be imported, even using the optional module. + """ + *modules, class_name = class_path.rsplit(".", 1) + module_name = modules[0] if modules else optional_module + if module_name is None: + raise ValueError("Must specify a module to import class from") + try: + module = importlib.import_module(module_name) + klass = getattr(module, class_name) + except (ImportError, AttributeError): + if optional_module and module_name != optional_module: + full_klass_path = optional_module + "." + class_path + return cls._get_class(full_klass_path) + raise ValueError(f"Could not find {class_name} in {module_name}") + return klass # type: ignore[no-any-return] + + def resolve_param(self, param: ParamConfig) -> Any: + """Finds the parameter value from its definition.""" + if not isinstance(param, ParamToResolveConfig): + # some parameters do not have to be resolved, real + # values are already provided + return param + return param.resolve(self._global_data) + + def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]: + """Resolve all parameters + + Returning dict[str, Any] because parameters can be anything (str, float, list, dict...) + """ + return { + param_name: self.resolve_param(param) + for param_name, param in params.items() + } + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Any: + raise NotImplementedError() diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_parser.py b/src/neo4j_graphrag/experimental/pipeline/config/config_parser.py deleted file mode 100644 index 11cc514d..00000000 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_parser.py +++ /dev/null @@ -1,739 +0,0 @@ -"""Generic config for all pipelines + specific implementation for "templates" -such as the SimpleKGPipeline. -""" - -from __future__ import annotations - -import abc -import enum -import importlib -from collections import Counter -from pathlib import Path -from typing import ( - Annotated, - Any, - ClassVar, - Generic, - Literal, - Optional, - Sequence, - TypeVar, - Union, -) - -import neo4j -from pydantic import ( - BaseModel, - ConfigDict, - Discriminator, - Field, - PrivateAttr, - RootModel, - Tag, - field_validator, -) -from pydantic.v1.utils import deep_update -from typing_extensions import Self - -from neo4j_graphrag.embeddings import Embedder -from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder -from neo4j_graphrag.experimental.components.entity_relation_extractor import ( - EntityRelationExtractor, - LLMEntityRelationExtractor, - OnError, -) -from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter -from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader -from neo4j_graphrag.experimental.components.resolver import ( - EntityResolver, - SinglePropertyExactMatchResolver, -) -from neo4j_graphrag.experimental.components.schema import ( - SchemaBuilder, - SchemaEntity, - SchemaRelation, -) -from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter -from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( - FixedSizeSplitter, -) -from neo4j_graphrag.experimental.components.types import LexicalGraphConfig -from neo4j_graphrag.experimental.pipeline import Component, Pipeline -from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader -from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( - ParamConfig, - ParamToResolveConfig, -) -from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError -from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult -from neo4j_graphrag.experimental.pipeline.types import ( - ComponentDefinition, - ConnectionDefinition, - EntityInputType, - PipelineDefinition, - RelationInputType, -) -from neo4j_graphrag.generation.prompts import ERExtractionTemplate -from neo4j_graphrag.llm import LLMInterface - - -class AbstractConfig(BaseModel, abc.ABC): - """Base class for all configs. - Provides methods to get a class from a string and resolve a parameter defined by - a dict with a 'resolver_' key. - - Each subclass must implement a 'parse' method that returns the relevant object. - """ - - RESOLVER_KEY: ClassVar[str] = "resolver_" - - _global_data: dict[str, Any] = PrivateAttr({}) - """Additional parameter ignored by all model_* methods.""" - - @classmethod - def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type: - """Get class from string and an optional module - - Will first try to import the class from `class_path` alone. If it results in an ImportError, - will try to import from `f'{optional_module}.{class_path}'` - - Args: - class_path (str): Class path with format 'my_module.MyClass'. - optional_module (Optional[str]): Optional module path. Used to provide a default path for some known objects and simplify the notation. - - Raises: - ValueError: if the class can't be imported, even using the optional module. - """ - *modules, class_name = class_path.rsplit(".", 1) - module_name = modules[0] if modules else optional_module - if module_name is None: - raise ValueError("Must specify a module to import class from") - try: - module = importlib.import_module(module_name) - klass = getattr(module, class_name) - except (ImportError, AttributeError): - if optional_module and module_name != optional_module: - full_klass_path = optional_module + "." + class_path - return cls._get_class(full_klass_path) - raise ValueError(f"Could not find {class_name} in {module_name}") - return klass # type: ignore[no-any-return] - - def resolve_param(self, param: ParamConfig) -> Any: - """Finds the parameter value from its definition.""" - if not isinstance(param, ParamToResolveConfig): - # some parameters do not have to be resolved, real - # values are already provided - return param - # all ParamToResolveConfig have a resolver_ field - return param.resolve(self._global_data) - - def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]: - """Resolve all parameters - - Returning dict[str, Any] because parameters can be anything (str, float, list, dict...) - """ - return { - param_name: self.resolve_param(param) - for param_name, param in params.items() - } - - @abc.abstractmethod - def parse(self, resolved_data: dict[str, Any] | None = None) -> Any: - raise NotImplementedError() - - -T = TypeVar("T") -"""Generic type to help (a bit) mypy with the return type of the parse method""" - - -class ObjectConfig(AbstractConfig, Generic[T]): - """A config class to represent an object from a class name - and its constructor parameters. - """ - - class_: str | None = Field(default=None, validate_default=True) - """Path to class to be instantiated.""" - params_: dict[str, ParamConfig] = {} - """Initialization parameters.""" - - DEFAULT_MODULE: ClassVar[str] = "." - """Default module to import the class from.""" - INTERFACE: ClassVar[type] = object - """Constraint on the class (must be a subclass of).""" - REQUIRED_PARAMS: ClassVar[list[str]] = [] - """List of required parameters for this object constructor.""" - - @field_validator("params_") - @classmethod - def validate_params(cls, params_: dict[str, Any]) -> dict[str, Any]: - """Make sure all required parameters are provided.""" - for p in cls.REQUIRED_PARAMS: - if p not in params_: - raise ValueError(f"Missing parameter {p}") - return params_ - - def get_module(self) -> str: - return self.DEFAULT_MODULE - - def get_interface(self) -> type: - return self.INTERFACE - - def parse(self, resolved_data: dict[str, Any] | None = None) -> T: - """Import `class_`, resolve `params_` and instantiate object.""" - self._global_data = resolved_data or {} - if self.class_ is None: - raise ValueError("`class_` is not defined") - klass = self._get_class(self.class_, self.get_module()) - if not issubclass(klass, self.get_interface()): - raise ValueError( - f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'" - ) - params = self.resolve_params(self.params_) - try: - obj = klass(**params) - except TypeError as e: - raise e - return obj # type: ignore[return-value] - - -class Neo4jDriverConfig(ObjectConfig[neo4j.Driver]): - REQUIRED_PARAMS = ["uri", "user", "password"] - - @field_validator("class_", mode="before") - @classmethod - def validate_class(cls, class_: Any) -> str: - """`class_` parameter is not used because we're always using the sync driver.""" - if class_: - # logger.info("Parameter class_ is not used") - ... - # not used - return "not used" - - def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: - params = self.resolve_params(self.params_) - uri = params.pop( - "uri" - ) # we know these params are there because of the required params validator - user = params.pop("user") - password = params.pop("password") - driver = neo4j.GraphDatabase.driver(uri, auth=(user, password), **params) - return driver - - -# note: using the notation with RootModel + root: field -# instead of RootModel[] for clarity -# but this requires the type: ignore comment below -class Neo4jDriverType(RootModel): # type: ignore[type-arg] - """A model to wrap neo4j.Driver and Neo4jDriverConfig objects. - - The `parse` method always returns a neo4j.Driver. - """ - - root: Union[neo4j.Driver, Neo4jDriverConfig] - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: - if isinstance(self.root, neo4j.Driver): - return self.root - # self.root is a Neo4jDriverConfig object - return self.root.parse() - - -class LLMConfig(ObjectConfig[LLMInterface]): - """Configuration for any LLMInterface object. - - By default, will try to import from `neo4j_graphrag.llm`. - """ - - DEFAULT_MODULE = "neo4j_graphrag.llm" - INTERFACE = LLMInterface - - -class LLMType(RootModel): # type: ignore[type-arg] - root: Union[LLMInterface, LLMConfig] - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface: - if isinstance(self.root, LLMInterface): - return self.root - return self.root.parse() - - -class EmbedderConfig(ObjectConfig[Embedder]): - """Configuration for any Embedder object. - - By default, will try to import from `neo4j_graphrag.embeddings`. - """ - - DEFAULT_MODULE = "neo4j_graphrag.embeddings" - INTERFACE = Embedder - - -class EmbedderType(RootModel): # type: ignore[type-arg] - root: Union[Embedder, EmbedderConfig] - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder: - if isinstance(self.root, Embedder): - return self.root - return self.root.parse() - - -class ComponentConfig(ObjectConfig[Component]): - """A config model for all components. - - In addition to the object config, components can have pre-defined parameters - that will be passed to the `run` method, ie `run_params_`. - """ - - run_params_: dict[str, ParamConfig] = {} - - DEFAULT_MODULE = "neo4j_graphrag.experimental.components" - INTERFACE = Component - - -class ComponentType(RootModel): # type: ignore[type-arg] - root: Union[Component, ComponentConfig] - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def parse(self, resolved_data: dict[str, Any] | None = None) -> Component: - if isinstance(self.root, Component): - return self.root - return self.root.parse(resolved_data) - - -class PipelineType(str, enum.Enum): - """Pipeline type: - - NONE => Pipeline - SIMPLE_KG_PIPELINE ~> SimpleKGPipeline - """ - - NONE = "none" - SIMPLE_KG_PIPELINE = "SimpleKGPipeline" - - -class AbstractPipelineConfig(AbstractConfig): - """This class defines the fields possibly used by all pipelines: neo4j drivers, LLMs... - - neo4j_config, llm_config can be provided by user as a single item or a dict of items. - Validators deal with type conversion so that the field in all instances is a dict of items. - """ - - neo4j_config: dict[str, Neo4jDriverType] = {} - llm_config: dict[str, LLMType] = {} - embedder_config: dict[str, EmbedderType] = {} - # extra parameters values that can be used in different places of the config file - extras: dict[str, Any] = {} - - DEFAULT_NAME: ClassVar[str] = "default" - """Name of the default item in dict - """ - - @field_validator("neo4j_config", mode="before") - @classmethod - def validate_drivers( - cls, drivers: Union[Neo4jDriverType, dict[str, Any]] - ) -> dict[str, Any]: - if not isinstance(drivers, dict) or "params_" in drivers: - return {cls.DEFAULT_NAME: drivers} - return drivers - - @field_validator("llm_config", mode="before") - @classmethod - def validate_llms(cls, llms: Union[LLMType, dict[str, Any]]) -> dict[str, Any]: - if not isinstance(llms, dict) or "params_" in llms: - return {cls.DEFAULT_NAME: llms} - return llms - - @field_validator("embedder_config", mode="before") - @classmethod - def validate_embedders( - cls, embedders: Union[EmbedderType, dict[str, Any]] - ) -> dict[str, Any]: - if not isinstance(embedders, dict) or "params_" in embedders: - return {cls.DEFAULT_NAME: embedders} - return embedders - - @field_validator("llm_config", "neo4j_config", "embedder_config", mode="after") - @classmethod - def validate_names(cls, lst: dict[str, Any]) -> dict[str, Any]: - if not lst: - return lst - c = Counter(lst.keys()) - most_common_item = c.most_common(1) - most_common_count = most_common_item[0][1] - if most_common_count > 1: - raise ValueError(f"names must be unique {most_common_item}") - return lst - - def _resolve_component(self, config: ComponentConfig) -> Component: - return config.parse(self._global_data) - - def _resolve_component_definition( - self, name: str, config: ComponentType - ) -> ComponentDefinition: - component = config.parse(self._global_data) - if hasattr(config, "run_params_"): - component_run_params = self.resolve_params(config.run_params_) - else: - component_run_params = {} - return ComponentDefinition( - name=name, - component=component, - run_params=component_run_params, - ) - - def _parse_global_data(self) -> dict[str, Any]: - """Global data contains data that can be referenced in other parts of the - config. Typically, neo4j drivers and llms can be referenced in component input - parameters (see ConfigKeyParamResolver) - """ - drivers: dict[str, neo4j.Driver] = { - driver_name: driver_config.parse() - for driver_name, driver_config in self.neo4j_config.items() - } - llms: dict[str, LLMInterface] = { - llm_name: llm_config.parse() - for llm_name, llm_config in self.llm_config.items() - } - embedders: dict[str, Embedder] = { - embedder_name: embedder_config.parse() - for embedder_name, embedder_config in self.embedder_config.items() - } - return { - "neo4j_config": drivers, - "llm_config": llms, - "embedder_config": embedders, - "extras": self.resolve_params(self.extras), - } - - def _get_components(self) -> list[ComponentDefinition]: - return [] - - def _get_connections(self) -> list[ConnectionDefinition]: - return [] - - def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: - """Parse the full config and returns a PipelineDefinition object, containing instantiated - components and a list of connections. - """ - self._global_data = self._parse_global_data() - return PipelineDefinition( - components=self._get_components(), - connections=self._get_connections(), - ) - - def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - return user_input - - def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver: - drivers = self._global_data.get("neo4j_config", {}) - return drivers.get(name) # type: ignore[no-any-return] - - def get_default_neo4j_driver(self) -> neo4j.Driver: - return self.get_neo4j_driver_by_name(self.DEFAULT_NAME) - - def get_llm_by_name(self, name: str) -> LLMInterface: - llms = self._global_data.get("llm_config", {}) - return llms.get(name) # type: ignore[no-any-return] - - def get_default_llm(self) -> LLMInterface: - return self.get_llm_by_name(self.DEFAULT_NAME) - - def get_embedder_by_name(self, name: str) -> Embedder: - embedders = self._global_data.get("embedder_config", {}) - return embedders.get(name) # type: ignore[no-any-return] - - def get_default_embedder(self) -> Embedder: - return self.get_embedder_by_name(self.DEFAULT_NAME) - - -class PipelineConfig(AbstractPipelineConfig): - """Configuration class for raw pipelines. Config must contain all components and connections.""" - - component_config: dict[str, ComponentType] - connection_config: list[ConnectionDefinition] - template_: Literal[PipelineType.NONE] = PipelineType.NONE - - def _get_connections(self) -> list[ConnectionDefinition]: - return self.connection_config - - def _get_components(self) -> list[ComponentDefinition]: - return [ - self._resolve_component_definition(name, component_config) - for name, component_config in self.component_config.items() - ] - - -class TemplatePipelineConfig(AbstractPipelineConfig): - """This class represent a 'template' pipeline, ie pipeline with pre-defined default - components and fixed connections. - - Component names are defined in the COMPONENTS class var. For each of them, - a `_get_` method must be implemented that returns the proper - component. Optionally, `_get__run_params` can be implemented to - deal with parameters required by the component's run method. - """ - - COMPONENTS: ClassVar[list[str]] = [] - - def _get_components(self) -> list[ComponentDefinition]: - components = [] - for component_name in self.COMPONENTS: - method = getattr(self, f"_get_{component_name}") - component = method() - if component is None: - continue - method = getattr(self, f"_get_run_params_for_{component_name}", None) - run_params = method() if method else {} - components.append( - ComponentDefinition( - name=component_name, - component=component, - run_params=run_params, - ) - ) - return components - - def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - return {} - - -class SimpleKGPipelineConfig(TemplatePipelineConfig): - COMPONENTS: ClassVar[list[str]] = [ - "pdf_loader", - "splitter", - "chunk_embedder", - "schema", - "extractor", - "writer", - "resolver", - ] - - template_: Literal[PipelineType.SIMPLE_KG_PIPELINE] = ( - PipelineType.SIMPLE_KG_PIPELINE - ) - - from_pdf: bool = False - entities: Sequence[EntityInputType] = [] - relations: Sequence[RelationInputType] = [] - potential_schema: Optional[list[tuple[str, str, str]]] = None - on_error: OnError = OnError.IGNORE - prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() - perform_entity_resolution: bool = True - lexical_graph_config: Optional[LexicalGraphConfig] = None - neo4j_database: Optional[str] = None - - pdf_loader: ComponentConfig | None = None - kg_writer: ComponentConfig | None = None - text_splitter: ComponentConfig | None = None - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def _get_pdf_loader(self) -> PdfLoader | None: - if not self.from_pdf: - return None - if self.pdf_loader: - return self.pdf_loader.parse(self._global_data) # type: ignore - return PdfLoader() - - def _get_splitter(self) -> TextSplitter: - if self.text_splitter: - return self.text_splitter.parse(self._global_data) # type: ignore - return FixedSizeSplitter() - - def _get_chunk_embedder(self) -> TextChunkEmbedder: - return TextChunkEmbedder(embedder=self.get_default_embedder()) - - def _get_schema(self) -> SchemaBuilder: - return SchemaBuilder() - - def _get_run_params_for_schema(self) -> dict[str, Any]: - return { - "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], - "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], - "potential_schema": self.potential_schema, - } - - def _get_extractor(self) -> EntityRelationExtractor: - return LLMEntityRelationExtractor( - llm=self.get_default_llm(), - prompt_template=self.prompt_template, - on_error=self.on_error, - ) - - def _get_writer(self) -> KGWriter: - if self.kg_writer: - return self.kg_writer.parse(self._global_data) # type: ignore - return Neo4jWriter(driver=self.get_default_neo4j_driver()) - - def _get_resolver(self) -> EntityResolver | None: - if not self.perform_entity_resolution: - return None - return SinglePropertyExactMatchResolver( - driver=self.get_default_neo4j_driver(), - ) - - def _get_connections(self) -> list[ConnectionDefinition]: - connections = [] - if self.from_pdf: - connections.append( - ConnectionDefinition( - start="pdf_loader", - end="splitter", - input_config={"text": "pdf_loader.text"}, - ) - ) - connections.append( - ConnectionDefinition( - start="schema", - end="extractor", - input_config={ - "schema": "schema", - "document_info": "pdf_loader.document_info", - }, - ) - ) - else: - connections.append( - ConnectionDefinition( - start="schema", - end="extractor", - input_config={ - "schema": "schema", - }, - ) - ) - connections.append( - ConnectionDefinition( - start="splitter", - end="chunk_embedder", - input_config={ - "text_chunks": "splitter", - }, - ) - ) - connections.append( - ConnectionDefinition( - start="chunk_embedder", - end="extractor", - input_config={ - "chunks": "chunk_embedder", - }, - ) - ) - connections.append( - ConnectionDefinition( - start="extractor", - end="writer", - input_config={ - "graph": "extractor", - }, - ) - ) - - if self.perform_entity_resolution: - connections.append( - ConnectionDefinition( - start="writer", - end="resolver", - input_config={}, - ) - ) - - return connections - - def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - run_params = {} - if self.lexical_graph_config: - run_params["extractor"] = { - "lexical_graph_config": self.lexical_graph_config - } - text = user_input.get("text") - file_path = user_input.get("file_path") - if not ((text is None) ^ (file_path is None)): - # exactly one of text or user_input must be set - raise PipelineDefinitionError( - "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." - ) - if self.from_pdf: - if not file_path: - raise PipelineDefinitionError( - "Expected 'file_path' argument when 'from_pdf' is True." - ) - run_params["pdf_loader"] = {"filepath": file_path} - else: - if not text: - raise PipelineDefinitionError( - "Expected 'text' argument when 'from_pdf' is False." - ) - run_params["splitter"] = {"text": text} - return run_params - - -def _get_discriminator_value(model: Any) -> PipelineType: - template_ = None - if "template_" in model: - template_ = model["template_"] - if hasattr(model, "template_"): - template_ = model.template_ - return PipelineType(template_) or PipelineType.NONE - - -class PipelineConfigWrapper(BaseModel): - """The pipeline config wrapper will parse the right pipeline config based on the `template_` field.""" - - config: Union[ - Annotated[PipelineConfig, Tag(PipelineType.NONE)], - Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], - ] = Field(discriminator=Discriminator(_get_discriminator_value)) - - def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: - return self.config.parse(resolved_data) - - def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: - return self.config.get_run_params(user_input) - - -class PipelineRunner: - """Pipeline runner builds a pipeline from different objects and exposes a run method to run pipeline - - Pipeline can be built from: - - A PipelineDefinition (`__init__` method) - - A PipelineConfig (`from_config` method) - - A config file (`from_config_file` method) - """ - - def __init__( - self, - pipeline_definition: PipelineDefinition, - config: AbstractPipelineConfig | None = None, - ) -> None: - self.config = config - self.pipeline = Pipeline.from_definition(pipeline_definition) - self.run_params = pipeline_definition.get_run_params() - - @classmethod - def from_config(cls, config: AbstractPipelineConfig | dict[str, Any]) -> Self: - wrapper = PipelineConfigWrapper.model_validate({"config": config}) - return cls(wrapper.parse(), config=wrapper.config) - - @classmethod - def from_config_file(cls, file_path: Union[str, Path]) -> Self: - if not isinstance(file_path, Path): - file_path = Path(file_path) - data = ConfigReader().read(file_path) - return cls.from_config(data) - - async def run(self, data: dict[str, Any]) -> PipelineResult: - # pipeline_conditional_run_params = self. - if self.config: - run_param = deep_update(self.run_params, self.config.get_run_params(data)) - else: - run_param = deep_update(self.run_params, data) - return await self.pipeline.run(data=run_param) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py b/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py deleted file mode 100644 index 6820fd5a..00000000 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py +++ /dev/null @@ -1,248 +0,0 @@ -from unittest.mock import MagicMock, Mock, patch - -import neo4j -import pytest - -from neo4j_graphrag.experimental.pipeline.config.config_parser import ( - AbstractPipelineConfig, - LLMConfig, - LLMType, - Neo4jDriverConfig, - Neo4jDriverType, -) -from neo4j_graphrag.llm import LLMInterface, OpenAILLM - - -@pytest.fixture(scope="function") -def driver() -> MagicMock: - return MagicMock(spec=neo4j.Driver) - - -def test_neo4j_driver_config() -> None: - config = Neo4jDriverConfig.model_validate( - { - "params_": { - "uri": "bolt://", - "user": "a user", - "password": "a password", - } - } - ) - assert config.class_ == "not used" - assert config.params_ == { - "uri": "bolt://", - "user": "a user", - "password": "a password", - } - with patch( - "neo4j_graphrag.experimental.pipeline.config.config_poc.neo4j.GraphDatabase.driver" - ) as driver_mock: - driver_mock.return_value = "a driver" - d = config.parse() - driver_mock.assert_called_once_with("bolt://", auth=("a user", "a password")) - assert d == "a driver" # type: ignore - - -def test_neo4j_driver_type_with_driver(driver: neo4j.Driver) -> None: - driver_type = Neo4jDriverType(driver) - assert driver_type.parse() == driver - - -def test_neo4j_driver_type_with_config() -> None: - driver_type = Neo4jDriverType( - Neo4jDriverConfig( - params_={ - "uri": "bolt://", - "user": "", - "password": "", - } - ) - ) - driver = driver_type.parse() - assert isinstance(driver, neo4j.Driver) - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.Neo4jDriverConfig.parse") -def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_params_( - mock_neo4j_config: Mock, -) -> None: - mock_neo4j_config.return_value = "text" - config = AbstractPipelineConfig.model_validate( - { - "neo4j_config": { - "params_": { - "uri": "bolt://", - "user": "", - "password": "", - } - } - } - ) - assert isinstance(config.neo4j_config, dict) - assert "default" in config.neo4j_config - config.parse() - mock_neo4j_config.assert_called_once() - assert config._global_data["neo4j_config"]["default"] == "text" - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.Neo4jDriverConfig.parse") -def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_names( - mock_neo4j_config: Mock, -) -> None: - mock_neo4j_config.return_value = "text" - config = AbstractPipelineConfig.model_validate( - { - "neo4j_config": { - "my_driver": { - "params_": { - "uri": "bolt://", - "user": "", - "password": "", - } - } - } - } - ) - assert isinstance(config.neo4j_config, dict) - assert "my_driver" in config.neo4j_config - config.parse() - mock_neo4j_config.assert_called_once() - assert config._global_data["neo4j_config"]["my_driver"] == "text" - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.Neo4jDriverConfig.parse") -def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_driver( - mock_neo4j_config: Mock, driver: neo4j.Driver -) -> None: - config = AbstractPipelineConfig.model_validate( - { - "neo4j_config": { - "my_driver": driver, - } - } - ) - assert isinstance(config.neo4j_config, dict) - assert "my_driver" in config.neo4j_config - config.parse() - assert not mock_neo4j_config.called - assert config._global_data["neo4j_config"]["my_driver"] == driver - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.Neo4jDriverConfig.parse") -def test_abstract_pipeline_config_neo4j_config_is_a_driver( - mock_neo4j_config: Mock, driver: neo4j.Driver -) -> None: - config = AbstractPipelineConfig.model_validate( - { - "neo4j_config": driver, - } - ) - assert isinstance(config.neo4j_config, dict) - assert "default" in config.neo4j_config - config.parse() - assert not mock_neo4j_config.called - assert config._global_data["neo4j_config"]["default"] == driver - - -@pytest.fixture(scope="function") -def llm() -> LLMInterface: - return MagicMock(spec=LLMInterface) - - -def test_llm_config() -> None: - config = LLMConfig.model_validate( - {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}} - ) - assert config.class_ == "OpenAILLM" - assert config.get_module() == "neo4j_graphrag.llm" - assert config.get_interface() == LLMInterface - assert config.params_ == {"model_name": "gpt-4o"} - d = config.parse() - assert isinstance(d, OpenAILLM) - - -def test_llm_type_with_driver(llm: LLMInterface) -> None: - llm_type = LLMType(llm) - assert llm_type.parse() == llm - - -def test_llm_type_with_config() -> None: - llm_type = LLMType(LLMConfig(class_="OpenAILLM", params_={"model_name": "gpt-4o"})) - llm = llm_type.parse() - assert isinstance(llm, LLMInterface) - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.LLMConfig.parse") -def test_abstract_pipeline_config_llm_config_is_a_dict_with_params_( - mock_llm_config: Mock, -) -> None: - mock_llm_config.return_value = "text" - config = AbstractPipelineConfig.model_validate( - {"llm_config": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}}} - ) - assert isinstance(config.llm_config, dict) - assert "default" in config.llm_config - config.parse() - mock_llm_config.assert_called_once() - assert config._global_data["llm_config"]["default"] == "text" - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.LLMConfig.parse") -def test_abstract_pipeline_config_llm_config_is_a_dict_with_names( - mock_llm_config: Mock, -) -> None: - mock_llm_config.return_value = "text" - config = AbstractPipelineConfig.model_validate( - { - "llm_config": { - "my_llm": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}} - } - } - ) - assert isinstance(config.llm_config, dict) - assert "my_llm" in config.llm_config - config.parse() - mock_llm_config.assert_called_once() - assert config._global_data["llm_config"]["my_llm"] == "text" - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.LLMConfig.parse") -def test_abstract_pipeline_config_llm_config_is_a_dict_with_llm( - mock_llm_config: Mock, llm: LLMInterface -) -> None: - config = AbstractPipelineConfig.model_validate( - { - "llm_config": { - "my_llm": llm, - } - } - ) - assert isinstance(config.llm_config, dict) - assert "my_llm" in config.llm_config - config.parse() - assert not mock_llm_config.called - assert config._global_data["llm_config"]["my_llm"] == llm - - -@patch.multiple(AbstractPipelineConfig, __abstractmethods__=set()) -@patch("neo4j_graphrag.experimental.pipeline.config.config_poc.LLMConfig.parse") -def test_abstract_pipeline_config_llm_config_is_a_llm( - mock_llm_config: Mock, llm: LLMInterface -) -> None: - config = AbstractPipelineConfig.model_validate( - { - "llm_config": llm, - } - ) - assert isinstance(config.llm_config, dict) - assert "default" in config.llm_config - config.parse() - assert not mock_llm_config.called - assert config._global_data["llm_config"]["default"] == llm diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py index b1bacc37..0d4af265 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Read JSON or YAML files and returns a dict. +No data validation performed at this stage. +""" import json from pathlib import Path @@ -61,7 +64,7 @@ def read_yaml(file_path: Path) -> Any: def _guess_format_and_read(self, file_path: Path) -> dict[str, Any]: extension = file_path.suffix.lower() # Note: .suffix returns an empty string if Path has no extension - # if not returning a dict, pasing will fail later on + # if not returning a dict, parsing will fail later on if extension in [".json"]: return self.read_json(file_path) # type: ignore[no-any-return] if extension in [".yaml", ".yml"]: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py new file mode 100644 index 00000000..0f3ae921 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py @@ -0,0 +1,212 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Config for all parameters that can be both provided as object instance and +config dict with 'class_' and 'params_' keys. +""" + +from __future__ import annotations + +import logging +from typing import ( + Any, + ClassVar, + Generic, + TypeVar, + Union, +) + +import neo4j +from pydantic import ( + ConfigDict, + Field, + RootModel, + field_validator, +) + +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamConfig, +) +from neo4j_graphrag.llm import LLMInterface + +logger = logging.getLogger(__name__) + + +T = TypeVar("T") +"""Generic type to help mypy with the parse method when we know the exact +expected return type (e.g. for the Neo4jDriverConfig below). +""" + + +class ObjectConfig(AbstractConfig, Generic[T]): + """A config class to represent an object from a class name + and its constructor parameters. + """ + + class_: str | None = Field(default=None, validate_default=True) + """Path to class to be instantiated.""" + params_: dict[str, ParamConfig] = {} + """Initialization parameters.""" + + DEFAULT_MODULE: ClassVar[str] = "." + """Default module to import the class from.""" + INTERFACE: ClassVar[type] = object + """Constraint on the class (must be a subclass of).""" + REQUIRED_PARAMS: ClassVar[list[str]] = [] + """List of required parameters for this object constructor.""" + + @field_validator("params_") + @classmethod + def validate_params(cls, params_: dict[str, Any]) -> dict[str, Any]: + """Make sure all required parameters are provided.""" + for p in cls.REQUIRED_PARAMS: + if p not in params_: + raise ValueError(f"Missing parameter {p}") + return params_ + + def get_module(self) -> str: + return self.DEFAULT_MODULE + + def get_interface(self) -> type: + return self.INTERFACE + + def parse(self, resolved_data: dict[str, Any] | None = None) -> T: + """Import `class_`, resolve `params_` and instantiate object.""" + self._global_data = resolved_data or {} + if self.class_ is None: + raise ValueError("`class_` is not defined") + klass = self._get_class(self.class_, self.get_module()) + if not issubclass(klass, self.get_interface()): + raise ValueError( + f"Invalid class '{klass}'. Expected a subclass of '{self.get_interface()}'" + ) + params = self.resolve_params(self.params_) + try: + obj = klass(**params) + except TypeError as e: + raise e + # here we still need to ignore type because _get_class returns Any + return obj # type: ignore[return-value] + + +class Neo4jDriverConfig(ObjectConfig[neo4j.Driver]): + REQUIRED_PARAMS = ["uri", "user", "password"] + + @field_validator("class_", mode="before") + @classmethod + def validate_class(cls, class_: Any) -> str: + """`class_` parameter is not used because we're always using the sync driver.""" + if class_: + logger.info("Parameter class_ is not used for Neo4jDriverConfig") + # not used + return "not used" + + def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: + params = self.resolve_params(self.params_) + uri = params.pop( + "uri" + ) # we know these params are there because of the required params validator + user = params.pop("user") + password = params.pop("password") + driver = neo4j.GraphDatabase.driver(uri, auth=(user, password), **params) + return driver + + +# note: using the notation with RootModel + root: field +# instead of RootModel[] for clarity +# but this requires the type: ignore comment below +class Neo4jDriverType(RootModel): # type: ignore[type-arg] + """A model to wrap neo4j.Driver and Neo4jDriverConfig objects. + + The `parse` method always returns a neo4j.Driver. + """ + + root: Union[neo4j.Driver, Neo4jDriverConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: + if isinstance(self.root, neo4j.Driver): + return self.root + # self.root is a Neo4jDriverConfig object + return self.root.parse() + + +class LLMConfig(ObjectConfig[LLMInterface]): + """Configuration for any LLMInterface object. + + By default, will try to import from `neo4j_graphrag.llm`. + """ + + DEFAULT_MODULE = "neo4j_graphrag.llm" + INTERFACE = LLMInterface + + +class LLMType(RootModel): # type: ignore[type-arg] + root: Union[LLMInterface, LLMConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface: + if isinstance(self.root, LLMInterface): + return self.root + return self.root.parse() + + +class EmbedderConfig(ObjectConfig[Embedder]): + """Configuration for any Embedder object. + + By default, will try to import from `neo4j_graphrag.embeddings`. + """ + + DEFAULT_MODULE = "neo4j_graphrag.embeddings" + INTERFACE = Embedder + + +class EmbedderType(RootModel): # type: ignore[type-arg] + root: Union[Embedder, EmbedderConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder: + if isinstance(self.root, Embedder): + return self.root + return self.root.parse() + + +class ComponentConfig(ObjectConfig[Component]): + """A config model for all components. + + In addition to the object config, components can have pre-defined parameters + that will be passed to the `run` method, ie `run_params_`. + """ + + run_params_: dict[str, ParamConfig] = {} + + DEFAULT_MODULE = "neo4j_graphrag.experimental.components" + INTERFACE = Component + + +class ComponentType(RootModel): # type: ignore[type-arg] + root: Union[Component, ComponentConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Component: + if isinstance(self.root, Component): + return self.root + return self.root.parse(resolved_data) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py new file mode 100644 index 00000000..e7869de8 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py @@ -0,0 +1,182 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, ClassVar, Literal, Union + +import neo4j +from pydantic import field_validator + +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig +from neo4j_graphrag.experimental.pipeline.config.object_config import ( + ComponentType, + EmbedderType, + LLMType, + Neo4jDriverType, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.types import ( + ComponentDefinition, + ConnectionDefinition, + PipelineDefinition, +) +from neo4j_graphrag.llm import LLMInterface + + +class AbstractPipelineConfig(AbstractConfig): + """This class defines the fields possibly used by all pipelines: neo4j drivers, LLMs... + neo4j_config, llm_config can be provided by user as a single item or a dict of items. + Validators deal with type conversion so that the field in all instances is a dict of items. + """ + + neo4j_config: dict[str, Neo4jDriverType] = {} + llm_config: dict[str, LLMType] = {} + embedder_config: dict[str, EmbedderType] = {} + # extra parameters values that can be used in different places of the config file + extras: dict[str, Any] = {} + + DEFAULT_NAME: ClassVar[str] = "default" + """Name of the default item in dict + """ + + @field_validator("neo4j_config", mode="before") + @classmethod + def validate_drivers( + cls, drivers: Union[Neo4jDriverType, dict[str, Any]] + ) -> dict[str, Any]: + if not isinstance(drivers, dict) or "params_" in drivers: + return {cls.DEFAULT_NAME: drivers} + return drivers + + @field_validator("llm_config", mode="before") + @classmethod + def validate_llms(cls, llms: Union[LLMType, dict[str, Any]]) -> dict[str, Any]: + if not isinstance(llms, dict) or "class_" in llms: + return {cls.DEFAULT_NAME: llms} + return llms + + @field_validator("embedder_config", mode="before") + @classmethod + def validate_embedders( + cls, embedders: Union[EmbedderType, dict[str, Any]] + ) -> dict[str, Any]: + if not isinstance(embedders, dict) or "class_" in embedders: + return {cls.DEFAULT_NAME: embedders} + return embedders + + def _resolve_component_definition( + self, name: str, config: ComponentType + ) -> ComponentDefinition: + component = config.parse(self._global_data) + if hasattr(config.root, "run_params_"): + component_run_params = self.resolve_params(config.root.run_params_) + else: + component_run_params = {} + return ComponentDefinition( + name=name, + component=component, + run_params=component_run_params, + ) + + def _parse_global_data(self) -> dict[str, Any]: + """Global data contains data that can be referenced in other parts of the + config. + + Typically, neo4j drivers, LLMs and embedders can be referenced in component + input parameters. + """ + # 'extras' parameters can be referenced in other configs, + # that's why they are parsed before the others + # e.g., an API key used for both LLM and Embedder can be stored only + # once in extras. + extra_data = { + "extras": self.resolve_params(self.extras), + } + drivers: dict[str, neo4j.Driver] = { + driver_name: driver_config.parse(extra_data) + for driver_name, driver_config in self.neo4j_config.items() + } + llms: dict[str, LLMInterface] = { + llm_name: llm_config.parse(extra_data) + for llm_name, llm_config in self.llm_config.items() + } + embedders: dict[str, Embedder] = { + embedder_name: embedder_config.parse(extra_data) + for embedder_name, embedder_config in self.embedder_config.items() + } + return { + **extra_data, + "neo4j_config": drivers, + "llm_config": llms, + "embedder_config": embedders, + } + + def _get_components(self) -> list[ComponentDefinition]: + return [] + + def _get_connections(self) -> list[ConnectionDefinition]: + return [] + + def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: + """Parse the full config and returns a PipelineDefinition object, containing instantiated + components and a list of connections. + """ + self._global_data = self._parse_global_data() + return PipelineDefinition( + components=self._get_components(), + connections=self._get_connections(), + ) + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + return user_input + + def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver: + drivers: dict[str, neo4j.Driver] = self._global_data.get("neo4j_config", {}) + return drivers[name] + + def get_default_neo4j_driver(self) -> neo4j.Driver: + return self.get_neo4j_driver_by_name(self.DEFAULT_NAME) + + def get_llm_by_name(self, name: str) -> LLMInterface: + llms: dict[str, LLMInterface] = self._global_data.get("llm_config", {}) + return llms[name] + + def get_default_llm(self) -> LLMInterface: + return self.get_llm_by_name(self.DEFAULT_NAME) + + def get_embedder_by_name(self, name: str) -> Embedder: + embedders: dict[str, Embedder] = self._global_data.get("embedder_config", {}) + return embedders[name] + + def get_default_embedder(self) -> Embedder: + return self.get_embedder_by_name(self.DEFAULT_NAME) + + +class PipelineConfig(AbstractPipelineConfig): + """Configuration class for raw pipelines. + Config must contain all components and connections.""" + + component_config: dict[str, ComponentType] + connection_config: list[ConnectionDefinition] + template_: Literal[PipelineType.NONE] = PipelineType.NONE + + def _get_connections(self) -> list[ConnectionDefinition]: + return self.connection_config + + def _get_components(self) -> list[ComponentDefinition]: + return [ + self._resolve_component_definition(name, component_config) + for name, component_config in self.component_config.items() + ] diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py new file mode 100644 index 00000000..7d4299ff --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -0,0 +1,115 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline config wrapper (router based on 'template_' key) +and pipeline runner. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import ( + Annotated, + Any, + Union, +) + +from pydantic import ( + BaseModel, + Discriminator, + Field, + Tag, +) +from pydantic.v1.utils import deep_update +from typing_extensions import Self + +from neo4j_graphrag.experimental.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( + AbstractPipelineConfig, + PipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder import ( + SimpleKGPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult +from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition + +logger = logging.getLogger(__name__) + + +def _get_discriminator_value(model: Any) -> PipelineType: + template_ = None + if "template_" in model: + template_ = model["template_"] + if hasattr(model, "template_"): + template_ = model.template_ + return PipelineType(template_) or PipelineType.NONE + + +class PipelineConfigWrapper(BaseModel): + """The pipeline config wrapper will parse the right pipeline config based on the `template_` field.""" + + config: Union[ + Annotated[PipelineConfig, Tag(PipelineType.NONE)], + Annotated[SimpleKGPipelineConfig, Tag(PipelineType.SIMPLE_KG_PIPELINE)], + ] = Field(discriminator=Discriminator(_get_discriminator_value)) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> PipelineDefinition: + return self.config.parse(resolved_data) + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + return self.config.get_run_params(user_input) + + +class PipelineRunner: + """Pipeline runner builds a pipeline from different objects and exposes a run method to run pipeline + + Pipeline can be built from: + - A PipelineDefinition (`__init__` method) + - A PipelineConfig (`from_config` method) + - A config file (`from_config_file` method) + """ + + def __init__( + self, + pipeline_definition: PipelineDefinition, + config: AbstractPipelineConfig | None = None, + ) -> None: + self.config = config + self.pipeline = Pipeline.from_definition(pipeline_definition) + self.run_params = pipeline_definition.get_run_params() + + @classmethod + def from_config(cls, config: AbstractPipelineConfig | dict[str, Any]) -> Self: + wrapper = PipelineConfigWrapper.model_validate({"config": config}) + return cls(wrapper.parse(), config=wrapper.config) + + @classmethod + def from_config_file(cls, file_path: Union[str, Path]) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + data = ConfigReader().read(file_path) + return cls.from_config(data) + + async def run(self, data: dict[str, Any]) -> PipelineResult: + # pipeline_conditional_run_params = self. + if self.config: + run_param = deep_update(self.run_params, self.config.get_run_params(data)) + else: + run_param = deep_update(self.run_params, data) + return await self.pipeline.run(data=run_param) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py new file mode 100644 index 00000000..125a1c87 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .simple_kg_builder import SimpleKGPipelineConfig + +__all__ = [ + "SimpleKGPipelineConfig", +] diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py new file mode 100644 index 00000000..e0e7fca7 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py @@ -0,0 +1,55 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, ClassVar + +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( + AbstractPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition + + +class TemplatePipelineConfig(AbstractPipelineConfig): + """This class represent a 'template' pipeline, ie pipeline with pre-defined default + components and fixed connections. + + Component names are defined in the COMPONENTS class var. For each of them, + a `_get_` method must be implemented that returns the proper + component. Optionally, `_get_run_params_for_` can be implemented + to deal with parameters required by the component's run method and predefined on + template initialization. + """ + + COMPONENTS: ClassVar[list[str]] = [] + + def _get_components(self) -> list[ComponentDefinition]: + components = [] + for component_name in self.COMPONENTS: + method = getattr(self, f"_get_{component_name}") + component = method() + if component is None: + continue + method = getattr(self, f"_get_run_params_for_{component_name}", None) + run_params = method() if method else {} + components.append( + ComponentDefinition( + name=component_name, + component=component, + run_params=run_params, + ) + ) + return components + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + return {} diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py new file mode 100644 index 00000000..16c6ac0c --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -0,0 +1,224 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, ClassVar, Literal, Optional, Sequence, Union + +from pydantic import ConfigDict + +from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder +from neo4j_graphrag.experimental.components.entity_relation_extractor import ( + EntityRelationExtractor, + LLMEntityRelationExtractor, + OnError, +) +from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter +from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader +from neo4j_graphrag.experimental.components.resolver import ( + EntityResolver, + SinglePropertyExactMatchResolver, +) +from neo4j_graphrag.experimental.components.schema import ( + SchemaBuilder, + SchemaEntity, + SchemaRelation, +) +from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter +from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( + FixedSizeSplitter, +) +from neo4j_graphrag.experimental.components.types import LexicalGraphConfig +from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentConfig +from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import ( + TemplatePipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.config.types import PipelineType +from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.types import ( + ConnectionDefinition, + EntityInputType, + RelationInputType, +) +from neo4j_graphrag.generation.prompts import ERExtractionTemplate + + +class SimpleKGPipelineConfig(TemplatePipelineConfig): + COMPONENTS: ClassVar[list[str]] = [ + "pdf_loader", + "splitter", + "chunk_embedder", + "schema", + "extractor", + "writer", + "resolver", + ] + + template_: Literal[PipelineType.SIMPLE_KG_PIPELINE] = ( + PipelineType.SIMPLE_KG_PIPELINE + ) + + from_pdf: bool = False + entities: Sequence[EntityInputType] = [] + relations: Sequence[RelationInputType] = [] + potential_schema: Optional[list[tuple[str, str, str]]] = None + on_error: OnError = OnError.IGNORE + prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() + perform_entity_resolution: bool = True + lexical_graph_config: Optional[LexicalGraphConfig] = None + neo4j_database: Optional[str] = None + + pdf_loader: ComponentConfig | None = None + kg_writer: ComponentConfig | None = None + text_splitter: ComponentConfig | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def _get_pdf_loader(self) -> PdfLoader | None: + if not self.from_pdf: + return None + if self.pdf_loader: + return self.pdf_loader.parse(self._global_data) # type: ignore + return PdfLoader() + + def _get_splitter(self) -> TextSplitter: + if self.text_splitter: + return self.text_splitter.parse(self._global_data) # type: ignore + return FixedSizeSplitter() + + def _get_chunk_embedder(self) -> TextChunkEmbedder: + return TextChunkEmbedder(embedder=self.get_default_embedder()) + + def _get_schema(self) -> SchemaBuilder: + return SchemaBuilder() + + def _get_run_params_for_schema(self) -> dict[str, Any]: + return { + "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], + "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], + "potential_schema": self.potential_schema, + } + + def _get_extractor(self) -> EntityRelationExtractor: + return LLMEntityRelationExtractor( + llm=self.get_default_llm(), + prompt_template=self.prompt_template, + on_error=self.on_error, + ) + + def _get_writer(self) -> KGWriter: + if self.kg_writer: + return self.kg_writer.parse(self._global_data) # type: ignore + return Neo4jWriter(driver=self.get_default_neo4j_driver()) + + def _get_resolver(self) -> EntityResolver | None: + if not self.perform_entity_resolution: + return None + return SinglePropertyExactMatchResolver( + driver=self.get_default_neo4j_driver(), + ) + + def _get_connections(self) -> list[ConnectionDefinition]: + connections = [] + if self.from_pdf: + connections.append( + ConnectionDefinition( + start="pdf_loader", + end="splitter", + input_config={"text": "pdf_loader.text"}, + ) + ) + connections.append( + ConnectionDefinition( + start="schema", + end="extractor", + input_config={ + "schema": "schema", + "document_info": "pdf_loader.document_info", + }, + ) + ) + else: + connections.append( + ConnectionDefinition( + start="schema", + end="extractor", + input_config={ + "schema": "schema", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="splitter", + end="chunk_embedder", + input_config={ + "text_chunks": "splitter", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="chunk_embedder", + end="extractor", + input_config={ + "chunks": "chunk_embedder", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="extractor", + end="writer", + input_config={ + "graph": "extractor", + }, + ) + ) + + if self.perform_entity_resolution: + connections.append( + ConnectionDefinition( + start="writer", + end="resolver", + input_config={}, + ) + ) + + return connections + + def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: + run_params = {} + if self.lexical_graph_config: + run_params["extractor"] = { + "lexical_graph_config": self.lexical_graph_config + } + text = user_input.get("text") + file_path = user_input.get("file_path") + if not ((text is None) ^ (file_path is None)): + # exactly one of text or user_input must be set + raise PipelineDefinitionError( + "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." + ) + if self.from_pdf: + if not file_path: + raise PipelineDefinitionError( + "Expected 'file_path' argument when 'from_pdf' is True." + ) + run_params["pdf_loader"] = {"filepath": file_path} + else: + if not text: + raise PipelineDefinitionError( + "Expected 'text' argument when 'from_pdf' is False." + ) + run_params["splitter"] = {"text": text} + return run_params diff --git a/src/neo4j_graphrag/experimental/pipeline/config/types.py b/src/neo4j_graphrag/experimental/pipeline/config/types.py new file mode 100644 index 00000000..48f91f48 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/config/types.py @@ -0,0 +1,26 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class PipelineType(str, enum.Enum): + """Pipeline type: + + NONE => Pipeline + SIMPLE_KG_PIPELINE ~> SimpleKGPipeline + """ + + NONE = "none" + SIMPLE_KG_PIPELINE = "SimpleKGPipeline" diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index db4d4e8d..3fca0215 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -22,8 +22,8 @@ from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.types import LexicalGraphConfig -from neo4j_graphrag.experimental.pipeline.config.config_parser import ( - PipelineRunner, +from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner +from neo4j_graphrag.experimental.pipeline.config.template_pipeline import ( SimpleKGPipelineConfig, ) from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 829cad23..5069ab6e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -19,6 +19,7 @@ import neo4j import pytest from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.experimental.pipeline import Component from neo4j_graphrag.llm import LLMInterface from neo4j_graphrag.retrievers import ( HybridRetriever, @@ -98,3 +99,8 @@ def format_function(record: neo4j.Record) -> RetrieverResultItem: ) return format_function + + +@pytest.fixture(scope="function") +def component() -> MagicMock: + return MagicMock(spec=Component) diff --git a/tests/unit/experimental/pipeline/config/__init__.py b/tests/unit/experimental/pipeline/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/__init__.py b/tests/unit/experimental/pipeline/config/template_pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_base.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_base.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/experimental/pipeline/config/test_base.py b/tests/unit/experimental/pipeline/config/test_base.py new file mode 100644 index 00000000..049e8c9b --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_base.py @@ -0,0 +1,67 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import pytest +from neo4j_graphrag.experimental.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamToResolveConfig, +) + + +def test_get_class_no_optional_module() -> None: + c = AbstractConfig() + klass = c._get_class("neo4j_graphrag.experimental.pipeline.Pipeline") + assert klass == Pipeline + + +def test_get_class_optional_module() -> None: + c = AbstractConfig() + klass = c._get_class( + "Pipeline", optional_module="neo4j_graphrag.experimental.pipeline" + ) + assert klass == Pipeline + + +def test_get_class_path_and_optional_module() -> None: + c = AbstractConfig() + klass = c._get_class( + "pipeline.Pipeline", optional_module="neo4j_graphrag.experimental" + ) + assert klass == Pipeline + + +def test_get_class_wrong_path() -> None: + c = AbstractConfig() + with pytest.raises(ValueError): + c._get_class("MyClass") + + +def test_resolve_param_with_param_to_resolve_object() -> None: + c = AbstractConfig() + with patch( + "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamToResolveConfig", + spec=ParamToResolveConfig, + ) as mock_param_class: + mock_param = mock_param_class.return_value + mock_param.resolve.return_value = 1 + assert c.resolve_param(mock_param) == 1 + mock_param.resolve.assert_called_once_with({}) + + +def test_resolve_param_with_other_object() -> None: + c = AbstractConfig() + assert c.resolve_param("value") == "value" diff --git a/tests/unit/experimental/pipeline/config/test_object_config.py b/tests/unit/experimental/pipeline/config/test_object_config.py new file mode 100644 index 00000000..f9cc4906 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_object_config.py @@ -0,0 +1,133 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import neo4j +from neo4j_graphrag.embeddings import Embedder, OpenAIEmbeddings +from neo4j_graphrag.experimental.pipeline.config.object_config import ( + EmbedderConfig, + EmbedderType, + LLMConfig, + LLMType, + Neo4jDriverConfig, + Neo4jDriverType, +) +from neo4j_graphrag.llm import LLMInterface, OpenAILLM + + +def test_neo4j_driver_config() -> None: + config = Neo4jDriverConfig.model_validate( + { + "params_": { + "uri": "bolt://", + "user": "a user", + "password": "a password", + } + } + ) + assert config.class_ == "not used" + assert config.params_ == { + "uri": "bolt://", + "user": "a user", + "password": "a password", + } + with patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.neo4j.GraphDatabase.driver" + ) as driver_mock: + driver_mock.return_value = "a driver" + d = config.parse() + driver_mock.assert_called_once_with("bolt://", auth=("a user", "a password")) + assert d == "a driver" # type: ignore + + +def test_neo4j_driver_type_with_driver(driver: neo4j.Driver) -> None: + driver_type = Neo4jDriverType(driver) + assert driver_type.parse() == driver + + +def test_neo4j_driver_type_with_config() -> None: + driver_type = Neo4jDriverType( + Neo4jDriverConfig( + params_={ + "uri": "bolt://", + "user": "", + "password": "", + } + ) + ) + driver = driver_type.parse() + assert isinstance(driver, neo4j.Driver) + + +def test_llm_config() -> None: + config = LLMConfig.model_validate( + { + "class_": "OpenAILLM", + "params_": {"model_name": "gpt-4o", "api_key": "my-api-key"}, + } + ) + assert config.class_ == "OpenAILLM" + assert config.get_module() == "neo4j_graphrag.llm" + assert config.get_interface() == LLMInterface + assert config.params_ == {"model_name": "gpt-4o", "api_key": "my-api-key"} + d = config.parse() + assert isinstance(d, OpenAILLM) + + +def test_llm_type_with_driver(llm: LLMInterface) -> None: + llm_type = LLMType(llm) + assert llm_type.parse() == llm + + +def test_llm_type_with_config() -> None: + llm_type = LLMType( + LLMConfig( + class_="OpenAILLM", + params_={"model_name": "gpt-4o", "api_key": "my-api-key"}, + ) + ) + llm = llm_type.parse() + assert isinstance(llm, OpenAILLM) + + +def test_embedder_config() -> None: + config = EmbedderConfig.model_validate( + { + "class_": "OpenAIEmbeddings", + "params_": {"api_key": "my-api-key"}, + } + ) + assert config.class_ == "OpenAIEmbeddings" + assert config.get_module() == "neo4j_graphrag.embeddings" + assert config.get_interface() == Embedder + assert config.params_ == {"api_key": "my-api-key"} + d = config.parse() + assert isinstance(d, OpenAIEmbeddings) + + +def test_embedder_type_with_embedder(embedder: Embedder) -> None: + embedder_type = EmbedderType(embedder) + assert embedder_type.parse() == embedder + + +def test_embedder_type_with_config() -> None: + embedder_type = EmbedderType( + EmbedderConfig( + class_="OpenAIEmbeddings", + params_={"api_key": "my-api-key"}, + ) + ) + embedder = embedder_type.parse() + assert isinstance(embedder, OpenAIEmbeddings) diff --git a/tests/unit/experimental/pipeline/config/test_param_resolver.py b/tests/unit/experimental/pipeline/config/test_param_resolver.py new file mode 100644 index 00000000..efd4d7e9 --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_param_resolver.py @@ -0,0 +1,56 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest.mock import patch + +import pytest +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamFromEnvConfig, + ParamFromKeyConfig, +) + + +@patch.dict(os.environ, {"MY_KEY": "my_value"}, clear=True) +def test_env_param_config_happy_path() -> None: + resolver = ParamFromEnvConfig(var_="MY_KEY") + assert resolver.resolve({}) == "my_value" + + +@patch.dict(os.environ, {}, clear=True) +def test_env_param_config_missing_env_var() -> None: + resolver = ParamFromEnvConfig(var_="MY_KEY") + assert resolver.resolve({}) is None + + +def test_config_key_param_simple_key() -> None: + resolver = ParamFromKeyConfig(key_="my_key") + assert resolver.resolve({"my_key": "my_value"}) == "my_value" + + +def test_config_key_param_missing_key() -> None: + resolver = ParamFromKeyConfig(key_="my_key") + with pytest.raises(KeyError): + resolver.resolve({}) + + +def test_config_complex_key_param() -> None: + resolver = ParamFromKeyConfig(key_="my_key.my_sub_key") + assert resolver.resolve({"my_key": {"my_sub_key": "value"}}) == "value" + + +def test_config_complex_key_param_missing_subkey() -> None: + resolver = ParamFromKeyConfig(key_="my_key.my_sub_key") + with pytest.raises(KeyError): + resolver.resolve({"my_key": {}}) diff --git a/tests/unit/experimental/pipeline/config/test_pipeline_config.py b/tests/unit/experimental/pipeline/config/test_pipeline_config.py new file mode 100644 index 00000000..4de5874b --- /dev/null +++ b/tests/unit/experimental/pipeline/config/test_pipeline_config.py @@ -0,0 +1,378 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock, patch + +import neo4j +from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.experimental.pipeline.config.object_config import ( + ComponentConfig, + ComponentType, + Neo4jDriverConfig, + Neo4jDriverType, +) +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( + ParamFromEnvConfig, + ParamFromKeyConfig, +) +from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( + AbstractPipelineConfig, +) +from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition +from neo4j_graphrag.llm import LLMInterface + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_params_( + mock_neo4j_config: Mock, +) -> None: + mock_neo4j_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": { + "params_": { + "uri": "bolt://", + "user": "", + "password": "", + } + } + } + ) + assert isinstance(config.neo4j_config, dict) + assert "default" in config.neo4j_config + config.parse() + mock_neo4j_config.assert_called_once() + assert config._global_data["neo4j_config"]["default"] == "text" + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_names( + mock_neo4j_config: Mock, +) -> None: + mock_neo4j_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": { + "my_driver": { + "params_": { + "uri": "bolt://", + "user": "", + "password": "", + } + } + } + } + ) + assert isinstance(config.neo4j_config, dict) + assert "my_driver" in config.neo4j_config + config.parse() + mock_neo4j_config.assert_called_once() + assert config._global_data["neo4j_config"]["my_driver"] == "text" + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_dict_with_driver( + mock_neo4j_config: Mock, driver: neo4j.Driver +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": { + "my_driver": driver, + } + } + ) + assert isinstance(config.neo4j_config, dict) + assert "my_driver" in config.neo4j_config + config.parse() + assert not mock_neo4j_config.called + assert config._global_data["neo4j_config"]["my_driver"] == driver + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverConfig.parse" +) +def test_abstract_pipeline_config_neo4j_config_is_a_driver( + mock_neo4j_config: Mock, driver: neo4j.Driver +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "neo4j_config": driver, + } + ) + assert isinstance(config.neo4j_config, dict) + assert "default" in config.neo4j_config + config.parse() + assert not mock_neo4j_config.called + assert config._global_data["neo4j_config"]["default"] == driver + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_dict_with_params_( + mock_llm_config: Mock, +) -> None: + mock_llm_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + {"llm_config": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}}} + ) + assert isinstance(config.llm_config, dict) + assert "default" in config.llm_config + config.parse() + mock_llm_config.assert_called_once() + assert config._global_data["llm_config"]["default"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_dict_with_names( + mock_llm_config: Mock, +) -> None: + mock_llm_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "llm_config": { + "my_llm": {"class_": "OpenAILLM", "params_": {"model_name": "gpt-4o"}} + } + } + ) + assert isinstance(config.llm_config, dict) + assert "my_llm" in config.llm_config + config.parse() + mock_llm_config.assert_called_once() + assert config._global_data["llm_config"]["my_llm"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_dict_with_llm( + mock_llm_config: Mock, llm: LLMInterface +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "llm_config": { + "my_llm": llm, + } + } + ) + assert isinstance(config.llm_config, dict) + assert "my_llm" in config.llm_config + config.parse() + assert not mock_llm_config.called + assert config._global_data["llm_config"]["my_llm"] == llm + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.LLMConfig.parse") +def test_abstract_pipeline_config_llm_config_is_a_llm( + mock_llm_config: Mock, llm: LLMInterface +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "llm_config": llm, + } + ) + assert isinstance(config.llm_config, dict) + assert "default" in config.llm_config + config.parse() + assert not mock_llm_config.called + assert config._global_data["llm_config"]["default"] == llm + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_a_dict_with_params_( + mock_embedder_config: Mock, +) -> None: + mock_embedder_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + {"embedder_config": {"class_": "OpenAIEmbeddings", "params_": {}}} + ) + assert isinstance(config.embedder_config, dict) + assert "default" in config.embedder_config + config.parse() + mock_embedder_config.assert_called_once() + assert config._global_data["embedder_config"]["default"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_a_dict_with_names( + mock_embedder_config: Mock, +) -> None: + mock_embedder_config.return_value = "text" + config = AbstractPipelineConfig.model_validate( + { + "embedder_config": { + "my_embedder": {"class_": "OpenAIEmbeddings", "params_": {}} + } + } + ) + assert isinstance(config.embedder_config, dict) + assert "my_embedder" in config.embedder_config + config.parse() + mock_embedder_config.assert_called_once() + assert config._global_data["embedder_config"]["my_embedder"] == "text" + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_a_dict_with_llm( + mock_embedder_config: Mock, embedder: Embedder +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "embedder_config": { + "my_embedder": embedder, + } + } + ) + assert isinstance(config.embedder_config, dict) + assert "my_embedder" in config.embedder_config + config.parse() + assert not mock_embedder_config.called + assert config._global_data["embedder_config"]["my_embedder"] == embedder + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.EmbedderConfig.parse") +def test_abstract_pipeline_config_embedder_config_is_an_embedder( + mock_embedder_config: Mock, embedder: Embedder +) -> None: + config = AbstractPipelineConfig.model_validate( + { + "embedder_config": embedder, + } + ) + assert isinstance(config.embedder_config, dict) + assert "default" in config.embedder_config + config.parse() + assert not mock_embedder_config.called + assert config._global_data["embedder_config"]["default"] == embedder + + +def test_abstract_pipeline_config_parse_global_data_no_extras(driver: Mock) -> None: + config = AbstractPipelineConfig( + neo4j_config={"my_driver": Neo4jDriverType(driver)}, + ) + gd = config._parse_global_data() + assert gd == { + "extras": {}, + "neo4j_config": { + "my_driver": driver, + }, + "llm_config": {}, + "embedder_config": {}, + } + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig.resolve" +) +def test_abstract_pipeline_config_parse_global_data_extras( + mock_param_resolver: Mock, +) -> None: + mock_param_resolver.return_value = "my value" + config = AbstractPipelineConfig( + extras={"my_extra_var": ParamFromEnvConfig(var_="some key")}, + ) + gd = config._parse_global_data() + assert gd == { + "extras": {"my_extra_var": "my value"}, + "neo4j_config": {}, + "llm_config": {}, + "embedder_config": {}, + } + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.param_resolver.ParamFromEnvConfig.resolve" +) +@patch( + "neo4j_graphrag.experimental.pipeline.config.object_config.Neo4jDriverType.parse" +) +def test_abstract_pipeline_config_parse_global_data_use_extras_in_other_config( + mock_neo4j_parser: Mock, + mock_param_resolver: Mock, +) -> None: + """Parser is able to read variables in the 'extras' section of config + to instantiate another object (neo4j.Driver in this test case) + """ + mock_param_resolver.side_effect = ["bolt://myhost", "myuser", "mypwd"] + mock_neo4j_parser.return_value = "my driver" + config = AbstractPipelineConfig( + extras={ + "my_extra_uri": ParamFromEnvConfig(var_="some key"), + "my_extra_user": ParamFromEnvConfig(var_="some key"), + "my_extra_pwd": ParamFromEnvConfig(var_="some key"), + }, + neo4j_config={ + "my_driver": Neo4jDriverType( + Neo4jDriverConfig( + params_=dict( + uri=ParamFromKeyConfig(key_="extras.my_extra_uri"), + user=ParamFromKeyConfig(key_="extras.my_extra_user"), + password=ParamFromKeyConfig(key_="extras.my_extra_pwd"), + ) + ) + ) + }, + ) + gd = config._parse_global_data() + expected_extras = { + "my_extra_uri": "bolt://myhost", + "my_extra_user": "myuser", + "my_extra_pwd": "mypwd", + } + assert gd["extras"] == expected_extras + assert gd["neo4j_config"] == {"my_driver": "my driver"} + mock_neo4j_parser.assert_called_once_with({"extras": expected_extras}) + + +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_abstract_pipeline_config_resolve_component_definition_no_run_params( + mock_component_parse: Mock, + component: Component, +) -> None: + mock_component_parse.return_value = component + config = AbstractPipelineConfig() + component_type = ComponentType(component) + component_definition = config._resolve_component_definition("name", component_type) + assert isinstance(component_definition, ComponentDefinition) + mock_component_parse.assert_called_once_with({}) + assert component_definition.name == "name" + assert component_definition.component == component + assert component_definition.run_params == {} + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.pipeline_config.AbstractPipelineConfig.resolve_params" +) +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_abstract_pipeline_config_resolve_component_definition_with_run_params( + mock_component_parse: Mock, + mock_resolve_params: Mock, + component: Component, +) -> None: + mock_component_parse.return_value = component + mock_resolve_params.return_value = {"param": "resolver param result"} + config = AbstractPipelineConfig() + component_type = ComponentType( + ComponentConfig(class_="", params_={}, run_params_={"param1": "value1"}) + ) + component_definition = config._resolve_component_definition("name", component_type) + assert isinstance(component_definition, ComponentDefinition) + mock_component_parse.assert_called_once_with({}) + assert component_definition.name == "name" + assert component_definition.component == component + assert component_definition.run_params == {"param": "resolver param result"} + mock_resolve_params.assert_called_once_with({"param1": "value1"}) diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py new file mode 100644 index 00000000..e69de29b