Skip to content

Commit

Permalink
feat(vector-store): CLI commands for managing vector stores (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Dec 18, 2024
1 parent 1d0fdf8 commit 94b1e94
Show file tree
Hide file tree
Showing 33 changed files with 710 additions and 199 deletions.
11 changes: 11 additions & 0 deletions docs/cli/main.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Ragbits CLI

Ragbits comes with a command line interface (CLI) that provides a number of commands for working with the Ragbits platform. It can be accessed by running the `ragbits` command in your terminal.

::: mkdocs-click
:module: ragbits.cli
:command: _click_app
:prog_name: ragbits
:style: table
:list_subcommands: true
:depth: 1
7 changes: 7 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ nav:
- how-to/evaluate/custom_evaluation_pipeline.md
- how-to/evaluate/custom_metric.md
- how-to/evaluate/custom_dataloader.md
- CLI:
- cli/main.md
- API Reference:
- Core:
- api_reference/core/prompt.md
Expand All @@ -41,6 +43,8 @@ nav:
- Ingestion:
- api_reference/document_search/processing.md
- api_reference/document_search/execution_strategies.md
hooks:
- mkdocs_hooks.py
theme:
name: material
icon:
Expand Down Expand Up @@ -69,6 +73,8 @@ theme:
- navigation.top
- content.code.annotate
- content.code.copy
- toc.integrate
- toc.follow
extra_css:
- stylesheets/extra.css
markdown_extensions:
Expand All @@ -94,6 +100,7 @@ markdown_extensions:
alternate_style: true
- toc:
permalink: "#"
- mkdocs-click
plugins:
- search
- autorefs
Expand Down
14 changes: 14 additions & 0 deletions mkdocs_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Literal

from ragbits import cli


def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool) -> None:
"""
Hook that runs during mkdocs startup.
Args:
command: The command that is being run.
dirty: whether --dirty flag was passed.
"""
cli._init_for_mkdocs()
46 changes: 36 additions & 10 deletions packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,39 @@
from pathlib import Path
from typing import Annotated

import click
import typer
from typer.main import get_command

import ragbits

from .app import CLI, OutputType
from .state import OutputType, cli_state, print_output

app = CLI(no_args_is_help=True)
__all__ = [
"OutputType",
"app",
"cli_state",
"print_output",
]

app = typer.Typer(no_args_is_help=True)
_click_app: click.Command | None = None # initialized in the `init_for_mkdocs` function


