From 30796aa4b3d632b0b66a6fcb6ea2eff0a931d576 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 19 Dec 2024 15:20:30 +0100 Subject: [PATCH] feat(cli): better error when no instance configured (#259) --- .../ragbits-cli/src/ragbits/cli/_utils.py | 64 +++++++++++++++++++ .../src/ragbits/core/utils/config_handling.py | 8 ++- .../src/ragbits/core/vector_stores/_cli.py | 40 +++++------- .../tests/cli/test_vector_store.py | 2 +- 4 files changed, 88 insertions(+), 26 deletions(-) create mode 100644 packages/ragbits-cli/src/ragbits/cli/_utils.py diff --git a/packages/ragbits-cli/src/ragbits/cli/_utils.py b/packages/ragbits-cli/src/ragbits/cli/_utils.py new file mode 100644 index 000000000..f72f06f01 --- /dev/null +++ b/packages/ragbits-cli/src/ragbits/cli/_utils.py @@ -0,0 +1,64 @@ +from pathlib import Path +from typing import Protocol, TypeVar + +import typer +from pydantic.alias_generators import to_snake +from rich.console import Console + +from ragbits.core.config import CoreConfig, core_config +from ragbits.core.utils.config_handling import InvalidConfigError, NoDefaultConfigError, WithConstructionConfig + +WithConstructionConfigT_co = TypeVar("WithConstructionConfigT_co", bound=WithConstructionConfig, covariant=True) + + +# Using a Protocol instead of simply typing the `cls` argument to `get_instance_or_exit` +# as `type[WithConstructionConfigT]` in order to workaround the issue of mypy not allowing abstract classes +# to be used as types: https://github.com/python/mypy/issues/4717 +class WithConstructionConfigProtocol(Protocol[WithConstructionConfigT_co]): + @classmethod + def subclass_from_defaults( + cls, defaults: CoreConfig, factory_path_override: str | None = None, yaml_path_override: Path | None = None + ) -> WithConstructionConfigT_co: ... + + +def get_instance_or_exit( + cls: WithConstructionConfigProtocol[WithConstructionConfigT_co], + type_name: str | None = None, + yaml_path: Path | None = None, + factory_path: str | None = None, + yaml_path_argument_name: str = "--yaml-path", + factory_path_argument_name: str = "--factory-path", +) -> WithConstructionConfigT_co: + """ + Returns an instance of the provided class, initialized using its `subclass_from_defaults` method. + If the instance can't be created, prints an error message and exits the program. + + Args: + cls: The class to create an instance of. + type_name: The name to use in error messages. If None, inferred from the class name. + yaml_path: Path to a YAML configuration file to use for initialization. + factory_path: Python path to a factory function to use for initialization. + yaml_path_argument_name: The name of the argument to use in error messages for the YAML path. + factory_path_argument_name: The name of the argument to use in error messages for the factory path. + """ + if not isinstance(cls, type): + raise TypeError(f"get_instance_or_exit expects the `cls` argument to be a class, got {cls}") + + type_name = type_name or to_snake(cls.__name__).replace("_", " ") + try: + return cls.subclass_from_defaults( + core_config, + factory_path_override=factory_path, + yaml_path_override=yaml_path, + ) + except NoDefaultConfigError as e: + Console( + stderr=True + ).print(f"""You need to provide the [b]{type_name}[/b] instance be used. You can do this by either: +- providing a path to a YAML configuration file with the [b]{yaml_path_argument_name}[/b] option +- providing a Python path to a function that creates a vector store with the [b]{factory_path_argument_name}[/b] option +- setting the default configuration or factory function in your project's [b]pyproject.toml[/b] file""") + raise typer.Exit(1) from e + except InvalidConfigError as e: + Console(stderr=True).print(e) + raise typer.Exit(1) from e diff --git a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py index 928cc16bc..de55be64a 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py +++ b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py @@ -22,6 +22,12 @@ class InvalidConfigError(Exception): """ +class NoDefaultConfigError(InvalidConfigError): + """ + An exception to be raised when no falling back to default configuration is not possible. + """ + + def import_by_path(path: str, default_module: ModuleType | None) -> Any: # noqa: ANN401 """ Retrieves and returns an object based on the string in the format of "module.submodule:object_name". @@ -156,7 +162,7 @@ def subclass_from_defaults( if default_config := defaults.default_instances_config.get(cls.configuration_key): return cls.subclass_from_config(ObjectContructionConfig.model_validate(default_config)) - raise InvalidConfigError(f"Could not find default factory or configuration for {cls.configuration_key}") + raise NoDefaultConfigError(f"Could not find default factory or configuration for {cls.configuration_key}") @classmethod def from_config(cls, config: dict) -> Self: diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py b/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py index 6e49488de..7fd77b343 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py @@ -4,13 +4,11 @@ import typer from pydantic import BaseModel -from rich.console import Console from ragbits.cli import cli_state, print_output +from ragbits.cli._utils import get_instance_or_exit from ragbits.cli.state import OutputType -from ragbits.core.config import core_config from ragbits.core.embeddings.base import Embeddings -from ragbits.core.utils.config_handling import InvalidConfigError from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions vector_stores_app = typer.Typer(no_args_is_help=True) @@ -27,17 +25,13 @@ class CLIState: @vector_stores_app.callback() def common_args( factory_path: str | None = None, - yaml_path: str | None = None, + yaml_path: Path | None = None, ) -> None: - try: - state.vector_store = VectorStore.subclass_from_defaults( - core_config, - factory_path_override=factory_path, - yaml_path_override=Path.cwd() / yaml_path if yaml_path else None, - ) - except InvalidConfigError as e: - Console(stderr=True).print(e) - raise typer.Exit(1) from e + state.vector_store = get_instance_or_exit( + VectorStore, + factory_path=factory_path, + yaml_path=yaml_path, + ) @vector_stores_app.command(name="list") @@ -85,7 +79,7 @@ def query( k: int = 5, max_distance: float | None = None, embedder_factory_path: str | None = None, - embedder_yaml_path: str | None = None, + embedder_yaml_path: Path | None = None, ) -> None: """ Query the chosen vector store. @@ -95,16 +89,14 @@ async def run() -> None: if state.vector_store is None: raise ValueError("Vector store not initialized") - try: - embedder: Embeddings = Embeddings.subclass_from_defaults( - core_config, - factory_path_override=embedder_factory_path, - yaml_path_override=Path.cwd() / embedder_yaml_path if embedder_yaml_path else None, - ) - except InvalidConfigError as e: - Console(stderr=True).print(e) - raise typer.Exit(1) from e - + embedder = get_instance_or_exit( + Embeddings, + factory_path=embedder_factory_path, + yaml_path=embedder_yaml_path, + factory_path_argument_name="--embedder_factory_path", + yaml_path_argument_name="--embedder_yaml_path", + type_name="embedder", + ) search_vector = await embedder.embed_text([text]) options = VectorStoreOptions(k=k, max_distance=max_distance) diff --git a/packages/ragbits-core/tests/cli/test_vector_store.py b/packages/ragbits-core/tests/cli/test_vector_store.py index 2b8a9e570..09aee4532 100644 --- a/packages/ragbits-core/tests/cli/test_vector_store.py +++ b/packages/ragbits-core/tests/cli/test_vector_store.py @@ -87,7 +87,7 @@ def test_vector_store_cli_no_store(): """ runner = CliRunner(mix_stderr=False) result = runner.invoke(vector_stores_app, ["list"]) - assert "Could not find default factory or configuration" in result.stderr + assert "You need to provide the vector store instance be used" in result.stderr def test_vector_store_list():