Skip to content

Commit

Permalink
feat(cli): better error when no instance configured (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Dec 19, 2024
1 parent 80de16a commit 30796aa
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 26 deletions.
64 changes: 64 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pathlib import Path
from typing import Protocol, TypeVar

import typer
from pydantic.alias_generators import to_snake
from rich.console import Console

from ragbits.core.config import CoreConfig, core_config
from ragbits.core.utils.config_handling import InvalidConfigError, NoDefaultConfigError, WithConstructionConfig

WithConstructionConfigT_co = TypeVar("WithConstructionConfigT_co", bound=WithConstructionConfig, covariant=True)


# Using a Protocol instead of simply typing the `cls` argument to `get_instance_or_exit`
# as `type[WithConstructionConfigT]` in order to workaround the issue of mypy not allowing abstract classes
# to be used as types: https://github.com/python/mypy/issues/4717
class WithConstructionConfigProtocol(Protocol[WithConstructionConfigT_co]):
@classmethod
def subclass_from_defaults(
cls, defaults: CoreConfig, factory_path_override: str | None = None, yaml_path_override: Path | None = None
) -> WithConstructionConfigT_co: ...


def get_instance_or_exit(
cls: WithConstructionConfigProtocol[WithConstructionConfigT_co],
type_name: str | None = None,
yaml_path: Path | None = None,
factory_path: str | None = None,
yaml_path_argument_name: str = "--yaml-path",
factory_path_argument_name: str = "--factory-path",
) -> WithConstructionConfigT_co:
"""
Returns an instance of the provided class, initialized using its `subclass_from_defaults` method.
If the instance can't be created, prints an error message and exits the program.
Args:
cls: The class to create an instance of.
type_name: The name to use in error messages. If None, inferred from the class name.
yaml_path: Path to a YAML configuration file to use for initialization.
factory_path: Python path to a factory function to use for initialization.
yaml_path_argument_name: The name of the argument to use in error messages for the YAML path.
factory_path_argument_name: The name of the argument to use in error messages for the factory path.
"""
if not isinstance(cls, type):
raise TypeError(f"get_instance_or_exit expects the `cls` argument to be a class, got {cls}")

type_name = type_name or to_snake(cls.__name__).replace("_", " ")
try:
return cls.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=yaml_path,
)
except NoDefaultConfigError as e:
Console(
stderr=True
).print(f"""You need to provide the [b]{type_name}[/b] instance be used. You can do this by either:
- providing a path to a YAML configuration file with the [b]{yaml_path_argument_name}[/b] option
- providing a Python path to a function that creates a vector store with the [b]{factory_path_argument_name}[/b] option
- setting the default configuration or factory function in your project's [b]pyproject.toml[/b] file""")
raise typer.Exit(1) from e
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class InvalidConfigError(Exception):
"""


class NoDefaultConfigError(InvalidConfigError):
"""
An exception to be raised when no falling back to default configuration is not possible.
"""


def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
"""
Retrieves and returns an object based on the string in the format of "module.submodule:object_name".
Expand Down Expand Up @@ -156,7 +162,7 @@ def subclass_from_defaults(
if default_config := defaults.default_instances_config.get(cls.configuration_key):
return cls.subclass_from_config(ObjectContructionConfig.model_validate(default_config))

raise InvalidConfigError(f"Could not find default factory or configuration for {cls.configuration_key}")
raise NoDefaultConfigError(f"Could not find default factory or configuration for {cls.configuration_key}")

@classmethod
def from_config(cls, config: dict) -> Self:
Expand Down
40 changes: 16 additions & 24 deletions packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

import typer
from pydantic import BaseModel
from rich.console import Console

from ragbits.cli import cli_state, print_output
from ragbits.cli._utils import get_instance_or_exit
from ragbits.cli.state import OutputType
from ragbits.core.config import core_config
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.utils.config_handling import InvalidConfigError
from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions

vector_stores_app = typer.Typer(no_args_is_help=True)
Expand All @@ -27,17 +25,13 @@ class CLIState:
@vector_stores_app.callback()
def common_args(
factory_path: str | None = None,
yaml_path: str | None = None,
yaml_path: Path | None = None,
) -> None:
try:
state.vector_store = VectorStore.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=Path.cwd() / yaml_path if yaml_path else None,
)
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
state.vector_store = get_instance_or_exit(
VectorStore,
factory_path=factory_path,
yaml_path=yaml_path,
)


@vector_stores_app.command(name="list")
Expand Down Expand Up @@ -85,7 +79,7 @@ def query(
k: int = 5,
max_distance: float | None = None,
embedder_factory_path: str | None = None,
embedder_yaml_path: str | None = None,
embedder_yaml_path: Path | None = None,
) -> None:
"""
Query the chosen vector store.
Expand All @@ -95,16 +89,14 @@ async def run() -> None:
if state.vector_store is None:
raise ValueError("Vector store not initialized")

try:
embedder: Embeddings = Embeddings.subclass_from_defaults(
core_config,
factory_path_override=embedder_factory_path,
yaml_path_override=Path.cwd() / embedder_yaml_path if embedder_yaml_path else None,
)
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e

embedder = get_instance_or_exit(
Embeddings,
factory_path=embedder_factory_path,
yaml_path=embedder_yaml_path,
factory_path_argument_name="--embedder_factory_path",
yaml_path_argument_name="--embedder_yaml_path",
type_name="embedder",
)
search_vector = await embedder.embed_text([text])

options = VectorStoreOptions(k=k, max_distance=max_distance)
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-core/tests/cli/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_vector_store_cli_no_store():
"""
runner = CliRunner(mix_stderr=False)
result = runner.invoke(vector_stores_app, ["list"])
assert "Could not find default factory or configuration" in result.stderr
assert "You need to provide the vector store instance be used" in result.stderr


def test_vector_store_list():
Expand Down

0 comments on commit 30796aa

Please sign in to comment.