Skip to content

Commit

Permalink
Simplify and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 4, 2024
1 parent d192667 commit 137cd76
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
from pathlib import Path

from neo4j_graphrag.experimental.pipeline.config.config_poc import PipelineRunner
from neo4j_graphrag.experimental.pipeline.config.config_parser import PipelineRunner
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult

os.environ["NEO4J_URI"] = "bolt://localhost:7687"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from pathlib import Path

from neo4j_graphrag.experimental.pipeline.config.config_poc import PipelineRunner
from neo4j_graphrag.experimental.pipeline.config.config_parser import PipelineRunner
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult

os.environ["NEO4J_URI"] = "bolt://localhost:7687"
Expand Down
12 changes: 6 additions & 6 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

from neo4j_graphrag.exceptions import SchemaValidationError
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types import (
EntityInputType,
RelationInputType,
)


class SchemaProperty(BaseModel):
Expand Down Expand Up @@ -57,9 +61,7 @@ class SchemaEntity(BaseModel):
properties: List[SchemaProperty] = []

@classmethod
def from_text_or_dict(
cls, input: SchemaEntity | str | dict[str, Union[str, dict[str, str]]]
) -> Self:
def from_text_or_dict(cls, input: EntityInputType) -> Self:
if isinstance(input, SchemaEntity):
return input
if isinstance(input, str):
Expand All @@ -77,9 +79,7 @@ class SchemaRelation(BaseModel):
properties: List[SchemaProperty] = []

@classmethod
def from_text_or_dict(
cls, input: SchemaRelation | str | dict[str, Union[str, dict[str, str]]]
) -> Self:
def from_text_or_dict(cls, input: RelationInputType) -> Self:
if isinstance(input, SchemaRelation):
return input
if isinstance(input, str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Generic,
Literal,
Optional,
Sequence,
TypeVar,
Union,
)
Expand Down Expand Up @@ -56,9 +57,8 @@
)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline import Component, Pipeline
from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS
from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.types import (
from neo4j_graphrag.experimental.pipeline.config.config_reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
ParamConfig,
ParamToResolveConfig,
)
Expand Down Expand Up @@ -123,14 +123,7 @@ def resolve_param(self, param: ParamConfig) -> Any:
# values are already provided
return param
# all ParamToResolveConfig have a resolver_ field
resolver_name = param.resolver_
if resolver_name not in PARAM_RESOLVERS:
raise ValueError(
f"Resolver {resolver_name} not found in {PARAM_RESOLVERS.keys()}"
)
resolver_class = PARAM_RESOLVERS[resolver_name]
resolver = resolver_class(self._global_data)
return resolver.resolve(param)
return param.resolve(self._global_data)

