Skip to content

Commit

Permalink
Add more comments to explain Config/Type models
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 12, 2024
1 parent 94483d8 commit 663df47
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import Any, List, Optional, Union, cast

import json_repair

from pydantic import ValidationError, validate_call

from neo4j_graphrag.exceptions import LLMGenerationError
Expand Down
29 changes: 25 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/config/object_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config for all parameters that can be both provided as object instance and
"""Config for all parameters that can be both provided as object instance or
config dict with 'class_' and 'params_' keys.
Nomenclature in this file:
- `*Config` models are used to represent "things" as dict to be used in a config file.
e.g.:
- neo4j.Driver => {"uri": "", "user": "", "password": ""}
- LLMInterface => {"class_": "OpenAI", "params_": {"model_name": "gpt-4o"}}
- `*Type` models are wrappers around an object and a 'Config' the object can be created
from. They are used to allow the instantiation of "PipelineConfig" either from
instantiated objects (when used in code) and from a config dict (when used to
load config from file).
"""

from __future__ import annotations
Expand Down Expand Up @@ -176,7 +187,7 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> neo4j.Driver:
if isinstance(self.root, neo4j.Driver):
return self.root
# self.root is a Neo4jDriverConfig object
return self.root.parse()
return self.root.parse(resolved_data)


class LLMConfig(ObjectConfig[LLMInterface]):
Expand All @@ -190,14 +201,19 @@ class LLMConfig(ObjectConfig[LLMInterface]):


class LLMType(RootModel): # type: ignore[type-arg]
"""A model to wrap LLMInterface and LLMConfig objects.
The `parse` method always returns an object inheriting from LLMInterface.
"""

root: Union[LLMInterface, LLMConfig]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> LLMInterface:
if isinstance(self.root, LLMInterface):
return self.root
return self.root.parse()
return self.root.parse(resolved_data)


class EmbedderConfig(ObjectConfig[Embedder]):
Expand All @@ -211,14 +227,19 @@ class EmbedderConfig(ObjectConfig[Embedder]):


class EmbedderType(RootModel): # type: ignore[type-arg]
"""A model to wrap Embedder and EmbedderConfig objects.
The `parse` method always returns an object inheriting from Embedder.
"""

root: Union[Embedder, EmbedderConfig]

model_config = ConfigDict(arbitrary_types_allowed=True)

def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder:
if isinstance(self.root, Embedder):
return self.root
return self.root.parse()
return self.root.parse(resolved_data)


class ComponentConfig(ObjectConfig[Component]):
Expand Down

0 comments on commit 663df47

Please sign in to comment.