Skip to content

Commit

Permalink
Adds more param resolvers
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 1, 2024
1 parent 536c9ff commit 668fbc2
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
from typing import Any, Optional, Type, ClassVar, Union, Annotated, Literal, Self

import neo4j
from pydantic import BaseModel, field_validator, Tag, Discriminator, Field
from pydantic import BaseModel, field_validator, Tag, Discriminator, Field, Extra
from pydantic.v1.utils import deep_update

from neo4j_graphrag.experimental.pipeline import Component, Pipeline
from neo4j_graphrag.experimental.pipeline.config.reader import ConfigReader
from neo4j_graphrag.experimental.pipeline.config.template_parser.param_resolvers import PARAM_RESOLVERS
from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS
from neo4j_graphrag.experimental.pipeline.config.types import ParamConfig, \
ParamToResolveConfig
from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition, \
PipelineDefinition
PipelineDefinition, ConnectionDefinition
from neo4j_graphrag.llm import LLMInterface


class AbstractConfig(BaseModel, abc.ABC):
class AbstractConfig(BaseModel, abc.ABC, extra=Extra.allow):
"""Base class for all configs.
Provides methods to get a class from a string and resolve a parameter defined by
a dict with a 'resolver_' key.
Expand All @@ -28,6 +31,10 @@ class AbstractConfig(BaseModel, abc.ABC):
"""
RESOLVER_KEY: ClassVar[str] = "resolver_"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.global_data = {}

@classmethod
def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type:
"""Get class from string and an optional module"""
Expand All @@ -47,19 +54,24 @@ def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> t
raise ValueError(f"Could not find {class_name} in {module_name}")
return klass

@classmethod
def resolve_params(cls, param: dict[str, Any], global_data: dict[str, Any]) -> Any:
"""Resolve parameter"""
# if isinstance(param, list):
# return [cls.resolve_params(p, global_data) for p in param]
if not isinstance(param, dict):
def resolve_param(self, param: ParamConfig) -> Any:
if not isinstance(param, ParamToResolveConfig):
return param
if not cls.RESOLVER_KEY in param:
return {key: cls.resolve_params(param[key], global_data) for key in param}
resolver_name = param.pop(cls.RESOLVER_KEY)
resolver_name = param.resolver_
if resolver_name not in PARAM_RESOLVERS:
raise ValueError(
f"Resolver {resolver_name} not found in {PARAM_RESOLVERS.keys()}"
)
resolver_klass = PARAM_RESOLVERS[resolver_name]
resolver = resolver_klass(global_data)
return resolver.resolve(**param)
resolver = resolver_klass(self.global_data)
return resolver.resolve(param)

def resolve_params(self, params: dict[str, ParamConfig]) -> dict[str, Any]:
"""Resolve parameters"""
return {
param_name: self.resolve_param(param)
for param_name, param in params.items()
}

@abc.abstractmethod
def parse(self) -> Any:
Expand All @@ -77,7 +89,7 @@ class ObjectConfig(AbstractConfig):
"""Path to class to be instantiated."""
name_: str = "default"
"""Object name in an array of objects."""
params_: dict[str, Any] = {}
params_: dict[str, ParamConfig] = {}
"""Initialization parameters."""

DEFAULT_MODULE: ClassVar[str] = "."
Expand Down Expand Up @@ -108,7 +120,7 @@ def parse(self) -> Any:
if not issubclass(klass, self.get_interface()):
raise ValueError(
f"Invalid class {klass}. Expected a subclass of {self.get_interface()}")
params = self.resolve_params(self.params_, {})
params = self.resolve_params(self.params_)
obj = klass(**params)
return obj

Expand All @@ -126,7 +138,7 @@ def validate_class(cls, class_: Any) -> str:
return "not used"

def parse(self) -> neo4j.Driver:
params = self.resolve_params(self.params_, {})
params = self.resolve_params(self.params_)
uri = params.pop("uri")
user = params.pop("user")
password = params.pop("password")
Expand All @@ -144,7 +156,7 @@ class LLMConfig(ObjectConfig):


class ComponentConfig(ObjectConfig):
run_params_: dict[str, Any] = {}
run_params_: dict[str, ParamConfig] = {}

DEFAULT_MODULE = "neo4j_graphrag.experimental.components"
INTERFACE = Component
Expand All @@ -158,6 +170,7 @@ class PipelineTemplateType(str, enum.Enum):
class AbstractPipelineConfig(AbstractConfig, abc.ABC):
neo4j_config: list[Neo4jDriverConfig]
llm_config: list[LLMConfig]
# extra parameters values that can be used in different places of the config file
extras: dict[str, Any] = {}

@field_validator("neo4j_config", mode="before")
Expand Down Expand Up @@ -186,17 +199,17 @@ def validate_names(cls, lst: list[Any]) -> list[Any]:
raise ValueError(f"names must be unique {most_common_item}")
return lst

