Skip to content

Commit

Permalink
checks
Browse files Browse the repository at this point in the history
  • Loading branch information
kdziedzic68 committed Dec 4, 2024
1 parent 636bb9e commit 6e03d49
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 25 deletions.
31 changes: 11 additions & 20 deletions packages/ragbits-cli/src/ragbits/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@ class CliState:
output_type: str = "text"


CliOutputFormat = list[dict[str, Any]] | list[BaseModel]


class CLI(typer.Typer):
"""A CLI class with output formatting"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any): # noqa: ANN401
super().__init__(*args, **kwargs)
self.state: CliState = CliState()
self.console: Console = Console()
Expand All @@ -38,30 +35,24 @@ def set_output_type(self, output_type: str) -> None:
raise ValueError("Output type must be either 'text' or 'json'")
self.state.output_type = output_type

def print_output(self, data: CliOutputFormat) -> None:
def print_output(self, data: list[BaseModel]) -> None:
"""
Process and display output based on the current state's output type.
Args:
data: list of ditionaries or list of pydantic models representing output of CLI function
"""
if isinstance(data[0], BaseModel):
data = [output.model_dump(mode="python") for output in data] # type: ignore
first_el_instance = data[0].__class__
if any(not isinstance(datapoint, first_el_instance) for datapoint in data):
raise ValueError("All the rows need to be of the same type")
data_dicts: list[dict] = [output.model_dump(mode="python") for output in data]
output_type = self.state.output_type
if output_type == "json":
try:
print(json.dumps(data, indent=4))
except TypeError as err:
raise ValueError("Output data is not JSON serializable") from err
else:
if not data or not isinstance(data, list) or not isinstance(data[0], dict):
raise ValueError("For text output, data must be a list of dictionaries.")

print(json.dumps(data_dicts, indent=4))
elif output_type == "text":
table = Table(show_header=True, header_style="bold magenta")
for key in data[0]:
for key in data_dicts[0]:
table.add_column(key.title())

for row in data:
table.add_row(*[str(value) for value in row.values()]) # type: ignore

for row in data_dicts:
table.add_row(*[str(value) for value in row.values()])
self.console.print(table)
8 changes: 3 additions & 5 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import json
from importlib import import_module
from pathlib import Path
from typing import cast

import typer
from pydantic import BaseModel

from ragbits.cli.app import CLI, CliOutputFormat
from ragbits.cli.app import CLI
from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.prompt.prompt import ChatFormat, Prompt
Expand Down Expand Up @@ -78,7 +77,7 @@ def render(prompt_path: str, payload: str | None = None) -> None:
"""
prompt = _render(prompt_path=prompt_path, payload=payload)
response = CliOutput(question=prompt.chat)
app.print_output(cast(CliOutputFormat, [response]))
app.print_output([response])

@prompts_app.command(name="exec")
def execute(
Expand All @@ -92,7 +91,6 @@ def execute(
Raises:
ValueError: If `llm_factory` is not provided.
"""
print(app.state)
from ragbits.core.llms.factory import get_llm_from_factory

prompt = _render(prompt_path=prompt_path, payload=payload)
Expand All @@ -103,6 +101,6 @@ def execute(

llm_output = asyncio.run(llm.generate(prompt))
response = CliOutput(question=prompt.chat, answer=llm_output)
app.print_output(cast(CliOutputFormat, [response]))
app.print_output([response])

app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts")

0 comments on commit 6e03d49

Please sign in to comment.