Skip to content

Commit

Permalink
A bit of mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 1, 2024
1 parent 668fbc2 commit 1d884a1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
77 changes: 40 additions & 37 deletions src/neo4j_graphrag/experimental/pipeline/config/config_poc.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
"""Generic config for all pipelines + specific implementation for "templates"
such as the SimpleKGPipeline.
"""

import abc
import enum
import importlib
from collections import Counter
from pathlib import Path
from typing import Any, Optional, Type, ClassVar, Union, Annotated, Literal, Self
from typing import Annotated, Any, ClassVar, Literal, Optional, Self, Type, Union

import neo4j
from pydantic import BaseModel, field_validator, Tag, Discriminator, Field, Extra
from pydantic import BaseModel, Discriminator, Extra, Field, Tag, field_validator
from pydantic.v1.utils import deep_update

from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.pipeline import Component, Pipeline
from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS
from neo4j_graphrag.experimental.pipeline.config.types import ParamConfig, \
ParamToResolveConfig
from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition, \
PipelineDefinition, ConnectionDefinition
from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.types import (
ParamConfig,
ParamToResolveConfig,
)
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
from neo4j_graphrag.experimental.pipeline.types import (
ComponentDefinition,
ConnectionDefinition,
PipelineDefinition,
)
from neo4j_graphrag.llm import LLMInterface


Expand All @@ -29,9 +37,10 @@ class AbstractConfig(BaseModel, abc.ABC, extra=Extra.allow):
Each subclass must implement a 'parse' method that returns the relevant object.
"""

RESOLVER_KEY: ClassVar[str] = "resolver_"

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.global_data = {}

Expand All @@ -52,7 +61,7 @@ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> t
full_klass_path = optional_module + "." + class_path
return cls._get_class(full_klass_path)
raise ValueError(f"Could not find {class_name} in {module_name}")
return klass
return klass # type: ignore[no-any-return]

def resolve_param(self, param: ParamConfig) -> Any:
if not isinstance(param, ParamToResolveConfig):
Expand Down Expand Up @@ -85,6 +94,7 @@ class ObjectConfig(AbstractConfig):
Since they can be included in a list, objects must have a name
to uniquely identify them.
"""

class_: str | None = None
"""Path to class to be instantiated."""
name_: str = "default"
Expand All @@ -94,7 +104,7 @@ class ObjectConfig(AbstractConfig):

DEFAULT_MODULE: ClassVar[str] = "."
"""Default module to import the class from."""
INTERFACE: ClassVar[Type] = object
INTERFACE: ClassVar[type] = object
"""Constraint on the class (must be a subclass of)."""
REQUIRED_PARAMS: ClassVar[list[str]] = []
"""List of required parameters for this object constructor."""
Expand All @@ -110,7 +120,7 @@ def validate_params(cls, params_: dict[str, Any]) -> dict[str, Any]:
def get_module(self) -> str:
return self.DEFAULT_MODULE

def get_interface(self) -> Type:
def get_interface(self) -> type:
return self.INTERFACE

def parse(self) -> Any:
Expand All @@ -119,7 +129,8 @@ def parse(self) -> Any:
klass = self._get_class(self.class_, self.get_module())
if not issubclass(klass, self.get_interface()):
raise ValueError(
f"Invalid class {klass}. Expected a subclass of {self.get_interface()}")
f"Invalid class {klass}. Expected a subclass of {self.get_interface()}"
)
params = self.resolve_params(self.params_)
obj = klass(**params)
return obj
Expand All @@ -142,11 +153,7 @@ def parse(self) -> neo4j.Driver:
uri = params.pop("uri")
user = params.pop("user")
password = params.pop("password")
driver = neo4j.GraphDatabase.driver(
uri,
auth=(user, password),
**params
)
driver = neo4j.GraphDatabase.driver(uri, auth=(user, password), **params)
return driver


Expand Down Expand Up @@ -178,14 +185,14 @@ class AbstractPipelineConfig(AbstractConfig, abc.ABC):
def validate_drivers(cls, drivers: Union[Any, list[Any]]) -> list[Any]:
if not isinstance(drivers, list):
drivers = [drivers]
return drivers
return drivers # type: ignore[no-any-return]

