-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(vector-store): CLI commands for managing vector stores (#244)
- Loading branch information
1 parent
1d0fdf8
commit 94b1e94
Showing
33 changed files
with
710 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Ragbits CLI | ||
|
||
Ragbits comes with a command line interface (CLI) that provides a number of commands for working with the Ragbits platform. It can be accessed by running the `ragbits` command in your terminal. | ||
|
||
::: mkdocs-click | ||
:module: ragbits.cli | ||
:command: _click_app | ||
:prog_name: ragbits | ||
:style: table | ||
:list_subcommands: true | ||
:depth: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Literal | ||
|
||
from ragbits import cli | ||
|
||
|
||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool) -> None: | ||
""" | ||
Hook that runs during mkdocs startup. | ||
Args: | ||
command: The command that is being run. | ||
dirty: whether --dirty flag was passed. | ||
""" | ||
cli._init_for_mkdocs() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import json | ||
from collections.abc import Sequence | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel | ||
from rich.console import Console | ||
from rich.table import Table | ||
|
||
|
||
class OutputType(Enum): | ||
"""Indicates a type of CLI output formatting""" | ||
|
||
text = "text" | ||
json = "json" | ||
|
||
|
||
@dataclass() | ||
class CliState: | ||
"""A dataclass describing CLI state""" | ||
|
||
output_type: OutputType = OutputType.text | ||
|
||
|
||
cli_state = CliState() | ||
|
||
|
||
def print_output(data: Sequence[BaseModel] | BaseModel) -> None: | ||
""" | ||
Process and display output based on the current state's output type. | ||
Args: | ||
data: a list of pydantic models representing output of CLI function | ||
""" | ||
console = Console() | ||
if isinstance(data, BaseModel): | ||
data = [data] | ||
if len(data) == 0: | ||
_print_empty_list() | ||
return | ||
first_el_instance = type(data[0]) | ||
if any(not isinstance(datapoint, first_el_instance) for datapoint in data): | ||
raise ValueError("All the rows need to be of the same type") | ||
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data] | ||
output_type = cli_state.output_type | ||
if output_type == OutputType.json: | ||
console.print(json.dumps(data_dicts, indent=4)) | ||
elif output_type == OutputType.text: | ||
table = Table(show_header=True, header_style="bold magenta") | ||
properties = data[0].model_json_schema()["properties"] | ||
for key in properties: | ||
table.add_column(properties[key]["title"]) | ||
for row in data_dicts: | ||
table.add_row(*[str(value) for value in row.values()]) | ||
console.print(table) | ||
else: | ||
raise ValueError(f"Output type: {output_type} not supported") | ||
|
||
|
||
def _print_empty_list() -> None: | ||
if cli_state.output_type == OutputType.text: | ||
print("Empty data list") | ||
elif cli_state.output_type == OutputType.json: | ||
print(json.dumps([])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,15 @@ | ||
# pylint: disable=import-outside-toplevel | ||
# pylint: disable=missing-param-doc | ||
import asyncio | ||
import json | ||
from importlib import import_module | ||
from pathlib import Path | ||
|
||
import typer | ||
from pydantic import BaseModel | ||
|
||
from ragbits.cli.app import CLI | ||
from ragbits.core.config import core_config | ||
from ragbits.core.llms.base import LLM, LLMType | ||
from ragbits.core.prompt.prompt import ChatFormat, Prompt | ||
|
||
|
||
def _render(prompt_path: str, payload: str | None) -> Prompt: | ||
module_stringified, object_stringified = prompt_path.split(":") | ||
prompt_cls = getattr(import_module(module_stringified), object_stringified) | ||
|
||
if payload is not None: | ||
payload = json.loads(payload) | ||
inputs = prompt_cls.input_type(**payload) | ||
return prompt_cls(inputs) | ||
|
||
return prompt_cls() | ||
|
||
|
||
class LLMResponseCliOutput(BaseModel): | ||
"""An output model for llm responses in CLI""" | ||
|
||
question: ChatFormat | ||
answer: str | BaseModel | None = None | ||
|
||
|
||
prompts_app = typer.Typer(no_args_is_help=True) | ||
from ragbits.core.prompt._cli import prompts_app | ||
from ragbits.core.vector_stores._cli import vector_stores_app | ||
|
||
|
||
def register(app: CLI) -> None: | ||
def register(app: typer.Typer) -> None: | ||
""" | ||
Register the CLI commands for the package. | ||
Args: | ||
app: The Typer object to register the commands with. | ||
""" | ||
|
||
@prompts_app.command() | ||
def lab( | ||
file_pattern: str = core_config.prompt_path_pattern, | ||
llm_factory: str = core_config.default_llm_factories[LLMType.TEXT], | ||
) -> None: | ||
""" | ||
Launches the interactive application for listing, rendering, and testing prompts | ||
defined within the current project. | ||
""" | ||
from ragbits.core.prompt.lab.app import lab_app | ||
|
||
lab_app(file_pattern=file_pattern, llm_factory=llm_factory) | ||
|
||
@prompts_app.command() | ||
def generate_promptfoo_configs( | ||
file_pattern: str = core_config.prompt_path_pattern, | ||
root_path: Path = Path.cwd(), # noqa: B008 | ||
target_path: Path = Path("promptfooconfigs"), | ||
) -> None: | ||
""" | ||
Generates the configuration files for the PromptFoo prompts. | ||
""" | ||
from ragbits.core.prompt.promptfoo import generate_configs | ||
|
||
generate_configs(file_pattern=file_pattern, root_path=root_path, target_path=target_path) | ||
|
||
@prompts_app.command() | ||
def render(prompt_path: str, payload: str | None = None) -> None: | ||
""" | ||
Renders a prompt by loading a class from a module and initializing it with a given payload. | ||
""" | ||
prompt = _render(prompt_path=prompt_path, payload=payload) | ||
response = LLMResponseCliOutput(question=prompt.chat) | ||
app.print_output(response) | ||
|
||
@prompts_app.command(name="exec") | ||
def execute( | ||
prompt_path: str, | ||
payload: str | None = None, | ||
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT], | ||
) -> None: | ||
""" | ||
Executes a prompt using the specified prompt class and LLM factory. | ||
Raises: | ||
ValueError: If `llm_factory` is not provided. | ||
""" | ||
prompt = _render(prompt_path=prompt_path, payload=payload) | ||
|
||
if llm_factory is None: | ||
raise ValueError("`llm_factory` must be provided") | ||
llm: LLM = LLM.subclass_from_factory(llm_factory) | ||
|
||
llm_output = asyncio.run(llm.generate(prompt)) | ||
response = LLMResponseCliOutput(question=prompt.chat, answer=llm_output) | ||
app.print_output(response) | ||
|
||
app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts") | ||
app.add_typer(vector_stores_app, name="vector-store", help="Commands for managing vector stores") |
Oops, something went wrong.