def _resolve_component(self, config: ComponentConfig, global_data: dict[str, Any]) -> ComponentDefinition:
def _resolve_component(self, config: ComponentConfig) -> ComponentDefinition:
klass_path = config.class_
try:
klass = self._get_class(
klass_path, optional_module="neo4j_graphrag.experimental.components"
)
except ValueError:
raise ValueError(f"Component '{klass_path}' not found")
component_init_params = self.resolve_params(config.params_, global_data)
component_init_params = self.resolve_params(config.params_)
component = klass(**component_init_params)
component_run_params = self.resolve_params(config.run_params_, global_data)
component_run_params = self.resolve_params(config.run_params_)
return ComponentDefinition(
name=config.name_,
component=component,
Expand All @@ -213,36 +226,44 @@ def _parse_global_data(self) -> dict[str, Any]:
return {
"neo4j_config": drivers,
"llm_config": llms,
"extras": self.resolve_params(self.extras, {}),
"extras": self.resolve_params(self.extras),
}

@abc.abstractmethod
def _get_components(self, global_data: dict[str, Any]) -> list[ComponentDefinition]:
def _get_components(self) -> list[ComponentDefinition]:
...

@abc.abstractmethod
def _get_connections(self) -> list[ConnectionDefinition]:
...

def parse(self) -> PipelineDefinition:
global_data = self._parse_global_data()
self.global_data = self._parse_global_data()
return PipelineDefinition(
components=self._get_components(global_data),
connections=[]
components=self._get_components(),
connections=self._get_connections(),
)


class PipelineConfig(AbstractPipelineConfig):
component_config: list[ComponentConfig]
connection_config: list[ConnectionDefinition]
template_: Literal[PipelineTemplateType.NONE] = PipelineTemplateType.NONE

def _get_components(self, global_data: dict[str, Any]) -> list[ComponentDefinition]:
def _get_connections(self) -> list[ConnectionDefinition]:
return self.connection_config

def _get_components(self) -> list[ComponentDefinition]:
return [
self._resolve_component(component_config, global_data)
self._resolve_component(component_config)
for component_config in self.component_config
]


class TemplatePipelineConfig(AbstractPipelineConfig):
class TemplatePipelineConfig(AbstractPipelineConfig, abc.ABC):
COMPONENTS: ClassVar[list[str]] = []

def _get_components(self, global_data: dict[str, Any]) -> list[ComponentDefinition]:
def _get_components(self) -> list[ComponentDefinition]:
components = []
for component_name in self.COMPONENTS:
method = getattr(self, f"_get_{component_name}")
Expand Down Expand Up @@ -291,6 +312,8 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
def get_pdf_loader(self) -> Component | None:
if not self.from_pdf:
return None
if self.pdf_loader:
return self._resolve_component(self.pdf_loader)


def get_discriminator_value(model: dict[str, Any]) -> str:
Expand Down Expand Up @@ -324,23 +347,27 @@ def _parse(cls, file_path: Union[str, Path]) -> PipelineDefinition:
if not isinstance(file_path, Path):
file_path = Path(file_path)
data = ConfigReader().read(file_path)
wrapper = PipelineConfigWrapper.model_validate(config=data)
wrapper = PipelineConfigWrapper.model_validate({"config": data})
return wrapper.parse()

def run(self, **data):
print(self.pipeline._nodes)
print(self.pipeline._edges)
return self.run_params
async def run(self, data):
run_param = deep_update(self.run_params, data)
return await self.pipeline.run(data=run_param)


if __name__ == "__main__":

from dotenv import load_dotenv
load_dotenv()

import asyncio
file_path = "examples/customize/build_graph/pipeline/pipeline_config.json"
runner = PipelineRunner.from_config_file(file_path)
print(runner.run())
print(
asyncio.run(
runner.run({"splitter": {"text": "blabla"}})
)
)

"""
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
import os
from typing import Any

from .types import ParamFromEnvConfig, ParamResolverEnum, ParamToResolveConfig
from .types import ParamFromEnvConfig, ParamResolverEnum, ParamToResolveConfig, \
ParamFromKeyConfig


class ParamResolver:
"""A base class for all parameter resolvers."""

name: ParamResolverEnum

def __init__(self, data: dict[str, Any]) -> None:
self.data = data

def resolve(self, param: ParamToResolveConfig) -> Any:
raise NotImplementedError

Expand All @@ -50,9 +54,48 @@ def resolve(self, param: ParamFromEnvConfig) -> Any:
return os.environ.get(param.var_)


class ConfigKeyParamResolver(ParamResolver):
"""Resolve a parameter by searching through the
config file. A parameter is defined by a `key_`.
It is possible to access nested keys by separating
each key with dots. For instance:
Example:
.. code-block:: python
data = {
"shared": {
"env": "LOCAL"
},
"section": {
"env": {
"resolver_": "KEY",
"key_": "shared.env"
}
}
}
resolver = ConfigKeyParamResolver(data)
resolver.resolve("shared.env")
# Output: "LOCAL"
"""

name = ParamResolverEnum.CONFIG_KEY
KEY_SEP = "."

def resolve(self, param: ParamFromKeyConfig) -> Any:
d = self.data
for k in param.key_.split(self.KEY_SEP):
d = d[k]
return d


PARAM_RESOLVERS = {
resolver.name: resolver
for resolver in [
EnvParamResolver,
ConfigKeyParamResolver,
]
}
8 changes: 8 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/config/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

class ParamResolverEnum(str, enum.Enum):
ENV = "ENV"
CONFIG_ARRAY = "CONFIG_ARRAY"
CONFIG_KEY = "CONFIG_KEY"


class ParamToResolveConfig(BaseModel):
Expand All @@ -40,10 +42,16 @@ class ParamFromEnvConfig(ParamToResolveConfig):
var_: str


class ParamFromKeyConfig(ParamToResolveConfig):
resolver_: Literal[ParamResolverEnum.CONFIG_KEY] = ParamResolverEnum.CONFIG_KEY
key_: str


ParamConfig = Union[
float,
str,
ParamFromEnvConfig,
ParamFromKeyConfig,
dict[str, Any],
]

Expand Down

0 comments on commit 668fbc2

Please sign in to comment.