Skip to content

Commit

Permalink
Add root types to allow instantiation from python object directly
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 3, 2024
1 parent 1d884a1 commit e2eb5d0
Showing 1 changed file with 119 additions and 52 deletions.
171 changes: 119 additions & 52 deletions src/neo4j_graphrag/experimental/pipeline/config/config_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
import importlib
from collections import Counter
from pathlib import Path
from typing import Annotated, Any, ClassVar, Literal, Optional, Self, Type, Union
from typing import Annotated, Any, ClassVar, Literal, Optional, Self, Union

import neo4j
from pydantic import BaseModel, Discriminator, Extra, Field, Tag, field_validator
from pydantic.v1.utils import deep_update
from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
Field,
PrivateAttr,
RootModel,
Tag,
field_validator,
)
from pydantic.utils import deep_update

from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.pipeline import Component, Pipeline
Expand All @@ -30,7 +39,7 @@
from neo4j_graphrag.llm import LLMInterface


class AbstractConfig(BaseModel, abc.ABC, extra=Extra.allow):
class AbstractConfig(BaseModel, abc.ABC):
"""Base class for all configs.
Provides methods to get a class from a string and resolve a parameter defined by
a dict with a 'resolver_' key.
Expand All @@ -40,9 +49,7 @@ class AbstractConfig(BaseModel, abc.ABC, extra=Extra.allow):

RESOLVER_KEY: ClassVar[str] = "resolver_"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.global_data = {}
_global_data: dict[str, Any] = PrivateAttr({})

@classmethod
def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type:
Expand Down Expand Up @@ -72,7 +79,7 @@ def resolve_param(self, param: ParamConfig) -> Any:
f"Resolver {resolver_name} not found in {PARAM_RESOLVERS.keys()}"
)
resolver_klass = PARAM_RESOLVERS[resolver_name]
resolver = resolver_klass(self.global_data)
resolver = resolver_klass(self._global_data)
return resolver.resolve(param)

def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]:
Expand All @@ -95,10 +102,8 @@ class ObjectConfig(AbstractConfig):
to uniquely identify them.
"""

class_: str | None = None
class_: str | None = Field(default=None, validate_default=True)
"""Path to class to be instantiated."""
name_: str = "default"
"""Object name in an array of objects."""
params_: dict[str, ParamConfig] = {}
"""Initialization parameters."""

Expand Down Expand Up @@ -157,56 +162,90 @@ def parse(self) -> neo4j.Driver:
return driver


class Neo4jDriverType(RootModel):
root: Union[neo4j.Driver, Neo4jDriverConfig]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self) -> neo4j.Driver:
if isinstance(self.root, neo4j.Driver):
return self.root
return self.root.parse()


class LLMConfig(ObjectConfig):
DEFAULT_MODULE = "neo4j_graphrag.llm"
INTERFACE = LLMInterface


class LLMType(RootModel):
root: Union[LLMInterface, LLMConfig]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self) -> LLMInterface:
if isinstance(self.root, LLMInterface):
return self.root
return self.root.parse()


class ComponentConfig(ObjectConfig):
run_params_: dict[str, ParamConfig] = {}

DEFAULT_MODULE = "neo4j_graphrag.experimental.components"
INTERFACE = Component


class ComponentType(RootModel):
root: Union[Component, ComponentConfig]

model_config = ConfigDict(arbitrary_types_allowed=True)


class PipelineTemplateType(str, enum.Enum):
NONE = "none"
SIMPLE_KG_PIPELINE = "SimpleKGPipeline"


class AbstractPipelineConfig(AbstractConfig, abc.ABC):
neo4j_config: list[Neo4jDriverConfig]
llm_config: list[LLMConfig]
class AbstractPipelineConfig(AbstractConfig):
neo4j_config: dict[str, Neo4jDriverType] = {}
llm_config: dict[str, LLMConfig] = {}
# extra parameters values that can be used in different places of the config file
extras: dict[str, Any] = {}