@field_validator("llm_config", mode="before")
@classmethod
def validate_llms(cls, llms: Union[Any, list[Any]]) -> list[Any]:
if not isinstance(llms, list):
llms = [llms]
return llms
return llms # type: ignore[no-any-return]

@field_validator("llm_config", "neo4j_config", mode="after")
@classmethod
Expand All @@ -201,6 +208,8 @@ def validate_names(cls, lst: list[Any]) -> list[Any]:

def _resolve_component(self, config: ComponentConfig) -> ComponentDefinition:
klass_path = config.class_
if klass_path is None:
raise ValueError(f"Class {klass_path} is not defined")
try:
klass = self._get_class(
klass_path, optional_module="neo4j_graphrag.experimental.components"
Expand All @@ -217,25 +226,19 @@ def _resolve_component(self, config: ComponentConfig) -> ComponentDefinition:
)

def _parse_global_data(self) -> dict[str, Any]:
drivers = {
d.name_: d.parse() for d in self.neo4j_config
}
llms = {
llm.name_: llm.parse() for llm in self.llm_config
}
drivers = {d.name_: d.parse() for d in self.neo4j_config}
llms = {llm.name_: llm.parse() for llm in self.llm_config}
return {
"neo4j_config": drivers,
"llm_config": llms,
"extras": self.resolve_params(self.extras),
}

@abc.abstractmethod
def _get_components(self) -> list[ComponentDefinition]:
...
def _get_components(self) -> list[ComponentDefinition]: ...

@abc.abstractmethod
def _get_connections(self) -> list[ConnectionDefinition]:
...
def _get_connections(self) -> list[ConnectionDefinition]: ...

def parse(self) -> PipelineDefinition:
self.global_data = self._parse_global_data()
Expand Down Expand Up @@ -314,6 +317,7 @@ def get_pdf_loader(self) -> Component | None:
return None
if self.pdf_loader:
return self._resolve_component(self.pdf_loader)
return PdfLoader()


def get_discriminator_value(model: dict[str, Any]) -> str:
Expand All @@ -325,7 +329,9 @@ def get_discriminator_value(model: dict[str, Any]) -> str:
class PipelineConfigWrapper(BaseModel):
config: Union[
Annotated[PipelineConfig, Tag(PipelineTemplateType.NONE.value)],
Annotated[SimpleKGPipelineConfig, Tag(PipelineTemplateType.SIMPLE_KG_PIPELINE.value)],
Annotated[
SimpleKGPipelineConfig, Tag(PipelineTemplateType.SIMPLE_KG_PIPELINE.value)
],
] = Field(discriminator=Discriminator(get_discriminator_value))

def parse(self) -> PipelineDefinition:
Expand All @@ -350,24 +356,21 @@ def _parse(cls, file_path: Union[str, Path]) -> PipelineDefinition:
wrapper = PipelineConfigWrapper.model_validate({"config": data})
return wrapper.parse()

async def run(self, data):
async def run(self, data: dict[str, Any]) -> PipelineResult:
run_param = deep_update(self.run_params, data)
return await self.pipeline.run(data=run_param)


if __name__ == "__main__":

from dotenv import load_dotenv

load_dotenv()

import asyncio

file_path = "examples/customize/build_graph/pipeline/pipeline_config.json"
runner = PipelineRunner.from_config_file(file_path)
print(
asyncio.run(
runner.run({"splitter": {"text": "blabla"}})
)
)
print(asyncio.run(runner.run({"splitter": {"text": "blabla"}})))

"""
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
import os
from typing import Any

from .types import ParamFromEnvConfig, ParamResolverEnum, ParamToResolveConfig, \
ParamFromKeyConfig
from .types import (
ParamFromEnvConfig,
ParamFromKeyConfig,
ParamResolverEnum,
ParamToResolveConfig,
)


class ParamResolver:
Expand Down

0 comments on commit 1d884a1

Please sign in to comment.