def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]:
"""Resolve all parameters
Expand Down Expand Up @@ -438,21 +431,21 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:

def get_neo4j_driver_by_name(self, name: str) -> neo4j.Driver:
drivers = self._global_data.get("neo4j_config", {})
return drivers.get(name)
return drivers.get(name) # type: ignore[no-any-return]

def get_default_neo4j_driver(self) -> neo4j.Driver:
return self.get_neo4j_driver_by_name(self.DEFAULT_NAME)

def get_llm_by_name(self, name: str) -> LLMInterface:
llms = self._global_data.get("llm_config", {})
return llms.get(name)
return llms.get(name) # type: ignore[no-any-return]

def get_default_llm(self) -> LLMInterface:
return self.get_llm_by_name(self.DEFAULT_NAME)

def get_embedder_by_name(self, name: str) -> Embedder:
embedders = self._global_data.get("embedder_config", {})
return embedders.get(name)
return embedders.get(name) # type: ignore[no-any-return]

def get_default_embedder(self) -> Embedder:
return self.get_embedder_by_name(self.DEFAULT_NAME)
Expand Down Expand Up @@ -525,8 +518,8 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
)

from_pdf: bool = False
entities: list[EntityInputType] = []
relations: list[RelationInputType] = []
entities: Sequence[EntityInputType] = []
relations: Sequence[RelationInputType] = []
potential_schema: Optional[list[tuple[str, str, str]]] = None
on_error: OnError = OnError.IGNORE
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import neo4j
import pytest

from neo4j_graphrag.experimental.pipeline.config.config_poc import (
from neo4j_graphrag.experimental.pipeline.config.config_parser import (
AbstractPipelineConfig,
LLMConfig,
LLMType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ def read_yaml(file_path: Path) -> Any:
def _guess_format_and_read(self, file_path: Path) -> dict[str, Any]:
extension = file_path.suffix.lower()
# Note: .suffix returns an empty string if Path has no extension
# if not returning a dict, pasing will fail later on
if extension in [".json"]:
return self.read_json(file_path)
return self.read_json(file_path) # type: ignore[no-any-return]
if extension in [".yaml", ".yml"]:
return self.read_yaml(file_path)
return self.read_yaml(file_path) # type: ignore[no-any-return]
raise ValueError(f"Unsupported extension: {extension}")

def read(self, file_path: Path) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# limitations under the License.

import enum
from typing import Any, Literal, Union
import os
from typing import Any, ClassVar, Literal, Union

from pydantic import BaseModel

Expand All @@ -25,18 +26,30 @@ class ParamResolverEnum(str, enum.Enum):


class ParamToResolveConfig(BaseModel):
pass
def resolve(self, data: dict[str, Any]) -> Any:
raise NotImplementedError


class ParamFromEnvConfig(ParamToResolveConfig):
resolver_: Literal[ParamResolverEnum.ENV] = ParamResolverEnum.ENV
var_: str

def resolve(self, data: dict[str, Any]) -> Any:
return os.environ.get(self.var_)


class ParamFromKeyConfig(ParamToResolveConfig):
resolver_: Literal[ParamResolverEnum.CONFIG_KEY] = ParamResolverEnum.CONFIG_KEY
key_: str

KEY_SEP: ClassVar[str] = "."

def resolve(self, data: dict[str, Any]) -> Any:
d = data
for k in self.key_.split(self.KEY_SEP):
d = d[k]
return d


ParamConfig = Union[
float,
Expand Down
105 changes: 0 additions & 105 deletions src/neo4j_graphrag/experimental/pipeline/config/param_resolvers.py

This file was deleted.

12 changes: 6 additions & 6 deletions src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from neo4j_graphrag.embeddings import Embedder
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline.config.config_poc import (
from neo4j_graphrag.experimental.pipeline.config.config_parser import (
PipelineRunner,
SimpleKGPipelineConfig,
)
Expand Down Expand Up @@ -88,19 +88,19 @@ def __init__(
):
try:
config = SimpleKGPipelineConfig(
llm_config=llm,
neo4j_config=driver,
embedder_config=embedder,
# argument type are fixed in the Config object
llm_config=llm, # type: ignore[arg-type]
neo4j_config=driver, # type: ignore[arg-type]
embedder_config=embedder, # type: ignore[arg-type]
entities=entities or [],
relations=relations or [],
potential_schema=potential_schema,
from_pdf=from_pdf,
pdf_loader=pdf_loader,
kg_writer=kg_writer,
text_splitter=text_splitter,
on_error=on_error,
on_error=on_error, # type: ignore[arg-type]
prompt_template=prompt_template,
embedder=embedder,
perform_entity_resolution=perform_entity_resolution,
lexical_graph_config=lexical_graph_config,
neo4j_database=neo4j_database,
Expand Down
22 changes: 0 additions & 22 deletions tests/unit/experimental/pipeline/test_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,28 +226,6 @@ def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None:
assert "resolver" not in kg_builder.runner.pipeline


@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
return_value=(5, 23, 0),
)
@pytest.mark.asyncio
def test_simple_kg_pipeline_lexical_graph_config_attribute(_: Mock) -> None:
llm = MagicMock(spec=LLMInterface)
driver = MagicMock(spec=neo4j.Driver)
embedder = MagicMock(spec=Embedder)

lexical_graph_config = LexicalGraphConfig()
kg_builder = SimpleKGPipeline(
llm=llm,
driver=driver,
embedder=embedder,
on_error="IGNORE",
lexical_graph_config=lexical_graph_config,
)

assert kg_builder.runner.config.lexical_graph_config == lexical_graph_config


@mock.patch(
"neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version",
return_value=(5, 23, 0),
Expand Down

0 comments on commit 137cd76

Please sign in to comment.