Skip to content

Commit

Permalink
add audit core
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Oct 31, 2024
1 parent b245164 commit 37f9977
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 32 deletions.
46 changes: 34 additions & 12 deletions examples/document-search/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,44 @@
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits-core[litellm]",
# "ragbits-core[litellm,otel]",
# ]
# ///
import asyncio

from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from ragbits.core import audit
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

provider = TracerProvider(resource=Resource({SERVICE_NAME: "ragbits"}))
provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter()))
# provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter("http://localhost:4317", insecure=True)))
trace.set_tracer_provider(provider)

audit.set_trace_handlers("otel")
# audit.set_trace_handlers(["langsmith", "otel"])
documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
DocumentMeta.create_text_document_from_literal(
"Why doesn't James Bond fart in bed? Because it would blow his cover."
"""
RIP boiled water. You will be mist.
"""
),
DocumentMeta.create_text_document_from_literal(
"""
Why doesn't James Bond fart in bed? Because it would blow his cover.
"""
),
DocumentMeta.create_text_document_from_literal(
"Why programmers don't like to swim? Because they're scared of the floating points."
"""
Why programmers don't like to swim? Because they're scared of the floating points.
"""
),
]

Expand All @@ -30,16 +51,17 @@ async def main() -> None:
embedder = LiteLLMEmbeddings(
model="text-embedding-3-small",
)
vector_store = InMemoryVectorStore()
document_search = DocumentSearch(
embedder=embedder,
vector_store=vector_store,
)
results = await embedder.embed_text(["I'm boiling my water and I need a", "joke"])
# vector_store = InMemoryVectorStore()
# document_search = DocumentSearch(
# embedder=embedder,
# vector_store=vector_store,
# )

