Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): add option to choose which columns to display #257

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 80 additions & 32 deletions packages/ragbits-cli/src/ragbits/cli/state.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from enum import Enum
from typing import TypeVar

import typer
from pydantic import BaseModel
from rich.console import Console
from rich.table import Table
from rich.table import Column, Table


class OutputType(Enum):
Expand All @@ -24,41 +26,87 @@ class CliState:

cli_state = CliState()

ModelT = TypeVar("ModelT", bound=BaseModel)

def print_output(data: Sequence[BaseModel] | BaseModel) -> None:

def print_output_table(
data: Sequence[ModelT], columns: Mapping[str, Column] | Sequence[str] | str | None = None
) -> None:
"""
Process and display output based on the current state's output type.
Display data from Pydantic models in a table format.

Args:
data: a list of pydantic models representing output of CLI function
columns: a list of columns to display in the output table: either as a list, string with comma separated names,
or for grater control over how the data is displayed a mapping of column names to Column objects.
If not provided, the columns will be inferred from the model schema.
"""
console = Console()
if isinstance(data, BaseModel):
data = [data]
if len(data) == 0:
_print_empty_list()

if not data:
console.print("No results")
return
first_el_instance = type(data[0])
if any(not isinstance(datapoint, first_el_instance) for datapoint in data):
raise ValueError("All the rows need to be of the same type")
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data]
output_type = cli_state.output_type
if output_type == OutputType.json:
console.print(json.dumps(data_dicts, indent=4))
elif output_type == OutputType.text:
table = Table(show_header=True, header_style="bold magenta")
properties = data[0].model_json_schema()["properties"]
for key in properties:
table.add_column(properties[key]["title"])
for row in data_dicts:
table.add_row(*[str(value) for value in row.values()])
console.print(table)
else:
raise ValueError(f"Output type: {output_type} not supported")


def _print_empty_list() -> None:
if cli_state.output_type == OutputType.text:
print("Empty data list")
elif cli_state.output_type == OutputType.json:
print(json.dumps([]))

fields = data[0].model_fields

# Human readable titles for columns
titles = {key: value.get("title", key) for key, value in data[0].model_json_schema()["properties"].items()}

# Normalize the list of columns
if columns is None:
columns = {key: Column() for key in fields}
elif isinstance(columns, str):
columns = {key: Column() for key in columns.split(",")}
elif isinstance(columns, Sequence):
columns = {key: Column() for key in columns}

# Add headers to columns if not provided
for key in columns:
if key not in fields:
Console(stderr=True).print(f"Unknown column: {key}")
raise typer.Exit(1)

column = columns[key]
if column.header == "":
column.header = titles.get(key, key)

# Create and print the table
table = Table(*columns.values(), show_header=True, header_style="bold magenta")
for row in data:
table.add_row(*[str(getattr(row, key)) for key in columns])
console.print(table)


def print_output_json(data: Sequence[ModelT]) -> None:
"""
Display data from Pydantic models in a JSON format.

Args:
data: a list of pydantic models representing output of CLI function
"""
console = Console()
console.print(json.dumps([output.model_dump(mode="json") for output in data], indent=4))


def print_output(
data: Sequence[ModelT] | ModelT, columns: Mapping[str, Column] | Sequence[str] | str | None = None
) -> None:
"""
Process and display output based on the current state's output type.

Args:
data: a list of pydantic models representing output of CLI function
columns: a list of columns to display in the output table: either as a list, string with comma separated names,
or for grater control over how the data is displayed a mapping of column names to Column objects.
If not provided, the columns will be inferred from the model schema.
"""
if not isinstance(data, Sequence):
data = [data]

match cli_state.output_type:
case OutputType.text:
print_output_table(data, columns)
case OutputType.json:
print_output_json(data)
case _:
raise ValueError(f"Unsupported output type: {cli_state.output_type}")
53 changes: 42 additions & 11 deletions packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated

import typer
from pydantic import BaseModel
Expand All @@ -21,11 +22,22 @@ class CLIState:

state: CLIState = CLIState()

# Default columns for commands that list entries
_default_columns = "id,key,metadata"