@app.callback()
def output_type(
def ragbits_cli(
# `OutputType.text.value` used as a workaround for the issue with `typer.Option` not accepting Enum values
output: Annotated[
OutputType, typer.Option("--output", "-o", help="Set the output type (text or json)")
] = OutputType.text.value, # type: ignore
) -> None:
"""Sets an output type for the CLI
Args:
output: type of output to be set
"""
app.set_output_type(output_type=output)
"""Common CLI arguments for all ragbits commands."""
cli_state.output_type = output


def main() -> None:
def autoregister() -> None:
"""
Main entry point for the CLI.
Autodiscover and register all the CLI modules in the ragbits packages.
This function registers all the CLI modules in the ragbits packages:
- iterates over every package in the ragbits.* namespace
Expand All @@ -46,4 +53,23 @@ def main() -> None:
register_func = importlib.import_module(f"ragbits.{module.name}.cli").register
register_func(app)


def _init_for_mkdocs() -> None:
"""
Initializes the CLI app for the mkdocs environment.
This function registers all the CLI commands and sets the `_click_app` variable to a click
command object containing all the CLI commands. This way the `mkdocs-click` plugin can
create an automatic CLI documentation.
"""
global _click_app # noqa: PLW0603
autoregister()
_click_app = get_command(app)


def main() -> None:
"""
Main entry point for the CLI. Registers all the CLI commands and runs the app.
"""
autoregister()
app()
76 changes: 0 additions & 76 deletions packages/ragbits-cli/src/ragbits/cli/app.py

This file was deleted.

64 changes: 64 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum

from pydantic import BaseModel
from rich.console import Console
from rich.table import Table


class OutputType(Enum):
"""Indicates a type of CLI output formatting"""

text = "text"
json = "json"


@dataclass()
class CliState:
"""A dataclass describing CLI state"""

output_type: OutputType = OutputType.text


cli_state = CliState()


def print_output(data: Sequence[BaseModel] | BaseModel) -> None:
"""
Process and display output based on the current state's output type.
Args:
data: a list of pydantic models representing output of CLI function
"""
console = Console()
if isinstance(data, BaseModel):
data = [data]
if len(data) == 0:
_print_empty_list()
return
first_el_instance = type(data[0])
if any(not isinstance(datapoint, first_el_instance) for datapoint in data):
raise ValueError("All the rows need to be of the same type")
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data]
output_type = cli_state.output_type
if output_type == OutputType.json:
console.print(json.dumps(data_dicts, indent=4))
elif output_type == OutputType.text:
table = Table(show_header=True, header_style="bold magenta")
properties = data[0].model_json_schema()["properties"]
for key in properties:
table.add_column(properties[key]["title"])
for row in data_dicts:
table.add_row(*[str(value) for value in row.values()])
console.print(table)
else:
raise ValueError(f"Output type: {output_type} not supported")


def _print_empty_list() -> None:
if cli_state.output_type == OutputType.text:
print("Empty data list")
elif cli_state.output_type == OutputType.json:
print(json.dumps([]))
97 changes: 4 additions & 93 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,15 @@
# pylint: disable=import-outside-toplevel
# pylint: disable=missing-param-doc
import asyncio
import json
from importlib import import_module
from pathlib import Path

import typer
from pydantic import BaseModel

from ragbits.cli.app import CLI
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.prompt.prompt import ChatFormat, Prompt


def _render(prompt_path: str, payload: str | None) -> Prompt:
module_stringified, object_stringified = prompt_path.split(":")
prompt_cls = getattr(import_module(module_stringified), object_stringified)

if payload is not None:
payload = json.loads(payload)
inputs = prompt_cls.input_type(**payload)
return prompt_cls(inputs)

return prompt_cls()


class LLMResponseCliOutput(BaseModel):
"""An output model for llm responses in CLI"""

question: ChatFormat
answer: str | BaseModel | None = None


prompts_app = typer.Typer(no_args_is_help=True)
from ragbits.core.prompt._cli import prompts_app
from ragbits.core.vector_stores._cli import vector_stores_app


def register(app: CLI) -> None:
def register(app: typer.Typer) -> None:
"""
Register the CLI commands for the package.
Args:
app: The Typer object to register the commands with.
"""

@prompts_app.command()
def lab(
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
defined within the current project.
"""
from ragbits.core.prompt.lab.app import lab_app

lab_app(file_pattern=file_pattern, llm_factory=llm_factory)

@prompts_app.command()
def generate_promptfoo_configs(
file_pattern: str = core_config.prompt_path_pattern,
root_path: Path = Path.cwd(), # noqa: B008
target_path: Path = Path("promptfooconfigs"),
) -> None:
"""
Generates the configuration files for the PromptFoo prompts.
"""
from ragbits.core.prompt.promptfoo import generate_configs

generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path)

@prompts_app.command()
def render(prompt_path: str, payload: str | None = None) -> None:
"""
Renders a prompt by loading a class from a module and initializing it with a given payload.
"""
prompt = _render(prompt_path=prompt_path, payload=payload)
response = LLMResponseCliOutput(question=prompt.chat)
app.print_output(response)

@prompts_app.command(name="exec")
def execute(
prompt_path: str,
payload: str | None = None,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Executes a prompt using the specified prompt class and LLM factory.
Raises:
ValueError: If `llm_factory` is not provided.
"""
prompt = _render(prompt_path=prompt_path, payload=payload)

if llm_factory is None:
raise ValueError("`llm_factory` must be provided")
llm: LLM = LLM.subclass_from_factory(llm_factory)

llm_output = asyncio.run(llm.generate(prompt))
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output)
app.print_output(response)

app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts")
app.add_typer(vector_stores_app, name="vector-store", help="Commands for managing vector stores")
Loading

0 comments on commit 94b1e94

Please sign in to comment.