Skip to content

Commit

Permalink
Add embedder config
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 3, 2024
1 parent 390be0c commit adfc908
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions src/neo4j_graphrag/experimental/pipeline/config/config_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from pydantic.v1.utils import deep_update

from neo4j_graphrag.embeddings import Embedder
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.pipeline import Component, Pipeline
from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS
Expand Down Expand Up @@ -189,7 +190,6 @@ def validate_class(cls, class_: Any) -> str:
return "not used"

def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver:
print(self.params_)
params = self.resolve_params(self.params_)
uri = params.pop(
"uri"
Expand Down Expand Up @@ -241,6 +241,27 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface:
return self.root.parse()


class EmbedderConfig(ObjectConfig[Embedder]):
"""Configuration for any Embedder object.
By default, will try to import from `neo4j_graphrag.embeddings`.
"""

DEFAULT_MODULE = "neo4j_graphrag.embeddings"
INTERFACE = Embedder


class EmbedderType(RootModel): # type: ignore[type-arg]
root: Union[Embedder, EmbedderConfig]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
if isinstance(self.root, Embedder):
return self.root
return self.root.parse()


class ComponentConfig(ObjectConfig[Component]):
"""A config model for all components.
Expand Down Expand Up @@ -283,6 +304,7 @@ class AbstractPipelineConfig(AbstractConfig):
"""
neo4j_config: dict[str, Neo4jDriverType] = {}
llm_config: dict[str, LLMType] = {}
embedder_config: dict[str, EmbedderType] = {}
# extra parameters values that can be used in different places of the config file
extras: dict[str, Any] = {}

Expand All @@ -308,7 +330,16 @@ def validate_llms(
return {cls.DEFAULT_NAME: llms}
return llms

@field_validator("llm_config", "neo4j_config", mode="after")
@field_validator("embedder_config", mode="before")
@classmethod
def validate_embedders(
cls, embedders: Union[EmbedderType, dict[str, Any]]
) -> dict[str, Any]:
if not isinstance(embedders, dict) or "params_" in embedders:
return {cls.DEFAULT_NAME: embedders}
return embedders

@field_validator("llm_config", "neo4j_config", "embedder_config", mode="after")
@classmethod
def validate_names(cls, lst: dict[str, Any]) -> dict[str, Any]:
if not lst:
Expand Down Expand Up @@ -350,9 +381,14 @@ def _parse_global_data(self) -> dict[str, Any]:
llm_name: llm_config.parse()
for llm_name, llm_config in self.llm_config.items()
}
embedders: dict[str, Embedder] = {
embedder_name: embedder_config.parse()
for embedder_name, embedder_config in self.embedder_config.items()
}
return {
"neo4j_config": drivers,
"llm_config": llms,
"embedder_config": embedders,
"extras": self.resolve_params(self.extras),
}

Expand Down

0 comments on commit adfc908

Please sign in to comment.