await document_search.ingest(documents)
# await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
# results = await document_search.search("I'm boiling my water and I need a joke")
# print(results)


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ lab = [
promptfoo = [
"PyYAML~=6.0.2",
]
langsmith = [
"langsmith~=0.1.137",
]
otel = [
"opentelemetry-api~=1.27.0",
]

[tool.uv]
dev-dependencies = [
Expand Down
90 changes: 90 additions & 0 deletions packages/ragbits-core/src/ragbits/core/audit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import asyncio
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from functools import wraps
from types import SimpleNamespace
from typing import Any, ParamSpec, TypeVar
from ragbits.core.audit.base import TraceHandler

_trace_handlers: list[TraceHandler] = []

Handler = str | TraceHandler

P = ParamSpec("P")
R = TypeVar("R")


def traceable(func: Callable[P, R]) -> Callable[P, R]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
print(*args, **kwargs)
return func(*args, **kwargs)

@wraps(func)
async def wrapper_async(*args: P.args, **kwargs: P.kwargs) -> R:
print("asuync", *args, **kwargs)
return await func(*args, **kwargs) # type: ignore

if asyncio.iscoroutinefunction(func):
return wrapper_async # type: ignore
return wrapper


@asynccontextmanager
async def trace(**inputs: Any) -> AsyncIterator[SimpleNamespace]:
"""
Context manager for processing an event.
Args:
event: The event to be processed.
Yields:
The event being processed.
"""
for handler in _trace_handlers:
await handler.on_start(**inputs)

try:
yield (outputs := SimpleNamespace())
except Exception as exc:
for handler in _trace_handlers:
await handler.on_error(exc)
raise exc

for handler in _trace_handlers:
await handler.on_end(**vars(outputs))


def set_trace_handlers(handlers: Handler | list[Handler]) -> None:
"""
Setup event handlers.
Args:
handlers: List of event handlers to be used.
Raises:
ValueError: If handler is not found.
TypeError: If handler type is invalid.
"""
global _trace_handlers

if isinstance(handlers, str):
handlers = [handlers]

for handler in handlers: # type: ignore
if isinstance(handler, TraceHandler):
_trace_handlers.append(handler)
elif isinstance(handler, str):
if handler == "otel":
from ragbits.core.audit.otel import OtelTraceHandler
_trace_handlers.append(OtelTraceHandler())
if handler == "langsmith":
from ragbits.core.audit.langsmith import LangSmithTraceHandler
_trace_handlers.append(LangSmithTraceHandler())
else:
raise ValueError(f"Handler {handler} not found.")
else:
raise TypeError(f"Invalid handler type: {type(handler)}")


__all__ = ["TraceHandler", "traceable", "trace", "set_trace_handlers"]
35 changes: 35 additions & 0 deletions packages/ragbits-core/src/ragbits/core/audit/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from typing import Any


class TraceHandler(ABC):
"""
Base class for all trace handlers.
"""

@abstractmethod
async def on_start(self, **inputs: Any) -> None:
"""
Log input data at the start of the event.
Args:
inputs: The input data.
"""

@abstractmethod
async def on_end(self, **outputs: Any) -> None:
"""
Log output data at the end of the event.
Args:
outputs: The output data.
"""

@abstractmethod
async def on_error(self, error: Exception) -> None:
"""
Log error during the event.
Args:
error: The error that occurred.
"""
29 changes: 29 additions & 0 deletions packages/ragbits-core/src/ragbits/core/audit/otel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

from opentelemetry import trace
from opentelemetry.trace import Span, StatusCode
from ragbits.core.audit.base import TraceHandler


class OtelTraceHandler(TraceHandler):
"""
OpenTelemetry trace handler.
"""

def __init__(self) -> None:
self._tracer = trace.get_tracer("ragbits.events")

async def on_start(self, **inputs: Any) -> None:
with self._tracer.start_as_current_span("request", end_on_exit=False) as span:
for key, value in inputs.items():
span.set_attribute(key, value)

async def on_end(self, **outputs: Any) -> None:
span = trace.get_current_span()
for key, value in outputs.items():
span.set_attribute(key, value)
span.set_status(StatusCode.OK)
span.end()

async def on_error(self, error: Exception) -> None:
print("on_error", error)
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,12 @@ class EmbeddingResponseError(EmbeddingError):

def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
super().__init__(message)


class EmbeddingEmptyResponseError(EmbeddingError):
"""
Raised when an API response has an empty response.
"""

def __init__(self, message: str = "Empty response returned by API.") -> None:
super().__init__(message)
45 changes: 29 additions & 16 deletions packages/ragbits-core/src/ragbits/core/embeddings/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
except ImportError:
HAS_LITELLM = False

from ragbits.core.audit import trace
from ragbits.core.embeddings import Embeddings
from ragbits.core.embeddings.exceptions import (
EmbeddingConnectionError,
EmbeddingEmptyResponseError,
EmbeddingResponseError,
EmbeddingStatusError,
)
Expand Down Expand Up @@ -64,23 +66,34 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
Raises:
EmbeddingConnectionError: If there is a connection error with the embedding API.
EmbeddingEmptyResponseError: If the embedding API returns an empty response.
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
try:
response = await litellm.aembedding(
input=data,
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**self.options,
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise EmbeddingResponseError() from exc
async with trace(texts=data) as outputs:
try:
response = await litellm.aembedding(
input=data,
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**self.options,
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise EmbeddingResponseError() from exc

return [embedding["embedding"] for embedding in response.data]
if not response.data:
raise EmbeddingEmptyResponseError()

outputs.embeddings = [embedding["embedding"] for embedding in response.data]
if response.usage:
outputs.completion_tokens = response.usage.completion_tokens
outputs.prompt_tokens = response.usage.prompt_tokens
outputs.total_tokens = response.usage.total_tokens

return outputs.embeddings
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"ragbits-cli",
"ragbits-core[litellm,local,lab,chroma]",
"ragbits-document-search[gcs, huggingface]",
"ragbits-core[chroma,lab,litellm,local,langsmith,otel]",
"ragbits-document-search[gcs,huggingface]",
"ragbits-evaluate[relari]",
]

Expand Down Expand Up @@ -146,7 +146,6 @@ ignore = [
"PLR0913",
]


[tool.ruff.lint.pydocstyle]
convention = "google"

Expand All @@ -173,7 +172,6 @@ convention = "google"
docstring-code-format = true
docstring-code-line-length = 120


[tool.ruff.lint.isort]
known-first-party = ["ragbits"]
known-third-party = [
Expand Down

0 comments on commit 37f9977

Please sign in to comment.