Skip to content

Commit

Permalink
Use cast to remove a type ignore comment
Browse files Browse the repository at this point in the history
  • Loading branch information
stellasia committed Dec 6, 2024
1 parent aedd9a3 commit 3122eda
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 60 deletions.
31 changes: 1 addition & 30 deletions src/neo4j_graphrag/experimental/pipeline/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

from __future__ import annotations

import importlib
import logging
from typing import Any, Optional
from typing import Any

from pydantic import BaseModel, PrivateAttr

Expand All @@ -41,34 +40,6 @@ class AbstractConfig(BaseModel):
_global_data: dict[str, Any] = PrivateAttr({})
"""Additional parameter ignored by all Pydantic model_* methods."""

@classmethod
def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type:
"""Get class from string and an optional module
Will first try to import the class from `class_path` alone. If it results in an ImportError,
will try to import from `f'{optional_module}.{class_path}'`
Args:
class_path (str): Class path with format 'my_module.MyClass'.
optional_module (Optional[str]): Optional module path. Used to provide a default path for some known objects and simplify the notation.
Raises:
ValueError: if the class can't be imported, even using the optional module.
"""
*modules, class_name = class_path.rsplit(".", 1)
module_name = modules[0] if modules else optional_module
if module_name is None:
raise ValueError("Must specify a module to import class from")
try:
module = importlib.import_module(module_name)
klass = getattr(module, class_name)
except (ImportError, AttributeError):
if optional_module and module_name != optional_module:
full_klass_path = optional_module + "." + class_path
return cls._get_class(full_klass_path)
raise ValueError(f"Could not find {class_name} in {module_name}")
return klass # type: ignore[no-any-return]

def resolve_param(self, param: ParamConfig) -> Any:
"""Finds the parameter value from its definition."""
if not isinstance(param, ParamToResolveConfig):
Expand Down
31 changes: 31 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/config/object_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@

from __future__ import annotations

import importlib
import logging
from typing import (
Any,
ClassVar,
Generic,
Optional,
TypeVar,
Union,
cast,
)

import neo4j
Expand Down Expand Up @@ -84,6 +87,34 @@ def get_module(self) -> str:
def get_interface(self) -> type:
return self.INTERFACE

@classmethod
def _get_class(cls, class_path: str, optional_module: Optional[str] = None) -> type:
"""Get class from string and an optional module
Will first try to import the class from `class_path` alone. If it results in an ImportError,
will try to import from `f'{optional_module}.{class_path}'`
Args:
class_path (str): Class path with format 'my_module.MyClass'.
optional_module (Optional[str]): Optional module path. Used to provide a default path for some known objects and simplify the notation.
Raises:
ValueError: if the class can't be imported, even using the optional module.
"""
*modules, class_name = class_path.rsplit(".", 1)
module_name = modules[0] if modules else optional_module
if module_name is None:
raise ValueError("Must specify a module to import class from")
try:
module = importlib.import_module(module_name)
klass = getattr(module, class_name)
except (ImportError, AttributeError):
if optional_module and module_name != optional_module:
full_klass_path = optional_module + "." + class_path
return cls._get_class(full_klass_path)
raise ValueError(f"Could not find {class_name} in {module_name}")
return cast(type, klass)

def parse(self, resolved_data: dict[str, Any] | None = None) -> T:
"""Import `class_`, resolve `params_` and instantiate object."""
self._global_data = resolved_data or {}
Expand Down
30 changes: 0 additions & 30 deletions tests/unit/experimental/pipeline/config/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,12 @@
# limitations under the License.
from unittest.mock import patch

import pytest
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig
from neo4j_graphrag.experimental.pipeline.config.param_resolver import (
ParamToResolveConfig,
)


def test_get_class_no_optional_module() -> None:
c = AbstractConfig()
klass = c._get_class("neo4j_graphrag.experimental.pipeline.Pipeline")
assert klass == Pipeline


def test_get_class_optional_module() -> None:
c = AbstractConfig()
klass = c._get_class(
"Pipeline", optional_module="neo4j_graphrag.experimental.pipeline"
)
assert klass == Pipeline


def test_get_class_path_and_optional_module() -> None:
c = AbstractConfig()
klass = c._get_class(
"pipeline.Pipeline", optional_module="neo4j_graphrag.experimental"
)
assert klass == Pipeline


def test_get_class_wrong_path() -> None:
c = AbstractConfig()
with pytest.raises(ValueError):
c._get_class("MyClass")


def test_resolve_param_with_param_to_resolve_object() -> None:
c = AbstractConfig()
with patch(
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/experimental/pipeline/config/test_object_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,49 @@
from unittest.mock import patch

import neo4j
import pytest
from neo4j_graphrag.embeddings import Embedder, OpenAIEmbeddings
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.config.object_config import (
EmbedderConfig,
EmbedderType,
LLMConfig,
LLMType,
Neo4jDriverConfig,
Neo4jDriverType,
ObjectConfig,
)
from neo4j_graphrag.llm import LLMInterface, OpenAILLM


def test_get_class_no_optional_module() -> None:
c: ObjectConfig[object] = ObjectConfig()
klass = c._get_class("neo4j_graphrag.experimental.pipeline.Pipeline")
assert klass == Pipeline


def test_get_class_optional_module() -> None:
c: ObjectConfig[object] = ObjectConfig()
klass = c._get_class(
"Pipeline", optional_module="neo4j_graphrag.experimental.pipeline"
)
assert klass == Pipeline


def test_get_class_path_and_optional_module() -> None:
c: ObjectConfig[object] = ObjectConfig()
klass = c._get_class(
"pipeline.Pipeline", optional_module="neo4j_graphrag.experimental"
)
assert klass == Pipeline


def test_get_class_wrong_path() -> None:
c: ObjectConfig[object] = ObjectConfig()
with pytest.raises(ValueError):
c._get_class("MyClass")


def test_neo4j_driver_config() -> None:
config = Neo4jDriverConfig.model_validate(
{
Expand Down

0 comments on commit 3122eda

Please sign in to comment.