Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): better error when no instance configured #259

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pathlib import Path
from typing import Protocol, TypeVar

import typer
from pydantic.alias_generators import to_snake
from rich.console import Console

from ragbits.core.config import CoreConfig, core_config
from ragbits.core.utils.config_handling import InvalidConfigError, NoDefaultConfigError, WithConstructionConfig

WithConstructionConfigT_co = TypeVar("WithConstructionConfigT_co", bound=WithConstructionConfig, covariant=True)


# Using a Protocol instead of simply typing the `cls` argument to `get_instance_or_exit`
# as `type[WithConstructionConfigT]` in order to workaround the issue of mypy not allowing abstract classes
# to be used as types: https://github.com/python/mypy/issues/4717
class WithConstructionConfigProtocol(Protocol[WithConstructionConfigT_co]):
@classmethod
def subclass_from_defaults(
cls, defaults: CoreConfig, factory_path_override: str | None = None, yaml_path_override: Path | None = None
) -> WithConstructionConfigT_co: ...


def get_instance_or_exit(
cls: WithConstructionConfigProtocol[WithConstructionConfigT_co],
type_name: str | None = None,
yaml_path: Path | None = None,
factory_path: str | None = None,
yaml_path_argument_name: str = "--yaml-path",
factory_path_argument_name: str = "--factory-path",
) -> WithConstructionConfigT_co:
"""
Returns an instance of the provided class, initialized using its `subclass_from_defaults` method.
If the instance can't be created, prints an error message and exits the program.

Args:
cls: The class to create an instance of.
type_name: The name to use in error messages. If None, inferred from the class name.
yaml_path: Path to a YAML configuration file to use for initialization.
factory_path: Python path to a factory function to use for initialization.
yaml_path_argument_name: The name of the argument to use in error messages for the YAML path.
factory_path_argument_name: The name of the argument to use in error messages for the factory path.
"""
if not isinstance(cls, type):
raise TypeError(f"get_instance_or_exit expects the `cls` argument to be a class, got {cls}")

type_name = type_name or to_snake(cls.__name__).replace("_", " ")
try:
return cls.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=yaml_path,
)
except NoDefaultConfigError as e:
Console(
stderr=True
).print(f"""You need to provide the [b]{type_name}[/b] instance be used. You can do this by either:
- providing a path to a YAML configuration file with the [b]{yaml_path_argument_name}[/b] option
- providing a Python path to a function that creates a vector store with the [b]{factory_path_argument_name}[/b] option
- setting the default configuration or factory function in your project's [b]pyproject.toml[/b] file""")
raise typer.Exit(1) from e
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class InvalidConfigError(Exception):
"""


class NoDefaultConfigError(InvalidConfigError):
"""
An exception to be raised when no falling back to default configuration is not possible.
"""


def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401
"""
Retrieves and returns an object based on the string in the format of "module.submodule:object_name".
Expand Down Expand Up @@ -156,7 +162,7 @@ def subclass_from_defaults(
if default_config := defaults.default_instances_config.get(cls.configuration_key):
return cls.subclass_from_config(ObjectContructionConfig.model_validate(default_config))

raise InvalidConfigError(f"Could not find default factory or configuration for {cls.configuration_key}")
raise NoDefaultConfigError(f"Could not find default factory or configuration for {cls.configuration_key}")

@classmethod
def from_config(cls, config: dict) -> Self:
Expand Down
40 changes: 16 additions & 24 deletions packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

import typer
from pydantic import BaseModel
from rich.console import Console

from ragbits.cli import cli_state, print_output
from ragbits.cli._utils import get_instance_or_exit
from ragbits.cli.state import OutputType
from ragbits.core.config import core_config
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.utils.config_handling import InvalidConfigError
from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions

vector_stores_app = typer.Typer(no_args_is_help=True)
Expand All @@ -27,17 +25,13 @@ class CLIState:
@vector_stores_app.callback()
def common_args(
factory_path: str | None = None,
yaml_path: str | None = None,
yaml_path: Path | None = None,
) -> None:
try:
state.vector_store = VectorStore.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=Path.cwd() / yaml_path if yaml_path else None,
)
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
state.vector_store = get_instance_or_exit(
VectorStore,
factory_path=factory_path,
yaml_path=yaml_path,
)


@vector_stores_app.command(name="list")
Expand Down Expand Up @@ -85,7 +79,7 @@ def query(
k: int = 5,
max_distance: float | None = None,
embedder_factory_path: str | None = None,
embedder_yaml_path: str | None = None,
embedder_yaml_path: Path | None = None,
) -> None:
"""
Query the chosen vector store.
Expand All @@ -95,16 +89,14 @@ async def run() -> None:
if state.vector_store is None:
raise ValueError("Vector store not initialized")

try:
embedder: Embeddings = Embeddings.subclass_from_defaults(
core_config,
factory_path_override=embedder_factory_path,
yaml_path_override=Path.cwd() / embedder_yaml_path if embedder_yaml_path else None,
)
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e

embedder = get_instance_or_exit(
Embeddings,
factory_path=embedder_factory_path,
yaml_path=embedder_yaml_path,
factory_path_argument_name="--embedder_factory_path",
yaml_path_argument_name="--embedder_yaml_path",
type_name="embedder",
)
search_vector = await embedder.embed_text([text])

options = VectorStoreOptions(k=k, max_distance=max_distance)
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-core/tests/cli/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_vector_store_cli_no_store():
"""
runner = CliRunner(mix_stderr=False)
result = runner.invoke(vector_stores_app, ["list"])
assert "Could not find default factory or configuration" in result.stderr
assert "You need to provide the vector store instance be used" in result.stderr


def test_vector_store_list():
Expand Down
Loading