From 37f99770e73d7cd9d74e0be9519bdf04c8d35097 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Thu, 31 Oct 2024 21:55:56 +0100 Subject: [PATCH] add audit core --- examples/document-search/basic.py | 46 +++++++--- packages/ragbits-core/pyproject.toml | 6 ++ .../src/ragbits/core/audit/__init__.py | 90 +++++++++++++++++++ .../src/ragbits/core/audit/base.py | 35 ++++++++ .../src/ragbits/core/audit/otel.py | 29 ++++++ .../src/ragbits/core/embeddings/exceptions.py | 9 ++ .../src/ragbits/core/embeddings/litellm.py | 45 ++++++---- pyproject.toml | 6 +- 8 files changed, 234 insertions(+), 32 deletions(-) create mode 100644 packages/ragbits-core/src/ragbits/core/audit/__init__.py create mode 100644 packages/ragbits-core/src/ragbits/core/audit/base.py create mode 100644 packages/ragbits-core/src/ragbits/core/audit/otel.py diff --git a/examples/document-search/basic.py b/examples/document-search/basic.py index 3cf9e887c..68c7784ee 100644 --- a/examples/document-search/basic.py +++ b/examples/document-search/basic.py @@ -2,23 +2,44 @@ # requires-python = ">=3.10" # dependencies = [ # "ragbits-document-search", -# "ragbits-core[litellm]", +# "ragbits-core[litellm,otel]", # ] # /// import asyncio +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from ragbits.core import audit from ragbits.core.embeddings.litellm import LiteLLMEmbeddings from ragbits.core.vector_stores.in_memory import InMemoryVectorStore from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta +provider = TracerProvider(resource=Resource({SERVICE_NAME: "ragbits"})) +provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) +# provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter("http://localhost:4317", insecure=True))) +trace.set_tracer_provider(provider) + +audit.set_trace_handlers("otel") +# audit.set_trace_handlers(["langsmith", "otel"]) documents = [ - DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), DocumentMeta.create_text_document_from_literal( - "Why doesn't James Bond fart in bed? Because it would blow his cover." + """ + RIP boiled water. You will be mist. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why doesn't James Bond fart in bed? Because it would blow his cover. + """ ), DocumentMeta.create_text_document_from_literal( - "Why programmers don't like to swim? Because they're scared of the floating points." + """ + Why programmers don't like to swim? Because they're scared of the floating points. + """ ), ] @@ -30,16 +51,17 @@ async def main() -> None: embedder = LiteLLMEmbeddings( model="text-embedding-3-small", ) - vector_store = InMemoryVectorStore() - document_search = DocumentSearch( - embedder=embedder, - vector_store=vector_store, - ) + results = await embedder.embed_text(["I'm boiling my water and I need a", "joke"]) + # vector_store = InMemoryVectorStore() + # document_search = DocumentSearch( + # embedder=embedder, + # vector_store=vector_store, + # ) - await document_search.ingest(documents) + # await document_search.ingest(documents) - results = await document_search.search("I'm boiling my water and I need a joke") - print(results) + # results = await document_search.search("I'm boiling my water and I need a joke") + # print(results) if __name__ == "__main__": diff --git a/packages/ragbits-core/pyproject.toml b/packages/ragbits-core/pyproject.toml index 048d812f9..df2a73749 100644 --- a/packages/ragbits-core/pyproject.toml +++ b/packages/ragbits-core/pyproject.toml @@ -56,6 +56,12 @@ lab = [ promptfoo = [ "PyYAML~=6.0.2", ] +langsmith = [ + "langsmith~=0.1.137", +] +otel = [ + "opentelemetry-api~=1.27.0", +] [tool.uv] dev-dependencies = [ diff --git a/packages/ragbits-core/src/ragbits/core/audit/__init__.py b/packages/ragbits-core/src/ragbits/core/audit/__init__.py new file mode 100644 index 000000000..ca2bad69b --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/audit/__init__.py @@ -0,0 +1,90 @@ +import asyncio +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from functools import wraps +from types import SimpleNamespace +from typing import Any, ParamSpec, TypeVar +from ragbits.core.audit.base import TraceHandler + +_trace_handlers: list[TraceHandler] = [] + +Handler = str | TraceHandler + +P = ParamSpec("P") +R = TypeVar("R") + + +def traceable(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + print(*args, **kwargs) + return func(*args, **kwargs) + + @wraps(func) + async def wrapper_async(*args: P.args, **kwargs: P.kwargs) -> R: + print("asuync", *args, **kwargs) + return await func(*args, **kwargs) # type: ignore + + if asyncio.iscoroutinefunction(func): + return wrapper_async # type: ignore + return wrapper + + +@asynccontextmanager +async def trace(**inputs: Any) -> AsyncIterator[SimpleNamespace]: + """ + Context manager for processing an event. + + Args: + event: The event to be processed. + + Yields: + The event being processed. + """ + for handler in _trace_handlers: + await handler.on_start(**inputs) + + try: + yield (outputs := SimpleNamespace()) + except Exception as exc: + for handler in _trace_handlers: + await handler.on_error(exc) + raise exc + + for handler in _trace_handlers: + await handler.on_end(**vars(outputs)) + + +def set_trace_handlers(handlers: Handler | list[Handler]) -> None: + """ + Setup event handlers. + + Args: + handlers: List of event handlers to be used. + + Raises: + ValueError: If handler is not found. + TypeError: If handler type is invalid. + """ + global _trace_handlers + + if isinstance(handlers, str): + handlers = [handlers] + + for handler in handlers: # type: ignore + if isinstance(handler, TraceHandler): + _trace_handlers.append(handler) + elif isinstance(handler, str): + if handler == "otel": + from ragbits.core.audit.otel import OtelTraceHandler + _trace_handlers.append(OtelTraceHandler()) + if handler == "langsmith": + from ragbits.core.audit.langsmith import LangSmithTraceHandler + _trace_handlers.append(LangSmithTraceHandler()) + else: + raise ValueError(f"Handler {handler} not found.") + else: + raise TypeError(f"Invalid handler type: {type(handler)}") + + +__all__ = ["TraceHandler", "traceable", "trace", "set_trace_handlers"] diff --git a/packages/ragbits-core/src/ragbits/core/audit/base.py b/packages/ragbits-core/src/ragbits/core/audit/base.py new file mode 100644 index 000000000..71f3bea84 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/audit/base.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class TraceHandler(ABC): + """ + Base class for all trace handlers. + """ + + @abstractmethod + async def on_start(self, **inputs: Any) -> None: + """ + Log input data at the start of the event. + + Args: + inputs: The input data. + """ + + @abstractmethod + async def on_end(self, **outputs: Any) -> None: + """ + Log output data at the end of the event. + + Args: + outputs: The output data. + """ + + @abstractmethod + async def on_error(self, error: Exception) -> None: + """ + Log error during the event. + + Args: + error: The error that occurred. + """ diff --git a/packages/ragbits-core/src/ragbits/core/audit/otel.py b/packages/ragbits-core/src/ragbits/core/audit/otel.py new file mode 100644 index 000000000..ee6129e10 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/audit/otel.py @@ -0,0 +1,29 @@ +from typing import Any + +from opentelemetry import trace +from opentelemetry.trace import Span, StatusCode +from ragbits.core.audit.base import TraceHandler + + +class OtelTraceHandler(TraceHandler): + """ + OpenTelemetry trace handler. + """ + + def __init__(self) -> None: + self._tracer = trace.get_tracer("ragbits.events") + + async def on_start(self, **inputs: Any) -> None: + with self._tracer.start_as_current_span("request", end_on_exit=False) as span: + for key, value in inputs.items(): + span.set_attribute(key, value) + + async def on_end(self, **outputs: Any) -> None: + span = trace.get_current_span() + for key, value in outputs.items(): + span.set_attribute(key, value) + span.set_status(StatusCode.OK) + span.end() + + async def on_error(self, error: Exception) -> None: + print("on_error", error) diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py b/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py index 4dd99ad1e..a8a25224a 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py @@ -34,3 +34,12 @@ class EmbeddingResponseError(EmbeddingError): def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: super().__init__(message) + + +class EmbeddingEmptyResponseError(EmbeddingError): + """ + Raised when an API response has an empty response. + """ + + def __init__(self, message: str = "Empty response returned by API.") -> None: + super().__init__(message) diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py index 1b3fbe5b7..0928e5056 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py @@ -5,9 +5,11 @@ except ImportError: HAS_LITELLM = False +from ragbits.core.audit import trace from ragbits.core.embeddings import Embeddings from ragbits.core.embeddings.exceptions import ( EmbeddingConnectionError, + EmbeddingEmptyResponseError, EmbeddingResponseError, EmbeddingStatusError, ) @@ -64,23 +66,34 @@ async def embed_text(self, data: list[str]) -> list[list[float]]: Raises: EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingEmptyResponseError: If the embedding API returns an empty response. EmbeddingStatusError: If the embedding API returns an error status code. EmbeddingResponseError: If the embedding API response is invalid. """ - try: - response = await litellm.aembedding( - input=data, - model=self.model, - api_base=self.api_base, - api_key=self.api_key, - api_version=self.api_version, - **self.options, - ) - except litellm.openai.APIConnectionError as exc: - raise EmbeddingConnectionError() from exc - except litellm.openai.APIStatusError as exc: - raise EmbeddingStatusError(exc.message, exc.status_code) from exc - except litellm.openai.APIResponseValidationError as exc: - raise EmbeddingResponseError() from exc + async with trace(texts=data) as outputs: + try: + response = await litellm.aembedding( + input=data, + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **self.options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc - return [embedding["embedding"] for embedding in response.data] + if not response.data: + raise EmbeddingEmptyResponseError() + + outputs.embeddings = [embedding["embedding"] for embedding in response.data] + if response.usage: + outputs.completion_tokens = response.usage.completion_tokens + outputs.prompt_tokens = response.usage.prompt_tokens + outputs.total_tokens = response.usage.total_tokens + + return outputs.embeddings diff --git a/pyproject.toml b/pyproject.toml index 3e2499ba0..5f31de9e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "ragbits-cli", - "ragbits-core[litellm,local,lab,chroma]", - "ragbits-document-search[gcs, huggingface]", + "ragbits-core[chroma,lab,litellm,local,langsmith,otel]", + "ragbits-document-search[gcs,huggingface]", "ragbits-evaluate[relari]", ] @@ -146,7 +146,6 @@ ignore = [ "PLR0913", ] - [tool.ruff.lint.pydocstyle] convention = "google" @@ -173,7 +172,6 @@ convention = "google" docstring-code-format = true docstring-code-line-length = 120 - [tool.ruff.lint.isort] known-first-party = ["ragbits"] known-third-party = [