diff --git a/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py b/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py index 8b8ce220..50443749 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/config_poc.py @@ -32,6 +32,7 @@ ) from pydantic.v1.utils import deep_update +from neo4j_graphrag.embeddings import Embedder from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.pipeline import Component, Pipeline from neo4j_graphrag.experimental.pipeline.config.param_resolvers import PARAM_RESOLVERS @@ -189,7 +190,6 @@ def validate_class(cls, class_: Any) -> str: return "not used" def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver: - print(self.params_) params = self.resolve_params(self.params_) uri = params.pop( "uri" @@ -241,6 +241,27 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface: return self.root.parse() +class EmbedderConfig(ObjectConfig[Embedder]): + """Configuration for any Embedder object. + + By default, will try to import from `neo4j_graphrag.embeddings`. + """ + + DEFAULT_MODULE = "neo4j_graphrag.embeddings" + INTERFACE = Embedder + + +class EmbedderType(RootModel): # type: ignore[type-arg] + root: Union[Embedder, EmbedderConfig] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder: + if isinstance(self.root, Embedder): + return self.root + return self.root.parse() + + class ComponentConfig(ObjectConfig[Component]): """A config model for all components. @@ -283,6 +304,7 @@ class AbstractPipelineConfig(AbstractConfig): """ neo4j_config: dict[str, Neo4jDriverType] = {} llm_config: dict[str, LLMType] = {} + embedder_config: dict[str, EmbedderType] = {} # extra parameters values that can be used in different places of the config file extras: dict[str, Any] = {} @@ -308,7 +330,16 @@ def validate_llms( return {cls.DEFAULT_NAME: llms} return llms - @field_validator("llm_config", "neo4j_config", mode="after") + @field_validator("embedder_config", mode="before") + @classmethod + def validate_embedders( + cls, embedders: Union[EmbedderType, dict[str, Any]] + ) -> dict[str, Any]: + if not isinstance(embedders, dict) or "params_" in embedders: + return {cls.DEFAULT_NAME: embedders} + return embedders + + @field_validator("llm_config", "neo4j_config", "embedder_config", mode="after") @classmethod def validate_names(cls, lst: dict[str, Any]) -> dict[str, Any]: if not lst: @@ -350,9 +381,14 @@ def _parse_global_data(self) -> dict[str, Any]: llm_name: llm_config.parse() for llm_name, llm_config in self.llm_config.items() } + embedders: dict[str, Embedder] = { + embedder_name: embedder_config.parse() + for embedder_name, embedder_config in self.embedder_config.items() + } return { "neo4j_config": drivers, "llm_config": llms, + "embedder_config": embedders, "extras": self.resolve_params(self.extras), }