Skip to content

Commit

Permalink
Feat(vectore-store): CLI commands for managing vector stores
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Dec 10, 2024
1 parent 439e445 commit b43a64a
Show file tree
Hide file tree
Showing 28 changed files with 641 additions and 194 deletions.
23 changes: 18 additions & 5 deletions packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@

import ragbits

from .app import CLI, OutputType
from .state import OutputType, cli_state, print_output

app = CLI(no_args_is_help=True)
__all__ = [
"OutputType",
"app",
"cli_state",
"print_output",
]

app = typer.Typer(no_args_is_help=True)


@app.callback()
Expand All @@ -23,12 +30,12 @@ def output_type(
Args:
output: type of output to be set
"""
app.set_output_type(output_type=output)
cli_state.output_type = output


def main() -> None:
def autodiscover() -> None:
"""
Main entry point for the CLI.
Autodiscover and register all the CLI modules in the ragbits packages.
This function registers all the CLI modules in the ragbits packages:
- iterates over every package in the ragbits.* namespace
Expand All @@ -46,4 +53,10 @@ def main() -> None:
register_func = importlib.import_module(f"ragbits.{module.name}.cli").register
register_func(app)


def main() -> None:
"""
Main entry point for the CLI. Registers all the CLI commands and runs the app.
"""
autodiscover()
app()
76 changes: 0 additions & 76 deletions packages/ragbits-cli/src/ragbits/cli/app.py

This file was deleted.

64 changes: 64 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum

from pydantic import BaseModel
from rich.console import Console
from rich.table import Table


class OutputType(Enum):
"""Indicates a type of CLI output formatting"""

text = "text"
json = "json"


@dataclass()
class CliState:
"""A dataclass describing CLI state"""

output_type: OutputType = OutputType.text


cli_state = CliState()


def print_output(data: Sequence[BaseModel] | BaseModel) -> None:
"""
Process and display output based on the current state's output type.
Args:
data: a list of pydantic models representing output of CLI function
"""
console = Console()
if isinstance(data, BaseModel):
data = [data]
if len(data) == 0:
_print_empty_list()
return
first_el_instance = type(data[0])
if any(not isinstance(datapoint, first_el_instance) for datapoint in data):
raise ValueError("All the rows need to be of the same type")
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data]
output_type = cli_state.output_type
if output_type == OutputType.json:
console.print(json.dumps(data_dicts, indent=4))
elif output_type == OutputType.text:
table = Table(show_header=True, header_style="bold magenta")
properties = data[0].model_json_schema()["properties"]
for key in properties:
table.add_column(properties[key]["title"])
for row in data_dicts:
table.add_row(*[str(value) for value in row.values()])
console.print(table)
else:
raise ValueError(f"Output type: {output_type} not supported")


def _print_empty_list() -> None:
if cli_state.output_type == OutputType.text:
print("Empty data list")
elif cli_state.output_type == OutputType.json:
print(json.dumps([]))
97 changes: 4 additions & 93 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,15 @@
# pylint: disable=import-outside-toplevel
# pylint: disable=missing-param-doc
import asyncio
import json
from importlib import import_module
from pathlib import Path

import typer
from pydantic import BaseModel

from ragbits.cli.app import CLI
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.prompt.prompt import ChatFormat, Prompt


def _render(prompt_path: str, payload: str | None) -> Prompt:
module_stringified, object_stringified = prompt_path.split(":")
prompt_cls = getattr(import_module(module_stringified), object_stringified)

if payload is not None:
payload = json.loads(payload)
inputs = prompt_cls.input_type(**payload)
return prompt_cls(inputs)

return prompt_cls()


class LLMResponseCliOutput(BaseModel):
"""An output model for llm responses in CLI"""

question: ChatFormat
answer: str | BaseModel | None = None


prompts_app = typer.Typer(no_args_is_help=True)
from ragbits.core.prompt._cli import prompts_app
from ragbits.core.vector_stores._cli import vector_stores_app


def register(app: CLI) -> None:
def register(app: typer.Typer) -> None:
"""
Register the CLI commands for the package.
Args:
app: The Typer object to register the commands with.
"""

@prompts_app.command()
def lab(
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
defined within the current project.
"""
from ragbits.core.prompt.lab.app import lab_app

lab_app(file_pattern=file_pattern, llm_factory=llm_factory)

@prompts_app.command()
def generate_promptfoo_configs(
file_pattern: str = core_config.prompt_path_pattern,
root_path: Path = Path.cwd(), # noqa: B008
target_path: Path = Path("promptfooconfigs"),
) -> None:
"""
Generates the configuration files for the PromptFoo prompts.
"""
from ragbits.core.prompt.promptfoo import generate_configs

generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path)

@prompts_app.command()
def render(prompt_path: str, payload: str | None = None) -> None:
"""
Renders a prompt by loading a class from a module and initializing it with a given payload.
"""
prompt = _render(prompt_path=prompt_path, payload=payload)
response = LLMResponseCliOutput(question=prompt.chat)
app.print_output(response)

@prompts_app.command(name="exec")
def execute(
prompt_path: str,
payload: str | None = None,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Executes a prompt using the specified prompt class and LLM factory.
Raises:
ValueError: If `llm_factory` is not provided.
"""
prompt = _render(prompt_path=prompt_path, payload=payload)

if llm_factory is None:
raise ValueError("`llm_factory` must be provided")
llm: LLM = LLM.subclass_from_factory(llm_factory)

llm_output = asyncio.run(llm.generate(prompt))
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output)
app.print_output(response)

app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts")
app.add_typer(vector_stores_app, name="vector-store", help="Commands for managing vector stores")
27 changes: 26 additions & 1 deletion packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from functools import cached_property
from pathlib import Path

from pydantic import BaseModel

from ragbits.core.llms.base import LLMType
from ragbits.core.utils._pyproject import get_config_instance
from ragbits.core.utils._pyproject import get_config_from_yaml, get_config_instance


class CoreConfig(BaseModel):
"""
Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files.
"""

# Path to the base directory of the project, defaults to the directory of the pyproject.toml file
project_base_path: Path | None = None

# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"

Expand All @@ -19,5 +25,24 @@ class CoreConfig(BaseModel):
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory",
}

# Path to functions that returns instances of diffrent types of Ragbits objects
default_factories: dict[str, str] = {}

# Path to a YAML file with default configuration of varius Ragbits objects
default_instaces_config_path: Path | None = None

@cached_property
def default_instances_config(self) -> dict:
"""
Get the configuration from the file specified in default_instaces_config_path.
Returns:
dict: The configuration from the file.
"""
if self.default_instaces_config_path is None or not self.project_base_path:
return {}

return get_config_from_yaml(self.project_base_path / self.default_instaces_config_path)


core_config = get_config_instance(CoreConfig, subproject="core")
1 change: 1 addition & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Embeddings(WithConstructionConfig, ABC):
"""

default_module: ClassVar = embeddings
configuration_key: ClassVar = "embedder"

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
Expand Down
1 change: 1 addition & 0 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC):

_options_cls: type[LLMClientOptions]
default_module: ClassVar = llms
configuration_key: ClassVar = "llm"

def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class MetadataStore(WithConstructionConfig, ABC):
"""

default_module: ClassVar = metadata_stores
configuration_key: ClassVar = "metadata_store"

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
Expand Down
Loading

0 comments on commit b43a64a

Please sign in to comment.