Skip to content

Commit

Permalink
feat: add cli event handler (#256)
Browse files Browse the repository at this point in the history
Co-authored-by: Mateusz Hordyński <[email protected]>
Co-authored-by: Michal Pstrag <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2025
1 parent 6f3f08f commit 713ed0a
Show file tree
Hide file tree
Showing 18 changed files with 506 additions and 83 deletions.
3 changes: 3 additions & 0 deletions examples/document-search/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@

import asyncio

from ragbits.core import audit
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

audit.set_trace_handlers("cli")

documents = [
DocumentMeta.create_text_document_from_literal(
"""
Expand Down
3 changes: 3 additions & 0 deletions examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@

from chromadb import EphemeralClient

from ragbits.core import audit
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings, LiteLLMEmbeddingsOptions
from ragbits.core.vector_stores import VectorStoreOptions
from ragbits.core.vector_stores.chroma import ChromaVectorStore
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta

audit.set_trace_handlers("cli")

documents = [
DocumentMeta.create_text_document_from_literal(
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Ragbits Document Search Example: DocumentSearch from Config
Ragbits Document Search Example: Configurable DocumentSearch
This example demonstrates how to use the `DocumentSearch` class to search for documents with a more advanced setup.
We will use the `LiteLLMEmbeddings` class to embed the documents and the query, the `ChromaVectorStore` class to store
Expand Down Expand Up @@ -30,9 +30,12 @@ class to rephrase the query.

import asyncio

from ragbits.core import audit
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

audit.set_trace_handlers("cli")

documents = [
DocumentMeta.create_text_document_from_literal(
"""
Expand Down
5 changes: 4 additions & 1 deletion examples/document-search/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Ragbits Document Search Example: Basic wtih distributed ingestion
Ragbits Document Search Example: Distributed Ingest
This example is based on the "Basic" example, but it demonstrates how to ingest documents in a distributed manner.
The distributed ingestion is provided by "DistributedProcessing" which uses Ray to parallelize the ingestion process.
Expand Down Expand Up @@ -31,12 +31,15 @@

import asyncio

from ragbits.core import audit
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.ingestion.processor_strategies import DistributedProcessing

audit.set_trace_handlers("cli")

documents = [
DocumentMeta.create_text_document_from_literal(
"""
Expand Down
3 changes: 3 additions & 0 deletions examples/document-search/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import asyncio
from pathlib import Path

from ragbits.core import audit
from ragbits.core.embeddings.vertex_multimodal import VertexAIMultimodelEmbeddings
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
Expand All @@ -41,6 +42,8 @@
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.dummy import DummyImageProvider

audit.set_trace_handlers("cli")

IMAGES_PATH = Path(__file__).parent / "images"


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Ragbits Document Search Example: Chroma x OpenTelemetry
Ragbits Document Search Example: OpenTelemetry
This example demonstrates how to use the `DocumentSearch` class to search for documents with a more advanced setup.
We will use the `LiteLLMEmbeddings` class to embed the documents and the query, the `ChromaVectorStore` class to store
Expand Down
3 changes: 3 additions & 0 deletions examples/document-search/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@

from qdrant_client import AsyncQdrantClient

from ragbits.core import audit
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_stores.qdrant import QdrantVectorStore
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta

audit.set_trace_handlers("cli")

documents = [
DocumentMeta.create_text_document_from_literal(
"""
Expand Down
7 changes: 7 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typer.main import get_command

import ragbits
from ragbits.core import audit

from .state import OutputType, cli_state, print_output

Expand All @@ -28,9 +29,15 @@ def ragbits_cli(
output: Annotated[
OutputType, typer.Option("--output", "-o", help="Set the output type (text or json)")
] = OutputType.text.value, # type: ignore
verbose: bool = typer.Option(False, "--verbose", "-v", help="Print additional information"),
) -> None:
"""Common CLI arguments for all ragbits commands."""
cli_state.output_type = output
cli_state.verbose = verbose

if verbose == 1:
typer.echo("Verbose mode is enabled.")
audit.set_trace_handlers("cli")


def autoregister() -> None:
Expand Down
1 change: 1 addition & 0 deletions packages/ragbits-cli/src/ragbits/cli/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class OutputType(Enum):
class CliState:
"""A dataclass describing CLI state"""

verbose: bool = False
output_type: OutputType = OutputType.text


Expand Down
9 changes: 9 additions & 0 deletions packages/ragbits-core/src/ragbits/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os

import typer

from ragbits.core import audit

if os.getenv("RAGBITS_VERBOSE", "0") == "1":
typer.echo('Verbose mode is enabled with environment variable "RAGBITS_VERBOSE".')
audit.set_trace_handlers("cli")
17 changes: 16 additions & 1 deletion packages/ragbits-core/src/ragbits/core/audit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,28 @@ def set_trace_handlers(handlers: Handler | list[Handler]) -> None:
if handler == "otel":
from ragbits.core.audit.otel import OtelTraceHandler

_trace_handlers.append(OtelTraceHandler())
if not any(isinstance(item, OtelTraceHandler) for item in _trace_handlers):
_trace_handlers.append(OtelTraceHandler())
elif handler == "cli":
from ragbits.core.audit.cli import CLITraceHandler

if not any(isinstance(item, CLITraceHandler) for item in _trace_handlers):
_trace_handlers.append(CLITraceHandler())
else:
raise ValueError(f"Handler {handler} not found.")
else:
raise TypeError(f"Invalid handler type: {type(handler)}")


def clear_event_handlers() -> None:
"""
Clear all trace handlers.
"""
global _trace_handlers # noqa: PLW0602

_trace_handlers.clear()


@contextmanager
def trace(name: str | None = None, **inputs: Any) -> Iterator[SimpleNamespace]: # noqa: ANN401
"""
Expand Down
33 changes: 33 additions & 0 deletions packages/ragbits-core/src/ragbits/core/audit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,36 @@ def trace(self, name: str, **inputs: Any) -> Iterator[SimpleNamespace]: # noqa:

span = self._spans.get().pop()
self.stop(outputs=vars(outputs), current_span=span)


def format_attributes(data: dict, prefix: str | None = None) -> dict:
"""
Format attributes for CLI.
Args:
data: The data to format.
prefix: The prefix to use for the keys.
Returns:
The formatted attributes.
"""
flattened = {}

for key, value in data.items():
current_key = f"{prefix}.{key}" if prefix else key

if isinstance(value, dict):
flattened.update(format_attributes(value, current_key))
elif isinstance(value, list | tuple):
flattened[current_key] = repr(
[
item if isinstance(item, str | float | int | bool) else repr(item)
for item in value # type: ignore
]
)
elif isinstance(value, str | float | int | bool):
flattened[current_key] = value
else:
flattened[current_key] = repr(value)

return flattened
159 changes: 159 additions & 0 deletions packages/ragbits-core/src/ragbits/core/audit/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import time
from enum import Enum

from rich.live import Live
from rich.tree import Tree

from ragbits.core.audit.base import TraceHandler, format_attributes


class SpanStatus(Enum):
"""
SpanStatus represents the status of the span.
"""

ERROR = "ERROR"
STARTED = "STARTED"
COMPLETED = "COMPLETED"


class PrintColor(str, Enum):
"""
SpanPrintColor represents the color of font for printing the span related information to the console.
"""

RUNNING_COLOR = "bold blue"
END_COLOR = "bold green"
ERROR_COLOR = "bold red"
TEXT_COLOR = "grey50"
KEY_COLOR = "plum4"


class CLISpan:
"""
CLI Span represents a single operation within a trace.
"""

def __init__(self, name: str, attributes: dict, parent: "CLISpan | None" = None) -> None:
"""
Constructs a new CLI Span.
Args:
name: The name of the span.
attributes: The attributes of the span.
parent: the parent of initiated span.
"""
self.name = name
self.parent = parent
self.attributes = attributes
self.start_time = time.perf_counter()
self.end_time: float | None = None
self.status = SpanStatus.STARTED
self.tree = Tree("")
if self.parent is not None:
self.parent.tree.add(self.tree)

def update(self) -> None:
"""
Updates tree label based on span state.
"""
elapsed = f": {(self.end_time - self.start_time):.3f}s" if self.end_time else " ..."
color = {
SpanStatus.ERROR: PrintColor.ERROR_COLOR,
SpanStatus.STARTED: PrintColor.RUNNING_COLOR,
SpanStatus.COMPLETED: PrintColor.END_COLOR,
}[self.status]
name = f"[{color}]{self.name}[/{color}][{PrintColor.TEXT_COLOR}]{elapsed}[/{PrintColor.TEXT_COLOR}]"

# TODO: Remove truncating after implementing better CLI formatting.
attrs = [
f"[{PrintColor.KEY_COLOR}]{k}:[/{PrintColor.KEY_COLOR}] "
f"[{PrintColor.TEXT_COLOR}]{str(v)[:120] + ' (...)' if len(str(v)) > 120 else v}[/{PrintColor.TEXT_COLOR}]" # noqa: PLR2004
for k, v in self.attributes.items()
]
self.tree.label = f"{name}\n{chr(10).join(attrs)}" if attrs else name

def end(self) -> None:
"""
Sets the current time as the span's end time.
The span's end time is the wall time at which the operation finished.
Only the first call to `end` should modify the span, further calls are ignored.
"""
if self.end_time is None:
self.end_time = time.perf_counter()


class CLITraceHandler(TraceHandler[CLISpan]):
"""
CLITraceHandler class for all trace handlers.
"""

def __init__(self) -> None:
super().__init__()
self.live = Live(auto_refresh=False)

def start(self, name: str, inputs: dict, current_span: CLISpan | None = None) -> CLISpan:
"""
Log input data at the beginning of the trace.
Args:
name: The name of the trace.
inputs: The input data.
current_span: The current trace span.
Returns:
The updated current trace span.
"""
attributes = format_attributes(inputs, prefix="inputs")
span = CLISpan(
name=name,
attributes=attributes,
parent=current_span,
)
if current_span is None:
self.live = Live(auto_refresh=False)
self.live.start()
self.tree = span.tree

span.update()
self.live.update(self.tree, refresh=True)

return span

def stop(self, outputs: dict, current_span: CLISpan) -> None:
"""
Log output data at the end of the trace.
Args:
outputs: The output data.
current_span: The current trace span.
"""
attributes = format_attributes(outputs, prefix="outputs")
current_span.attributes.update(attributes)
current_span.status = SpanStatus.COMPLETED
current_span.end()

current_span.update()
self.live.update(self.tree, refresh=True)

if current_span.parent is None:
self.live.stop()

def error(self, error: Exception, current_span: CLISpan) -> None:
"""
Log error during the trace.
Args:
error: The error that occurred.
current_span: The current trace span.
"""
attributes = format_attributes({"message": str(error), **vars(error)}, prefix="error")
current_span.attributes.update(attributes)
current_span.status = SpanStatus.ERROR
current_span.end()

current_span.update()
self.live.update(self.tree, refresh=True)

if current_span.parent is None:
self.live.stop()
Loading

0 comments on commit 713ed0a

Please sign in to comment.