DEFAULT_NAME: ClassVar[str] = "default"

@field_validator("neo4j_config", mode="before")
@classmethod
def validate_drivers(cls, drivers: Union[Any, list[Any]]) -> list[Any]:
if not isinstance(drivers, list):
drivers = [drivers]
return drivers # type: ignore[no-any-return]
def validate_drivers(
cls, drivers: Union[Neo4jDriverType, dict[str, Neo4jDriverType]]
) -> dict[str, Neo4jDriverType]:
if not isinstance(drivers, dict) or "params_" in drivers:
return {cls.DEFAULT_NAME: drivers}
return drivers

@field_validator("llm_config", mode="before")
@classmethod
def validate_llms(cls, llms: Union[Any, list[Any]]) -> list[Any]:
if not isinstance(llms, list):
llms = [llms]
return llms # type: ignore[no-any-return]
def validate_llms(
cls, llms: Union[LLMType, dict[str, LLMType]]
) -> dict[str, LLMType]:
if not isinstance(llms, dict) or "params_" in llms:
return {cls.DEFAULT_NAME: llms}
return llms

@field_validator("llm_config", "neo4j_config", mode="after")
@classmethod
def validate_names(cls, lst: list[Any]) -> list[Any]:
def validate_names(cls, lst: dict[str, Any]) -> dict[str, Any]:
if not lst:
return lst
c = Counter([item.name_ for item in lst])
c = Counter(lst.keys())
most_common_item = c.most_common(1)
most_common_count = most_common_item[0][1]
if most_common_count > 1:
raise ValueError(f"names must be unique {most_common_item}")
return lst

def _resolve_component(self, config: ComponentConfig) -> ComponentDefinition:
def _resolve_component(self, config: ComponentConfig) -> Component:
klass_path = config.class_
if klass_path is None:
raise ValueError(f"Class {klass_path} is not defined")
Expand All @@ -218,38 +257,47 @@ def _resolve_component(self, config: ComponentConfig) -> ComponentDefinition:
raise ValueError(f"Component '{klass_path}' not found")
component_init_params = self.resolve_params(config.params_)
component = klass(**component_init_params)
component_run_params = self.resolve_params(config.run_params_)
return component

def _resolve_component_definition(
self, name: str, config: ComponentType
) -> ComponentDefinition:
component = config.root
component_run_params = {}
if not isinstance(component, Component):
component = self._resolve_component(config.root)
component_run_params = self.resolve_params(config.root.run_params_)
return ComponentDefinition(
name=config.name_,
name=name,
component=component,
run_params=component_run_params,
)

def _parse_global_data(self) -> dict[str, Any]:
drivers = {d.name_: d.parse() for d in self.neo4j_config}
llms = {llm.name_: llm.parse() for llm in self.llm_config}
drivers = {d: config.parse() for d, config in self.neo4j_config.items()}
llms = {llm: config.parse() for llm, config in self.llm_config.items()}
return {
"neo4j_config": drivers,
"llm_config": llms,
"extras": self.resolve_params(self.extras),
}

@abc.abstractmethod
def _get_components(self) -> list[ComponentDefinition]: ...
def _get_components(self) -> list[ComponentDefinition]:
return []

@abc.abstractmethod
def _get_connections(self) -> list[ConnectionDefinition]: ...
def _get_connections(self) -> list[ConnectionDefinition]:
return []

def parse(self) -> PipelineDefinition:
self.global_data = self._parse_global_data()
self._global_data = self._parse_global_data()
return PipelineDefinition(
components=self._get_components(),
connections=self._get_connections(),
)


class PipelineConfig(AbstractPipelineConfig):
component_config: list[ComponentConfig]
component_config: dict[str, ComponentType]
connection_config: list[ConnectionDefinition]
template_: Literal[PipelineTemplateType.NONE] = PipelineTemplateType.NONE

Expand All @@ -258,12 +306,12 @@ def _get_connections(self) -> list[ConnectionDefinition]:

