Skip to content

Commit

Permalink
feat(cli): better error when no instance configured
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Dec 18, 2024
1 parent 94b1e94 commit ee4a1a9
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 25 deletions.
51 changes: 51 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pathlib import Path
from typing import TypeVar

import typer
from pydantic.alias_generators import to_snake
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.utils.config_handling import InvalidConfigError, NoDefaultConfigError, WithConstructionConfig

WithConstructionConfigT = TypeVar("WithConstructionConfigT", bound=WithConstructionConfig)


def get_instance_or_exit(
cls: type[WithConstructionConfigT],
type_name: str | None = None,
yaml_path: Path | None = None,
factory_path: str | None = None,
yaml_path_argument_name: str = "--yaml-path",
factory_path_argument_name: str = "--factory-path",
) -> WithConstructionConfigT:
"""
Returns an instance of the provided class, initialized using its `subclass_from_defaults` method.
If the instance can't be created, prints an error message and exits the program.
Args:
cls: The class to create an instance of.
type_name: The name to use in error messages. If None, inferred from the class name.
yaml_path: Path to a YAML configuration file to use for initialization.
factory_path: Python path to a factory function to use for initialization.
yaml_path_argument_name: The name of the argument to use in error messages for the YAML path.
factory_path_argument_name: The name of the argument to use in error messages for the factory path.
"""
type_name = type_name or to_snake(cls.__name__).replace("_", " ")
try:
return cls.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=yaml_path,
)
except NoDefaultConfigError as e:
Console(
stderr=True
).print(f"""You need to provide the [b]{type_name}[/b] instance be used. You can do this by either:
- providing a path to a YAML configuration file with the [b]{yaml_path_argument_name}[/b] option
- providing a Python path to a function that creates a vector store with the [b]{factory_path_argument_name}[/b] option
- setting the default configuration or factory function in your project's [b]pyproject.toml[/b] file""")
raise typer.Exit(1) from e
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ class InvalidConfigError(Exception):
"""


class NoDefaultConfigError(InvalidConfigError):
"""
An exception to be raised when no falling back to default configuration is not possible.
"""


def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
"""
Retrieves and returns an object based on the string in the format of "module.submodule:object_name".
Expand Down Expand Up @@ -155,7 +161,7 @@ def subclass_from_defaults(
if default_config := defaults.default_instances_config.get(cls.configuration_key):
return cls.subclass_from_config(ObjectContructionConfig.model_validate(default_config))

raise InvalidConfigError(f"Could not find default factory or configuration for {cls.configuration_key}")
raise NoDefaultConfigError(f"Could not find default factory or configuration for {cls.configuration_key}")

@classmethod
def from_config(cls, config: dict) -> Self:
Expand Down
40 changes: 16 additions & 24 deletions packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

import typer
from pydantic import BaseModel
from rich.console import Console

from ragbits.cli import cli_state, print_output
from ragbits.cli._utils import get_instance_or_exit
from ragbits.cli.state import OutputType
from ragbits.core.config import core_config
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.utils.config_handling import InvalidConfigError
from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions

vector_stores_app = typer.Typer(no_args_is_help=True)
Expand All @@ -27,17 +25,13 @@ class CLIState:
@vector_stores_app.callback()
def common_args(
factory_path: str | None = None,
yaml_path: str | None = None,
yaml_path: Path | None = None,
) -> None:
try:
state.vector_store = VectorStore.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=Path.cwd() / yaml_path if yaml_path else None,
)
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
state.vector_store = get_instance_or_exit(
VectorStore,
factory_path=factory_path,
yaml_path=yaml_path,
)


@vector_stores_app.command(name="list")
Expand Down Expand Up @@ -85,7 +79,7 @@ def query(
k: int = 5,
max_distance: float | None = None,
embedder_factory_path: str | None = None,
embedder_yaml_path: str | None = None,
embedder_yaml_path: Path | None = None,
) -> None:
"""
Query the chosen vector store.
Expand All @@ -95,16 +89,14 @@ async def run() -> None:
if state.vector_store is None:
raise ValueError("Vector store not initialized")

try:
embedder = Embeddings.subclass_from_defaults(
core_config,
factory_path_override=embedder_factory_path,
yaml_path_override=Path.cwd() / embedder_yaml_path if embedder_yaml_path else None,
)
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e

embedder = get_instance_or_exit(
Embeddings,
factory_path=embedder_factory_path,
yaml_path=embedder_yaml_path,
factory_path_argument_name="--embedder_factory_path",
yaml_path_argument_name="--embedder_yaml_path",
type_name="embedder",
)
search_vector = await embedder.embed_text([text])

options = VectorStoreOptions(k=k, max_distance=max_distance)
Expand Down

0 comments on commit ee4a1a9

Please sign in to comment.