diff --git a/examples/customize/build_graph/pipeline/pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/pipeline_from_config_file.py index 6f205267..917efefa 100644 --- a/examples/customize/build_graph/pipeline/pipeline_from_config_file.py +++ b/examples/customize/build_graph/pipeline/pipeline_from_config_file.py @@ -13,7 +13,7 @@ import os from pathlib import Path -from neo4j_graphrag.experimental.pipeline.config.config_poc import PipelineRunner +from neo4j_graphrag.experimental.pipeline.config.config_parser import PipelineRunner from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult os.environ["NEO4J_URI"] = "bolt://localhost:7687" diff --git a/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py b/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py index 19c74d99..837ce5d7 100644 --- a/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py +++ b/examples/customize/build_graph/pipeline/simple_kg_pipeline_from_config_file.py @@ -15,7 +15,7 @@ import os from pathlib import Path -from neo4j_graphrag.experimental.pipeline.config.config_poc import PipelineRunner +from neo4j_graphrag.experimental.pipeline.config.config_parser import PipelineRunner from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult os.environ["NEO4J_URI"] = "bolt://localhost:7687" diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index ed6e0bdc..c22b00c5 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -21,6 +21,10 @@ from neo4j_graphrag.exceptions import SchemaValidationError from neo4j_graphrag.experimental.pipeline.component import Component, DataModel +from neo4j_graphrag.experimental.pipeline.types import ( + EntityInputType, + RelationInputType, +) class SchemaProperty(BaseModel): @@ -57,9 +61,7 @@ class SchemaEntity(BaseModel): properties: List[SchemaProperty] = [] @classmethod - def from_text_or_dict( - cls, input: SchemaEntity | str | dict[str, Union[str, dict[str, str]]] - ) -> Self: + def from_text_or_dict(cls, input: EntityInputType) -> Self: if isinstance(input, SchemaEntity): return input if isinstance(input, str): @@ -77,9 +79,7 @@ class SchemaRelation(BaseModel): properties: List[SchemaProperty] = [] @classmethod - def from_text_or_dict( - cls, input: SchemaRelation | str | dict[str, Union[str, dict[str, str]]] - ) -> Self: + def from_text_or_dict(cls, input: RelationInputType) -> Self: if isinstance(input, SchemaRelation): return input if isinstance(input, str): diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py b/src/neo4j_graphrag/experimental/pipeline/config/config_parser.py similarity index 97% rename from src/neo4j_graphrag/experimental/pipeline/config/config_poc.py rename to src/neo4j_graphrag/experimental/pipeline/config/config_parser.py index bd75f28f..ed3bfd0c 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_parser.py @@ -14,6 +14,7 @@ Generic, Literal, Optional, + Sequence, TypeVar, Union, ) @@ -56,9 +57,8 @@ ) from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline import Component, Pipeline -from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS -from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader -from neo4j_graphrag.experimental.pipeline.config.types import ( +from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader +from neo4j_graphrag.experimental.pipeline.config.param_resolver import ( ParamConfig, ParamToResolveConfig, ) @@ -123,14 +123,7 @@ def resolve_param(self, param: ParamConfig) -> Any: # values are already provided return param # all ParamToResolveConfig have a resolver_ field - resolver_name = param.resolver_ - if resolver_name not in PARAM_RESOLVERS: - raise ValueError( - f"Resolver {resolver_name} not found in {PARAM_RESOLVERS.keys()}" - ) - resolver_class = PARAM_RESOLVERS[resolver_name] - resolver = resolver_class(self._global_data) - return resolver.resolve(param) + return param.resolve(self._global_data) def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]: """Resolve all parameters @@ -438,21 +431,21 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver: drivers = self._global_data.get("neo4j_config", {}) - return drivers.get(name) + return drivers.get(name) # type: ignore[no-any-return] def get_default_neo4j_driver(self) -> neo4j.Driver: return self.get_neo4j_driver_by_name(self.DEFAULT_NAME) def get_llm_by_name(self, name: str) -> LLMInterface: llms = self._global_data.get("llm_config", {}) - return llms.get(name) + return llms.get(name) # type: ignore[no-any-return] def get_default_llm(self) -> LLMInterface: return self.get_llm_by_name(self.DEFAULT_NAME) def get_embedder_by_name(self, name: str) -> Embedder: embedders = self._global_data.get("embedder_config", {}) - return embedders.get(name) + return embedders.get(name) # type: ignore[no-any-return] def get_default_embedder(self) -> Embedder: return self.get_embedder_by_name(self.DEFAULT_NAME) @@ -525,8 +518,8 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): ) from_pdf: bool = False - entities: list[EntityInputType] = [] - relations: list[RelationInputType] = [] + entities: Sequence[EntityInputType] = [] + relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py b/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py index 1ca2ac43..6820fd5a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_poc_test.py @@ -3,7 +3,7 @@ import neo4j import pytest -from neo4j_graphrag.experimental.pipeline.config.config_poc import ( +from neo4j_graphrag.experimental.pipeline.config.config_parser import ( AbstractPipelineConfig, LLMConfig, LLMType, diff --git a/src/neo4j_graphrag/experimental/pipeline/config/reader.py b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py similarity index 90% rename from src/neo4j_graphrag/experimental/pipeline/config/reader.py rename to src/neo4j_graphrag/experimental/pipeline/config/config_reader.py index 9b489462..b1bacc37 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/reader.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_reader.py @@ -61,10 +61,11 @@ def read_yaml(file_path: Path) -> Any: def _guess_format_and_read(self, file_path: Path) -> dict[str, Any]: extension = file_path.suffix.lower() # Note: .suffix returns an empty string if Path has no extension + # if not returning a dict, pasing will fail later on if extension in [".json"]: - return self.read_json(file_path) + return self.read_json(file_path) # type: ignore[no-any-return] if extension in [".yaml", ".yml"]: - return self.read_yaml(file_path) + return self.read_yaml(file_path) # type: ignore[no-any-return] raise ValueError(f"Unsupported extension: {extension}") def read(self, file_path: Path) -> dict[str, Any]: diff --git a/src/neo4j_graphrag/experimental/pipeline/config/types.py b/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py similarity index 73% rename from src/neo4j_graphrag/experimental/pipeline/config/types.py rename to src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py index 276f8030..24190add 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/param_resolver.py @@ -14,7 +14,8 @@ # limitations under the License. import enum -from typing import Any, Literal, Union +import os +from typing import Any, ClassVar, Literal, Union from pydantic import BaseModel @@ -25,18 +26,30 @@ class ParamResolverEnum(str, enum.Enum): class ParamToResolveConfig(BaseModel): - pass + def resolve(self, data: dict[str, Any]) -> Any: + raise NotImplementedError class ParamFromEnvConfig(ParamToResolveConfig): resolver_: Literal[ParamResolverEnum.ENV] = ParamResolverEnum.ENV var_: str + def resolve(self, data: dict[str, Any]) -> Any: + return os.environ.get(self.var_) + class ParamFromKeyConfig(ParamToResolveConfig): resolver_: Literal[ParamResolverEnum.CONFIG_KEY] = ParamResolverEnum.CONFIG_KEY key_: str + KEY_SEP: ClassVar[str] = "." + + def resolve(self, data: dict[str, Any]) -> Any: + d = data + for k in self.key_.split(self.KEY_SEP): + d = d[k] + return d + ParamConfig = Union[ float, diff --git a/src/neo4j_graphrag/experimental/pipeline/config/param_resolvers.py b/src/neo4j_graphrag/experimental/pipeline/config/param_resolvers.py deleted file mode 100644 index 117303e4..00000000 --- a/src/neo4j_graphrag/experimental/pipeline/config/param_resolvers.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Any - -from .types import ( - ParamFromEnvConfig, - ParamFromKeyConfig, - ParamResolverEnum, - ParamToResolveConfig, -) - - -class ParamResolver: - """A base class for all parameter resolvers.""" - - name: ParamResolverEnum - - def __init__(self, data: dict[str, Any]) -> None: - self.data = data - - def resolve(self, param: ParamToResolveConfig) -> Any: - raise NotImplementedError - - -class EnvParamResolver(ParamResolver): - """Resolve a parameter by reading its value - in the environment variables. - - Example: - - .. code-block:: python - - import os - os.environ["MY_ENV_VAR"] = "LOCAL" - - resolver = EnvParamResolver() - resolver.resolve("MY_ENV_VAR") - # Output: "LOCAL" - """ - - name = ParamResolverEnum.ENV - - def resolve(self, param: ParamFromEnvConfig) -> Any: - return os.environ.get(param.var_) - - -class ConfigKeyParamResolver(ParamResolver): - """Resolve a parameter by searching through the - config file. A parameter is defined by a `key_`. - - It is possible to access nested keys by separating - each key with dots. For instance: - - Example: - - .. code-block:: python - - data = { - "shared": { - "env": "LOCAL" - }, - "section": { - "env": { - "resolver_": "KEY", - "key_": "shared.env" - } - } - } - - resolver = ConfigKeyParamResolver(data) - resolver.resolve("shared.env") - # Output: "LOCAL" - """ - - name = ParamResolverEnum.CONFIG_KEY - KEY_SEP = "." - - def resolve(self, param: ParamFromKeyConfig) -> Any: - d = self.data - for k in param.key_.split(self.KEY_SEP): - d = d[k] - return d - - -PARAM_RESOLVERS = { - resolver.name: resolver - for resolver in [ - EnvParamResolver, - ConfigKeyParamResolver, - ] -} diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 26c71d7b..db4d4e8d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -22,7 +22,7 @@ from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.types import LexicalGraphConfig -from neo4j_graphrag.experimental.pipeline.config.config_poc import ( +from neo4j_graphrag.experimental.pipeline.config.config_parser import ( PipelineRunner, SimpleKGPipelineConfig, ) @@ -88,9 +88,10 @@ def __init__( ): try: config = SimpleKGPipelineConfig( - llm_config=llm, - neo4j_config=driver, - embedder_config=embedder, + # argument type are fixed in the Config object + llm_config=llm, # type: ignore[arg-type] + neo4j_config=driver, # type: ignore[arg-type] + embedder_config=embedder, # type: ignore[arg-type] entities=entities or [], relations=relations or [], potential_schema=potential_schema, @@ -98,9 +99,8 @@ def __init__( pdf_loader=pdf_loader, kg_writer=kg_writer, text_splitter=text_splitter, - on_error=on_error, + on_error=on_error, # type: ignore[arg-type] prompt_template=prompt_template, - embedder=embedder, perform_entity_resolution=perform_entity_resolution, lexical_graph_config=lexical_graph_config, neo4j_database=neo4j_database, diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index 2c80c426..aba443ca 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -226,28 +226,6 @@ def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None: assert "resolver" not in kg_builder.runner.pipeline -@mock.patch( - "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", - return_value=(5, 23, 0), -) -@pytest.mark.asyncio -def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - lexical_graph_config = LexicalGraphConfig() - kg_builder = SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - on_error="IGNORE", - lexical_graph_config=lexical_graph_config, - ) - - assert kg_builder.runner.config.lexical_graph_config == lexical_graph_config - - @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", return_value=(5, 23, 0),