Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vector-store): CLI commands for managing vector stores #244

Merged
merged 6 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@

import ragbits

from .app import CLI, OutputType
from .state import OutputType, cli_state, print_output

app = CLI(no_args_is_help=True)
__all__ = [
"OutputType",
"app",
"cli_state",
"print_output",
]

app = typer.Typer(no_args_is_help=True)


@app.callback()
Expand All @@ -23,12 +30,12 @@ def output_type(
Args:
output: type of output to be set
"""
app.set_output_type(output_type=output)
cli_state.output_type = output


def main() -> None:
def autodiscover() -> None:
kdziedzic68 marked this conversation as resolved.
Show resolved Hide resolved
"""
Main entry point for the CLI.
Autodiscover and register all the CLI modules in the ragbits packages.

This function registers all the CLI modules in the ragbits packages:
- iterates over every package in the ragbits.* namespace
Expand All @@ -46,4 +53,10 @@ def main() -> None:
register_func = importlib.import_module(f"ragbits.{module.name}.cli").register
register_func(app)


def main() -> None:
"""
Main entry point for the CLI. Registers all the CLI commands and runs the app.
"""
autodiscover()
app()
76 changes: 0 additions & 76 deletions packages/ragbits-cli/src/ragbits/cli/app.py

This file was deleted.

64 changes: 64 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum

from pydantic import BaseModel
from rich.console import Console
from rich.table import Table


class OutputType(Enum):
"""Indicates a type of CLI output formatting"""

text = "text"
json = "json"


@dataclass()
class CliState:
"""A dataclass describing CLI state"""

output_type: OutputType = OutputType.text


cli_state = CliState()


def print_output(data: Sequence[BaseModel] | BaseModel) -> None:
kdziedzic68 marked this conversation as resolved.
Show resolved Hide resolved
"""
Process and display output based on the current state's output type.

Args:
data: a list of pydantic models representing output of CLI function
"""
console = Console()
if isinstance(data, BaseModel):
data = [data]
if len(data) == 0:
_print_empty_list()
return
first_el_instance = type(data[0])
if any(not isinstance(datapoint, first_el_instance) for datapoint in data):
raise ValueError("All the rows need to be of the same type")
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data]
output_type = cli_state.output_type
if output_type == OutputType.json:
console.print(json.dumps(data_dicts, indent=4))
elif output_type == OutputType.text:
table = Table(show_header=True, header_style="bold magenta")
properties = data[0].model_json_schema()["properties"]
for key in properties:
table.add_column(properties[key]["title"])
for row in data_dicts:
table.add_row(*[str(value) for value in row.values()])
console.print(table)
else:
raise ValueError(f"Output type: {output_type} not supported")


def _print_empty_list() -> None:
if cli_state.output_type == OutputType.text:
print("Empty data list")
elif cli_state.output_type == OutputType.json:
print(json.dumps([]))
97 changes: 4 additions & 93 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,15 @@
# pylint: disable=import-outside-toplevel
# pylint: disable=missing-param-doc
import asyncio
import json
from importlib import import_module
from pathlib import Path

import typer
from pydantic import BaseModel

from ragbits.cli.app import CLI
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.prompt.prompt import ChatFormat, Prompt


def _render(prompt_path: str, payload: str | None) -> Prompt:
module_stringified, object_stringified = prompt_path.split(":")
prompt_cls = getattr(import_module(module_stringified), object_stringified)

if payload is not None:
payload = json.loads(payload)
inputs = prompt_cls.input_type(**payload)
return prompt_cls(inputs)

return prompt_cls()


class LLMResponseCliOutput(BaseModel):
"""An output model for llm responses in CLI"""

question: ChatFormat
answer: str | BaseModel | None = None


prompts_app = typer.Typer(no_args_is_help=True)
from ragbits.core.prompt._cli import prompts_app
from ragbits.core.vector_stores._cli import vector_stores_app


def register(app: CLI) -> None:
def register(app: typer.Typer) -> None:
"""
Register the CLI commands for the package.

Args:
app: The Typer object to register the commands with.
"""

@prompts_app.command()
def lab(
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
defined within the current project.
"""
from ragbits.core.prompt.lab.app import lab_app

lab_app(file_pattern=file_pattern, llm_factory=llm_factory)

@prompts_app.command()
def generate_promptfoo_configs(
file_pattern: str = core_config.prompt_path_pattern,
root_path: Path = Path.cwd(), # noqa: B008
target_path: Path = Path("promptfooconfigs"),
) -> None:
"""
Generates the configuration files for the PromptFoo prompts.
"""
from ragbits.core.prompt.promptfoo import generate_configs

generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path)

@prompts_app.command()
def render(prompt_path: str, payload: str | None = None) -> None:
"""
Renders a prompt by loading a class from a module and initializing it with a given payload.
"""
prompt = _render(prompt_path=prompt_path, payload=payload)
response = LLMResponseCliOutput(question=prompt.chat)
app.print_output(response)

@prompts_app.command(name="exec")
def execute(
prompt_path: str,
payload: str | None = None,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Executes a prompt using the specified prompt class and LLM factory.

Raises:
ValueError: If `llm_factory` is not provided.
"""
prompt = _render(prompt_path=prompt_path, payload=payload)

if llm_factory is None:
raise ValueError("`llm_factory` must be provided")
llm: LLM = LLM.subclass_from_factory(llm_factory)

llm_output = asyncio.run(llm.generate(prompt))
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output)
app.print_output(response)

app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts")
app.add_typer(vector_stores_app, name="vector-store", help="Commands for managing vector stores")
27 changes: 26 additions & 1 deletion packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from functools import cached_property
from pathlib import Path

from pydantic import BaseModel

from ragbits.core.llms.base import LLMType
from ragbits.core.utils._pyproject import get_config_instance
from ragbits.core.utils._pyproject import get_config_from_yaml, get_config_instance


class CoreConfig(BaseModel):
"""
Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files.
"""

# Path to the base directory of the project, defaults to the directory of the pyproject.toml file
project_base_path: Path | None = None

# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"

Expand All @@ -19,5 +25,24 @@ class CoreConfig(BaseModel):
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory:simple_litellm_structured_output_factory",
}

# Path to functions that returns instances of diffrent types of Ragbits objects
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
default_factories: dict[str, str] = {}

# Path to a YAML file with default configuration of varius Ragbits objects
default_instaces_config_path: Path | None = None

@cached_property
def default_instances_config(self) -> dict:
"""
Get the configuration from the file specified in default_instaces_config_path.

Returns:
dict: The configuration from the file.
"""
if self.default_instaces_config_path is None or not self.project_base_path:
return {}

return get_config_from_yaml(self.project_base_path / self.default_instaces_config_path)


core_config = get_config_instance(CoreConfig, subproject="core")
1 change: 1 addition & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Embeddings(WithConstructionConfig, ABC):
"""

default_module: ClassVar = embeddings
configuration_key: ClassVar = "embedder"

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
Expand Down
1 change: 1 addition & 0 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC):

_options_cls: type[LLMClientOptions]
default_module: ClassVar = llms
configuration_key: ClassVar = "llm"

def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class MetadataStore(WithConstructionConfig, ABC):
"""

default_module: ClassVar = metadata_stores
configuration_key: ClassVar = "metadata_store"

@abstractmethod
async def store(self, ids: list[str], metadatas: list[dict]) -> None:
Expand Down
Loading
Loading