@vector_stores_app.callback()
def common_args(
factory_path: str | None = None,
yaml_path: Path | None = None,
factory_path: Annotated[
str | None,
typer.Option(
help="Python path to a function that creates a vector store, ina format 'module.submodule:function'"
),
] = None,
yaml_path: Annotated[
Path | None,
typer.Option(help="Path to a YAML configuration file for the vector store", exists=True, resolve_path=True),
] = None,
) -> None:
state.vector_store = get_instance_or_exit(
VectorStore,
Expand All @@ -35,7 +47,13 @@ def common_args(


@vector_stores_app.command(name="list")
def list_entries(limit: int = 10, offset: int = 0) -> None:
def list_entries(
limit: Annotated[int, typer.Option(help="Maximum number of entries to list")] = 10,
offset: Annotated[int, typer.Option(help="How many entries to skip")] = 0,
columns: Annotated[
str, typer.Option(help="Comma-separated list of columns to display, aviailable: id, key, vector, metadata")
] = _default_columns,
) -> None:
"""
List all objects in the chosen vector store.
"""
Expand All @@ -45,7 +63,7 @@ async def run() -> None:
raise ValueError("Vector store not initialized")

entries = await state.vector_store.list(limit=limit, offset=offset)
print_output(entries)
print_output(entries, columns=columns)

asyncio.run(run())

Expand All @@ -55,7 +73,9 @@ class RemovedItem(BaseModel):


@vector_stores_app.command()
def remove(ids: list[str]) -> None:
def remove(
ids: Annotated[list[str], typer.Argument(help="IDs of the entries to remove from the vector store")],
) -> None:
"""
Remove objects from the chosen vector store.
"""
Expand All @@ -75,11 +95,22 @@ async def run() -> None:

@vector_stores_app.command()
def query(
text: str,
k: int = 5,
max_distance: float | None = None,
embedder_factory_path: str | None = None,
embedder_yaml_path: Path | None = None,
text: Annotated[str, typer.Argument(help="Text to query the vector store with")],
k: Annotated[int, typer.Option(help="Number of entries to retrieve")] = 5,
max_distance: Annotated[float | None, typer.Option(help="Maximum distance to the query vector")] = None,
embedder_factory_path: Annotated[
str | None,
typer.Option(
help="Python path to a function that creates an embedder, in a format 'module.submodule:function'"
),
] = None,
embedder_yaml_path: Annotated[
Path | None,
typer.Option(help="Path to a YAML configuration file for the embedder", exists=True, resolve_path=True),
] = None,
columns: Annotated[
str, typer.Option(help="Comma-separated list of columns to display, aviailable: id, key, vector, metadata")
] = _default_columns,
) -> None:
"""
Query the chosen vector store.
Expand All @@ -104,6 +135,6 @@ async def run() -> None:
vector=search_vector[0],
options=options,
)
print_output(entries)
print_output(entries, columns=columns)

asyncio.run(run())
41 changes: 41 additions & 0 deletions packages/ragbits-core/tests/cli/test_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,47 @@ def test_vector_store_list_limit_offset():
assert "entry 3" not in result.stdout


def test_vector_store_list_columns():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
vector_stores_app,
["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--columns", "id,key,metadata"],
)
assert result.exit_code == 0
assert "entry 1" in result.stdout
assert "entry 2" in result.stdout
assert "entry 3" in result.stdout
assert "Vector" not in result.stdout
assert "Id" in result.stdout
assert "Key" in result.stdout
assert "Metadata" in result.stdout
assert "another_key" in result.stdout

result = runner.invoke(
vector_stores_app,
["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--columns", "id,key"],
)
assert result.exit_code == 0
assert "entry 1" in result.stdout
assert "entry 2" in result.stdout
assert "entry 3" in result.stdout
assert "Vector" not in result.stdout
assert "Id" in result.stdout
assert "Key" in result.stdout
assert "Metadata" not in result.stdout
assert "another_key" not in result.stdout


def test_vector_store_list_columns_non_existent():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
vector_stores_app,
["--factory-path", "cli.test_vector_store:vector_store_factory", "list", "--columns", "id,key,non_existent"],
)
assert result.exit_code == 1
assert "Unknown column: non_existent" in result.stderr


def test_vector_store_remove():
runner = CliRunner(mix_stderr=False)
result = runner.invoke(
Expand Down
Loading