def _get_components(self) -> list[ComponentDefinition]:
return [
self._resolve_component(component_config)
for component_config in self.component_config
self._resolve_component_definition(name, component_config)
for name, component_config in self.component_config.items()
]


class TemplatePipelineConfig(AbstractPipelineConfig, abc.ABC):
class TemplatePipelineConfig(AbstractPipelineConfig):
COMPONENTS: ClassVar[list[str]] = []

def _get_components(self) -> list[ComponentDefinition]:
Expand All @@ -288,11 +336,11 @@ def _get_components(self) -> list[ComponentDefinition]:
class SimpleKGPipelineConfig(TemplatePipelineConfig):
COMPONENTS: ClassVar[list[str]] = [
"pdf_loader",
"splitter",
"chunk_embedder",
"extractor",
"writer",
"entity_resolver",
# "splitter",
# "chunk_embedder",
# "extractor",
# "writer",
# "entity_resolver",
]

template_: Literal[PipelineTemplateType.SIMPLE_KG_PIPELINE] = (
Expand All @@ -312,26 +360,27 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
# entities: list[SchemaEntity] = []
# relations: list[SchemaRelation] = []

def get_pdf_loader(self) -> Component | None:
def _get_pdf_loader(self) -> Component | None:
if not self.from_pdf:
return None
if self.pdf_loader:
return self._resolve_component(self.pdf_loader)
return PdfLoader()


def get_discriminator_value(model: dict[str, Any]) -> str:
def get_discriminator_value(model: Any) -> PipelineTemplateType:
template_ = None
if "template_" in model:
return model["template_"] or PipelineTemplateType.SIMPLE_KG_PIPELINE.value
return PipelineTemplateType.NONE.value
template_ = model["template_"]
if hasattr(model, "template_"):
template_ = model.template_
return PipelineTemplateType(template_) or PipelineTemplateType.NONE


class PipelineConfigWrapper(BaseModel):
config: Union[
Annotated[PipelineConfig, Tag(PipelineTemplateType.NONE.value)],
Annotated[
SimpleKGPipelineConfig, Tag(PipelineTemplateType.SIMPLE_KG_PIPELINE.value)
],
Annotated[PipelineConfig, Tag(PipelineTemplateType.NONE)],
Annotated[SimpleKGPipelineConfig, Tag(PipelineTemplateType.SIMPLE_KG_PIPELINE)],
] = Field(discriminator=Discriminator(get_discriminator_value))

def parse(self) -> PipelineDefinition:
Expand All @@ -343,6 +392,11 @@ def __init__(self, pipeline_definition: PipelineDefinition) -> None:
self.pipeline = Pipeline.from_definition(pipeline_definition)
self.run_params = pipeline_definition.get_run_params()

@classmethod
def from_config(cls, config: AbstractPipelineConfig) -> Self:
wrapper = PipelineConfigWrapper.model_validate({"config": config})
return cls(wrapper.parse())

@classmethod
def from_config_file(cls, file_path: Union[str, Path]) -> Self:
pipeline_definition = cls._parse(file_path)
Expand Down Expand Up @@ -370,7 +424,20 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:

file_path = "examples/customize/build_graph/pipeline/pipeline_config.json"
runner = PipelineRunner.from_config_file(file_path)
print(asyncio.run(runner.run({"splitter": {"text": "blabla"}})))
print(runner)
# print(asyncio.run(runner.run({"splitter": {"text": "blabla"}})))

config = SimpleKGPipelineConfig.model_validate(
{
"template_": PipelineTemplateType.SIMPLE_KG_PIPELINE.value,
"neo4j_config": neo4j.GraphDatabase.driver("bolt://", auth=("", "")),
"from_pdf": True,
}
)
print(config)
runner = PipelineRunner.from_config(config)
print(runner.pipeline._nodes)


"""
{
Expand Down

0 comments on commit e2eb5d0

Please sign in to comment.