diff --git a/.gitignore b/.gitignore index 43d821a4..0f181039 100644 --- a/.gitignore +++ b/.gitignore @@ -87,11 +87,11 @@ cmake-build-*/ **/.terraform.lock.hcl **/.terraform -# benchmarks -benchmarks/sql/data/ - # mkdocs generated files site/ # build artifacts dist/ + +# examples +chroma/ diff --git a/README.md b/README.md index c47afb68..974b6d23 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,11 @@ - [X] **[Core](packages/ragbits-core)** - Fundamental tools for working with prompts and LLMs. - [X] **[Document Search](packages/ragbits-document-search)** - Handles vector search to retrieve relevant documents. - [X] **[CLI](packages/ragbits-cli)** - The `ragbits` shell command, enabling tools such as GUI prompt management. +- [x] **[Guardrails](packages/ragbits-guardrails)** - Ensures response safety and relevance. +- [x] **[Evaluation](packages/ragbits-evaluate)** - Unified evaluation framework for Ragbits components. - [ ] **Flow Controls** - Manages multi-stage chat flows for performing advanced actions *(coming soon)*. - [ ] **Structured Querying** - Queries structured data sources in a predictable manner *(coming soon)*. - [ ] **Caching** - Adds a caching layer to reduce costs and response times *(coming soon)*. -- [ ] **Observability & Audit** - Tracks user queries and events for easier troubleshooting *(coming soon)*. -- [ ] **Guardrails** - Ensures response safety and relevance *(coming soon)*. ## Installation diff --git a/docs/how-to/use_guardrails.md b/docs/how-to/use_guardrails.md new file mode 100644 index 00000000..430e883c --- /dev/null +++ b/docs/how-to/use_guardrails.md @@ -0,0 +1,48 @@ +# How-To: Use Guardrails + +Ragbits offers an expandable guardrails system. You can use one of the available guardrails or create your own to prevent toxic language, PII leaks etc. + +In this guide we will show you how to use guardrail based on OpenAI moderation and how to creat your own guardrail. + + +## Using existing guardrail +To use one of the existing guardrails you need to import it together with `GuardrailManager`. Next you simply pass a list of guardrails to the manager +and call `verify()` function that will check the input (`str` or `Prompt`) against all provided guardrails asynchronously. + +```python +import asyncio +from ragbits.guardrails.base import GuardrailManager, GuardrailVerificationResult +from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail + + +async def verify_message(message: str) -> list[GuardrailVerificationResult]: + manager = GuardrailManager([OpenAIModerationGuardrail()]) + return await manager.verify(message) + + +if __name__ == '__main__': + print(asyncio.run(verify_message("Test message"))) +``` + +The expected output is an object with the following properties: +```python + guardrail_name: str + succeeded: bool + fail_reason: str | None +``` +It allows you to see which guardrail was used, whether the check was successful and optionally a fail reason. + +## Implementing custom guardrail +We need to create a new class that inherits from `Guardrail` and implements abstract method `verify`. + +```python +from ragbits.core.prompt import Prompt +from ragbits.guardrails.base import Guardrail, GuardrailVerificationResult + +class CustomGuardrail(Guardrail): + + async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult: + pass +``` + +With that you can pass your `CustomGuardrail` to the `GuardrailManager` as shown in [using existing guardrails section](#using-existing-guardrail). \ No newline at end of file diff --git a/examples/document-search/basic.py b/examples/document-search/basic.py index 3cf9e887..0d3667f9 100644 --- a/examples/document-search/basic.py +++ b/examples/document-search/basic.py @@ -1,3 +1,27 @@ +""" +Ragbits Document Search Example: Basic + +This example demonstrates how to use the `DocumentSearch` class to search for documents with a minimal setup. +We will use the `LiteLLMEmbeddings` class to embed the documents and the query and the `InMemoryVectorStore` class +to store the embeddings. + +The script performs the following steps: + + 1. Create a list of documents. + 2. Initialize the `LiteLLMEmbeddings` class with the OpenAI `text-embedding-3-small` embedding model. + 3. Initialize the `InMemoryVectorStore` class. + 4. Initialize the `DocumentSearch` class with the embedder and the vector store. + 5. Ingest the documents into the `DocumentSearch` instance. + 6. Search for documents using a query. + 7. Print the search results. + +To run the script, execute the following command: + + ```bash + uv run examples/document-search/basic.py + ``` +""" + # /// script # requires-python = ">=3.10" # dependencies = [ @@ -5,6 +29,7 @@ # "ragbits-core[litellm]", # ] # /// + import asyncio from ragbits.core.embeddings.litellm import LiteLLMEmbeddings @@ -13,12 +38,25 @@ from ragbits.document_search.documents.document import DocumentMeta documents = [ - DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), DocumentMeta.create_text_document_from_literal( - "Why doesn't James Bond fart in bed? Because it would blow his cover." + """ + RIP boiled water. You will be mist. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why doesn't James Bond fart in bed? Because it would blow his cover. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why programmers don't like to swim? Because they're scared of the floating points. + """ ), DocumentMeta.create_text_document_from_literal( - "Why programmers don't like to swim? Because they're scared of the floating points." + """ + This one is completely unrelated. + """ ), ] diff --git a/examples/document-search/chroma.py b/examples/document-search/chroma.py index 88b876ce..7becdff6 100644 --- a/examples/document-search/chroma.py +++ b/examples/document-search/chroma.py @@ -1,3 +1,28 @@ +""" +Ragbits Document Search Example: Chroma + +This example demonstrates how to use the `DocumentSearch` class to search for documents with a more advanced setup. +We will use the `LiteLLMEmbeddings` class to embed the documents and the query, the `ChromaVectorStore` class to store +the embeddings. + +The script performs the following steps: + + 1. Create a list of documents. + 2. Initialize the `LiteLLMEmbeddings` class with the OpenAI `text-embedding-3-small` embedding model. + 3. Initialize the `ChromaVectorStore` class with a `PersistentClient` instance and an index name. + 4. Initialize the `DocumentSearch` class with the embedder and the vector store. + 5. Ingest the documents into the `DocumentSearch` instance. + 6. List all documents in the vector store. + 7. Search for documents using a query. + 8. Print the list of all documents and the search results. + +To run the script, execute the following command: + + ```bash + uv run examples/document-search/chroma.py + ``` +""" + # /// script # requires-python = ">=3.10" # dependencies = [ @@ -5,6 +30,7 @@ # "ragbits-core[chroma,litellm]", # ] # /// + import asyncio from chromadb import PersistentClient @@ -15,11 +41,26 @@ from ragbits.document_search.documents.document import DocumentMeta documents = [ - DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), DocumentMeta.create_text_document_from_literal( - "Why programmers don't like to swim? Because they're scared of the floating points." + """ + RIP boiled water. You will be mist. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why doesn't James Bond fart in bed? Because it would blow his cover. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why programmers don't like to swim? Because they're scared of the floating points. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + This one is completely unrelated. + """ ), - DocumentMeta.create_text_document_from_literal("This one is completely unrelated."), ] diff --git a/examples/document-search/chroma_otel.py b/examples/document-search/chroma_otel.py new file mode 100644 index 00000000..d137bd42 --- /dev/null +++ b/examples/document-search/chroma_otel.py @@ -0,0 +1,137 @@ +""" +Ragbits Document Search Example: Chroma x OpenTelemetry + +This example demonstrates how to use the `DocumentSearch` class to search for documents with a more advanced setup. +We will use the `LiteLLMEmbeddings` class to embed the documents and the query, the `ChromaVectorStore` class to store +the embeddings, and the OpenTelemetry SDK to trace the operations. + +The script performs the following steps: + + 1. Create a list of documents. + 2. Initialize the `LiteLLMEmbeddings` class with the OpenAI `text-embedding-3-small` embedding model. + 3. Initialize the `ChromaVectorStore` class with a `PersistentClient` instance and an index name. + 4. Initialize the `DocumentSearch` class with the embedder and the vector store. + 5. Ingest the documents into the `DocumentSearch` instance. + 6. List all documents in the vector store. + 7. Search for documents using a query. + 8. Print the list of all documents and the search results. + +To run the script, execute the following command: + + ```bash + uv run examples/document-search/chroma_otel.py + ``` + +The script exports traces to the local OTLP collector running on `http://localhost:4317`. To visualize the traces, +you can use Jeager. The recommended way to run it is using the official Docker image: + + 1. Run Jaeger Docker container: + + ```bash + docker run -d --rm --name jaeger \ + -p 16686:16686 \ + -p 4317:4317 \ + jaegertracing/all-in-one:1.62.0 + ``` + + 2. Open the Jaeger UI in your browser: + + ``` + http://localhost:16686 + ``` +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-document-search", +# "ragbits-core[chroma,litellm,otel]", +# ] +# /// + +import asyncio + +from chromadb import PersistentClient +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor + +from ragbits.core import audit +from ragbits.core.embeddings.litellm import LiteLLMEmbeddings +from ragbits.core.vector_stores.chroma import ChromaVectorStore +from ragbits.document_search import DocumentSearch, SearchConfig +from ragbits.document_search.documents.document import DocumentMeta + +provider = TracerProvider(resource=Resource({SERVICE_NAME: "ragbits"})) +provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter("http://localhost:4317", insecure=True))) +trace.set_tracer_provider(provider) + +audit.set_trace_handlers("otel") + +documents = [ + DocumentMeta.create_text_document_from_literal( + """ + RIP boiled water. You will be mist. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why doesn't James Bond fart in bed? Because it would blow his cover. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why programmers don't like to swim? Because they're scared of the floating points. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + This one is completely unrelated. + """ + ), +] + + +async def main() -> None: + """ + Run the example. + """ + embedder = LiteLLMEmbeddings( + model="text-embedding-3-small", + ) + vector_store = ChromaVectorStore( + client=PersistentClient("./chroma"), + index_name="jokes", + ) + document_search = DocumentSearch( + embedder=embedder, + vector_store=vector_store, + ) + + await document_search.ingest(documents) + + all_documents = await vector_store.list() + + print() + print("All documents:") + print([doc.metadata["content"] for doc in all_documents]) + + query = "I'm boiling my water and I need a joke" + vector_store_kwargs = { + "k": 2, + "max_distance": None, + } + results = await document_search.search( + query, + config=SearchConfig(vector_store_kwargs=vector_store_kwargs), + ) + + print() + print(f"Documents similar to: {query}") + print([element.get_key() for element in results]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/document-search/from_config.py b/examples/document-search/from_config.py index 6907a1dd..1b5912f8 100644 --- a/examples/document-search/from_config.py +++ b/examples/document-search/from_config.py @@ -1,3 +1,25 @@ +""" +Ragbits Document Search Example: DocumentSearch from Config + +This example demonstrates how to use the `DocumentSearch` class to search for documents with a more advanced setup. +We will use the `LiteLLMEmbeddings` class to embed the documents and the query, the `ChromaVectorStore` class to store +the embeddings, and the `LiteLLMReranker` class to rerank the search results. We will also use the `LLMQueryRephraser` +class to rephrase the query. + +The script performs the following steps: + + 1. Create a list of documents. + 2. Initialize the `DocumentSearch` class with the predefined configuration. + 3. Ingest the documents into the `DocumentSearch` instance. + 4. Search for documents using a query. + 5. Print the search results. + +To run the script, execute the following command: + + ```bash + uv run examples/document-search/from_config.py + ``` +""" # /// script # requires-python = ">=3.10" # dependencies = [ @@ -5,23 +27,39 @@ # "ragbits-core[chroma,litellm]", # ] # /// + import asyncio from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta documents = [ - DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), DocumentMeta.create_text_document_from_literal( - "Why doesn't James Bond fart in bed? Because it would blow his cover." + """ + RIP boiled water. You will be mist. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + Why doesn't James Bond fart in bed? Because it would blow his cover. + """ ), DocumentMeta.create_text_document_from_literal( - "Why programmers don't like to swim? Because they're scared of the floating points." + """ + Why programmers don't like to swim? Because they're scared of the floating points. + """ + ), + DocumentMeta.create_text_document_from_literal( + """ + This one is completely unrelated. + """ ), ] config = { - "embedder": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"}, + "embedder": { + "type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings", + }, "vector_store": { "type": "ragbits.core.vector_stores.chroma:ChromaVectorStore", "config": { @@ -42,7 +80,16 @@ }, }, }, - "reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"}, + "reranker": { + "type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", + "config": { + "model": "cohere/rerank-english-v3.0", + "default_options": { + "top_n": 3, + "max_chunks_per_doc": None, + }, + }, + }, "providers": {"txt": {"type": "DummyProvider"}}, "rephraser": { "type": "LLMQueryRephraser", diff --git a/examples/guardrails/openai_moderation.py b/examples/guardrails/openai_moderation.py new file mode 100644 index 00000000..75e42155 --- /dev/null +++ b/examples/guardrails/openai_moderation.py @@ -0,0 +1,29 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-core", +# "openai", +# ] +# /// +import asyncio +from argparse import ArgumentParser + +from ragbits.guardrails.base import GuardrailManager +from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail + + +async def guardrail_run(message: str) -> None: + """ + Example of using the OpenAIModerationGuardrail. Requires the OPENAI_API_KEY environment variable to be set. + """ + manager = GuardrailManager([OpenAIModerationGuardrail()]) + res = await manager.verify(message) + print(res) + + +if __name__ == "__main__": + args = ArgumentParser() + args.add_argument("message", nargs="+", type=str, help="Message to validate") + parsed_args = args.parse_args() + + asyncio.run(guardrail_run("".join(parsed_args.message))) diff --git a/packages/ragbits-core/pyproject.toml b/packages/ragbits-core/pyproject.toml index 048d812f..2271fb8e 100644 --- a/packages/ragbits-core/pyproject.toml +++ b/packages/ragbits-core/pyproject.toml @@ -56,6 +56,9 @@ lab = [ promptfoo = [ "PyYAML~=6.0.2", ] +otel = [ + "opentelemetry-api~=1.27.0", +] [tool.uv] dev-dependencies = [ diff --git a/packages/ragbits-core/src/ragbits/core/audit/__init__.py b/packages/ragbits-core/src/ragbits/core/audit/__init__.py new file mode 100644 index 00000000..cf12564d --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/audit/__init__.py @@ -0,0 +1,142 @@ +import asyncio +import inspect +from collections.abc import Callable, Iterator +from contextlib import ExitStack, contextmanager +from functools import wraps +from types import SimpleNamespace +from typing import Any, ParamSpec, TypeVar + +from ragbits.core.audit.base import TraceHandler + +__all__ = ["TraceHandler", "set_trace_handlers", "trace", "traceable"] + +_trace_handlers: list[TraceHandler] = [] + +Handler = str | TraceHandler + +P = ParamSpec("P") +R = TypeVar("R") + + +def set_trace_handlers(handlers: Handler | list[Handler]) -> None: + """ + Setup trace handlers. + + Args: + handlers: List of trace handlers to be used. + + Raises: + ValueError: If handler is not found. + TypeError: If handler type is invalid. + """ + global _trace_handlers # noqa: PLW0602 + + if isinstance(handlers, Handler): + handlers = [handlers] + + for handler in handlers: # type: ignore + if isinstance(handler, TraceHandler): + _trace_handlers.append(handler) + elif isinstance(handler, str): + if handler == "otel": + from ragbits.core.audit.otel import OtelTraceHandler + + _trace_handlers.append(OtelTraceHandler()) + else: + raise ValueError(f"Handler {handler} not found.") + else: + raise TypeError(f"Invalid handler type: {type(handler)}") + + +@contextmanager +def trace(name: str | None = None, **inputs: Any) -> Iterator[SimpleNamespace]: # noqa: ANN401 + """ + Context manager for processing a trace. + + Args: + name: The name of the trace. + inputs: The input data. + + Yields: + The output data. + """ + # We need to go up 2 frames (trace() and __enter__()) to get the parent function. + parent_frame = inspect.stack()[2].frame + name = ( + ( + f"{cls.__class__.__qualname__}.{parent_frame.f_code.co_name}" + if (cls := parent_frame.f_locals.get("self")) + else parent_frame.f_code.co_name + ) + if name is None + else name + ) + + with ExitStack() as stack: + outputs = [stack.enter_context(handler.trace(name, **inputs)) for handler in _trace_handlers] + yield (out := SimpleNamespace()) + for output in outputs: + output.__dict__.update(vars(out)) + + +def traceable(func: Callable[P, R]) -> Callable[P, R]: + """ + Decorator for making a function traceable. + + Args: + func: The function to be decorated. + + Returns: + The decorated function. + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + inputs = _get_function_inputs(func, args, kwargs) + with trace(name=func.__qualname__, **inputs) as outputs: + returned = func(*args, **kwargs) + if returned is not None: + outputs.returned = returned + return returned + + @wraps(func) + async def wrapper_async(*args: P.args, **kwargs: P.kwargs) -> R: + inputs = _get_function_inputs(func, args, kwargs) + with trace(name=func.__qualname__, **inputs) as outputs: + returned = await func(*args, **kwargs) # type: ignore + if returned is not None: + outputs.returned = returned + return returned + + return wrapper_async if asyncio.iscoroutinefunction(func) else wrapper # type: ignore + + +def _get_function_inputs(func: Callable, args: tuple, kwargs: dict) -> dict: + """ + Get the dictionary of inputs for a function based on positional and keyword arguments. + + Args: + func: The function to get inputs for. + args: The positional arguments. + kwargs: The keyword arguments. + + Returns: + The dictionary of inputs. + """ + sig_params = inspect.signature(func).parameters + merged = {} + pos_args_used = 0 + + for param_name, param in sig_params.items(): + if param_name in kwargs: + merged[param_name] = kwargs[param_name] + elif pos_args_used < len(args): + if param_name not in ("self", "cls", "args", "kwargs"): + merged[param_name] = args[pos_args_used] + pos_args_used += 1 + elif param.default is not param.empty: + merged[param_name] = param.default + + merged.update({k: v for k, v in kwargs.items() if k not in merged}) + + return merged diff --git a/packages/ragbits-core/src/ragbits/core/audit/base.py b/packages/ragbits-core/src/ragbits/core/audit/base.py new file mode 100644 index 00000000..d0e9f7fd --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/audit/base.py @@ -0,0 +1,87 @@ +from abc import ABC, abstractmethod +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from types import SimpleNamespace +from typing import Any, Generic, TypeVar + +SpanT = TypeVar("SpanT") + + +class TraceHandler(Generic[SpanT], ABC): + """ + Base class for all trace handlers. + """ + + def __init__(self) -> None: + """ + Constructs a new TraceHandler instance. + """ + super().__init__() + self._spans = ContextVar[list[SpanT]]("_spans", default=[]) + + @abstractmethod + def start(self, name: str, inputs: dict, current_span: SpanT | None = None) -> SpanT: + """ + Log input data at the beginning of the trace. + + Args: + name: The name of the trace. + inputs: The input data. + current_span: The current trace span. + + Returns: + The updated current trace span. + """ + + @abstractmethod + def stop(self, outputs: dict, current_span: SpanT) -> None: + """ + Log output data at the end of the trace. + + Args: + outputs: The output data. + current_span: The current trace span. + """ + + @abstractmethod + def error(self, error: Exception, current_span: SpanT) -> None: + """ + Log error during the trace. + + Args: + error: The error that occurred. + current_span: The current trace span. + """ + + @contextmanager + def trace(self, name: str, **inputs: Any) -> Iterator[SimpleNamespace]: # noqa: ANN401 + """ + Context manager for processing a trace. + + Args: + name: The name of the trace. + inputs: The input data. + + Yields: + The output data. + """ + self._spans.set(self._spans.get()[:]) + current_span = self._spans.get()[-1] if self._spans.get() else None + + span = self.start( + name=name, + inputs=inputs, + current_span=current_span, + ) + self._spans.get().append(span) + + try: + yield (outputs := SimpleNamespace()) + except Exception as exc: + span = self._spans.get().pop() + self.error(error=exc, current_span=span) + raise exc + + span = self._spans.get().pop() + self.stop(outputs=vars(outputs), current_span=span) diff --git a/packages/ragbits-core/src/ragbits/core/audit/otel.py b/packages/ragbits-core/src/ragbits/core/audit/otel.py new file mode 100644 index 00000000..1d909033 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/audit/otel.py @@ -0,0 +1,98 @@ +from opentelemetry import trace +from opentelemetry.trace import Span, StatusCode, TracerProvider +from opentelemetry.util.types import AttributeValue + +from ragbits.core.audit.base import TraceHandler + + +class OtelTraceHandler(TraceHandler[Span]): + """ + OpenTelemetry trace handler. + """ + + def __init__(self, provider: TracerProvider | None = None) -> None: + """ + Constructs a new OtelTraceHandler instance. + + Args: + provider: The tracer provider to use. + """ + super().__init__() + self._tracer = trace.get_tracer(instrumenting_module_name=__name__, tracer_provider=provider) + + def start(self, name: str, inputs: dict, current_span: Span | None = None) -> Span: + """ + Log input data at the beginning of the trace. + + Args: + name: The name of the trace. + inputs: The input data. + current_span: The current trace span. + + Returns: + The updated current trace span. + """ + context = trace.set_span_in_context(current_span) if current_span else None + + with self._tracer.start_as_current_span(name, context=context, end_on_exit=False) as span: + attributes = _format_attributes(inputs, prefix="inputs") + span.set_attributes(attributes) + + return span + + def stop(self, outputs: dict, current_span: Span) -> None: # noqa: PLR6301 + """ + Log output data at the end of the trace. + + Args: + outputs: The output data. + current_span: The current trace span. + """ + attributes = _format_attributes(outputs, prefix="outputs") + current_span.set_attributes(attributes) + current_span.set_status(StatusCode.OK) + current_span.end() + + def error(self, error: Exception, current_span: Span) -> None: # noqa: PLR6301 + """ + Log error during the trace. + + Args: + error: The error that occurred. + current_span: The current trace span. + """ + attributes = _format_attributes(vars(error), prefix="error") + current_span.set_attributes(attributes) + current_span.set_status(StatusCode.ERROR) + current_span.end() + + +def _format_attributes(data: dict, prefix: str | None = None) -> dict[str, AttributeValue]: + """ + Format attributes for OpenTelemetry. + + Args: + data: The data to format. + prefix: The prefix to use for the keys. + + Returns: + The formatted attributes. + """ + flattened = {} + + for key, value in data.items(): + current_key = f"{prefix}.{key}" if prefix else key + + if isinstance(value, dict): + flattened.update(_format_attributes(value, current_key)) + elif isinstance(value, list | tuple): + flattened[current_key] = [ + item if isinstance(item, str | float | int | bool) else repr(item) + for item in value # type: ignore + ] + elif isinstance(value, str | float | int | bool): + flattened[current_key] = value + else: + flattened[current_key] = repr(value) + + return flattened diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py b/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py index 4dd99ad1..a8a25224 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/exceptions.py @@ -34,3 +34,12 @@ class EmbeddingResponseError(EmbeddingError): def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: super().__init__(message) + + +class EmbeddingEmptyResponseError(EmbeddingError): + """ + Raised when an API response has an empty response. + """ + + def __init__(self, message: str = "Empty response returned by API.") -> None: + super().__init__(message) diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py index 1b3fbe5b..ad2d0929 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py @@ -5,9 +5,11 @@ except ImportError: HAS_LITELLM = False +from ragbits.core.audit import trace from ragbits.core.embeddings import Embeddings from ragbits.core.embeddings.exceptions import ( EmbeddingConnectionError, + EmbeddingEmptyResponseError, EmbeddingResponseError, EmbeddingStatusError, ) @@ -64,23 +66,40 @@ async def embed_text(self, data: list[str]) -> list[list[float]]: Raises: EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingEmptyResponseError: If the embedding API returns an empty response. EmbeddingStatusError: If the embedding API returns an error status code. EmbeddingResponseError: If the embedding API response is invalid. """ - try: - response = await litellm.aembedding( - input=data, - model=self.model, - api_base=self.api_base, - api_key=self.api_key, - api_version=self.api_version, - **self.options, - ) - except litellm.openai.APIConnectionError as exc: - raise EmbeddingConnectionError() from exc - except litellm.openai.APIStatusError as exc: - raise EmbeddingStatusError(exc.message, exc.status_code) from exc - except litellm.openai.APIResponseValidationError as exc: - raise EmbeddingResponseError() from exc + with trace( + data=data, + model=self.model, + api_base=self.api_base, + api_version=self.api_version, + options=self.options, + ) as outputs: + try: + response = await litellm.aembedding( + input=data, + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **self.options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc - return [embedding["embedding"] for embedding in response.data] + if not response.data: + raise EmbeddingEmptyResponseError() + + outputs.embeddings = [embedding["embedding"] for embedding in response.data] + if response.usage: + outputs.completion_tokens = response.usage.completion_tokens + outputs.prompt_tokens = response.usage.prompt_tokens + outputs.total_tokens = response.usage.total_tokens + + return outputs.embeddings diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/noop.py b/packages/ragbits-core/src/ragbits/core/embeddings/noop.py index 386b8e02..07d7b3d0 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/noop.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/noop.py @@ -1,3 +1,4 @@ +from ragbits.core.audit import traceable from ragbits.core.embeddings.base import Embeddings @@ -10,6 +11,7 @@ class NoopEmbeddings(Embeddings): or as a placeholder when an actual embedding model is not required. """ + @traceable async def embed_text(self, data: list[str]) -> list[list[float]]: # noqa: PLR6301 """ Embeds a list of strings into a list of vectors. diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/exceptions.py b/packages/ragbits-core/src/ragbits/core/llms/clients/exceptions.py index 0f1106ba..19ea8ee7 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/exceptions.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/exceptions.py @@ -34,3 +34,12 @@ class LLMResponseError(LLMError): def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: super().__init__(message) + + +class LLMEmptyResponseError(LLMError): + """ + Raised when an API response is empty. + """ + + def __init__(self, message: str = "Empty response returned by API.") -> None: + super().__init__(message) diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py index 11b9c3f8..92efe65c 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py @@ -10,11 +10,12 @@ HAS_LITELLM = False +from ragbits.core.audit import trace from ragbits.core.prompt import ChatFormat from ..types import NOT_GIVEN, NotGiven from .base import LLMClient, LLMOptions -from .exceptions import LLMConnectionError, LLMResponseError, LLMStatusError +from .exceptions import LLMConnectionError, LLMEmptyResponseError, LLMResponseError, LLMStatusError @dataclass @@ -108,21 +109,38 @@ async def call( elif json_mode: response_format = {"type": "json_object"} - try: - response = await litellm.acompletion( - messages=conversation, - model=self.model_name, - base_url=self.base_url, - api_key=self.api_key, - api_version=self.api_version, - response_format=response_format, - **options.dict(), - ) - except litellm.openai.APIConnectionError as exc: - raise LLMConnectionError() from exc - except litellm.openai.APIStatusError as exc: - raise LLMStatusError(exc.message, exc.status_code) from exc - except litellm.openai.APIResponseValidationError as exc: - raise LLMResponseError() from exc - - return response.choices[0].message.content + with trace( + messages=conversation, + model=self.model_name, + base_url=self.base_url, + api_version=self.api_version, + response_format=response_format, + options=options.dict(), + ) as outputs: + try: + response = await litellm.acompletion( + messages=conversation, + model=self.model_name, + base_url=self.base_url, + api_key=self.api_key, + api_version=self.api_version, + response_format=response_format, + **options.dict(), + ) + except litellm.openai.APIConnectionError as exc: + raise LLMConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise LLMResponseError() from exc + + if not response.choices: # type: ignore + raise LLMEmptyResponseError() + + outputs.response = response.choices[0].message.content # type: ignore + if response.usage: # type: ignore + outputs.completion_tokens = response.usage.completion_tokens # type: ignore + outputs.prompt_tokens = response.usage.prompt_tokens # type: ignore + outputs.total_tokens = response.usage.total_tokens # type: ignore + + return outputs.response # type: ignore diff --git a/packages/ragbits-core/src/ragbits/core/metadata_stores/in_memory.py b/packages/ragbits-core/src/ragbits/core/metadata_stores/in_memory.py index 6f2a9890..0cd1f89f 100644 --- a/packages/ragbits-core/src/ragbits/core/metadata_stores/in_memory.py +++ b/packages/ragbits-core/src/ragbits/core/metadata_stores/in_memory.py @@ -1,3 +1,4 @@ +from ragbits.core.audit import traceable from ragbits.core.metadata_stores.base import MetadataStore from ragbits.core.metadata_stores.exceptions import MetadataNotFoundError @@ -13,6 +14,7 @@ def __init__(self) -> None: """ self._storage: dict[str, dict] = {} + @traceable async def store(self, ids: list[str], metadatas: list[dict]) -> None: """ Store metadatas under ids in metadata store. @@ -24,6 +26,7 @@ async def store(self, ids: list[str], metadatas: list[dict]) -> None: for _id, metadata in zip(ids, metadatas, strict=False): self._storage[_id] = metadata + @traceable async def get(self, ids: list[str]) -> list[dict]: """ Returns metadatas associated with a given ids. diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py index c8500e7a..f3a981f6 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py @@ -8,6 +8,7 @@ from chromadb import Collection from chromadb.api import ClientAPI +from ragbits.core.audit import traceable from ragbits.core.metadata_stores import get_metadata_store from ragbits.core.metadata_stores.base import MetadataStore from ragbits.core.utils.config_handling import get_cls_from_config @@ -75,6 +76,7 @@ def from_config(cls, config: dict) -> ChromaVectorStore: metadata_store=get_metadata_store(config.get("metadata_store")), ) + @traceable async def store(self, entries: list[VectorStoreEntry]) -> None: """ Stores entries in the ChromaDB collection. @@ -94,6 +96,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None: ) self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents) # type: ignore + @traceable async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]: """ Retrieves entries from the ChromaDB collection. @@ -138,6 +141,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None if options.max_distance is None or distance <= options.max_distance ] + @traceable async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 ) -> list[VectorStoreEntry]: diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py index aa0af760..a8391991 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py @@ -2,6 +2,7 @@ import numpy as np +from ragbits.core.audit import traceable from ragbits.core.metadata_stores.base import MetadataStore from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery @@ -26,6 +27,7 @@ def __init__( super().__init__(default_options=default_options, metadata_store=metadata_store) self._storage: list[VectorStoreEntry] = [] + @traceable async def store(self, entries: list[VectorStoreEntry]) -> None: """ Store entries in the vector store. @@ -35,6 +37,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None: """ self._storage.extend(entries) + @traceable async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]: """ Retrieve entries from the vector store. @@ -57,6 +60,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None if options.max_distance is None or distance <= options.max_distance ] + @traceable async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 ) -> list[VectorStoreEntry]: diff --git a/packages/ragbits-core/tests/unit/audit/__init__.py b/packages/ragbits-core/tests/unit/audit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-core/tests/unit/audit/test_otel.py b/packages/ragbits-core/tests/unit/audit/test_otel.py new file mode 100644 index 00000000..8ea2fcd9 --- /dev/null +++ b/packages/ragbits-core/tests/unit/audit/test_otel.py @@ -0,0 +1,42 @@ +from datetime import datetime + +import pytest + +from ragbits.core.audit.otel import _format_attributes + + +@pytest.mark.parametrize( + ("input_data", "prefix", "expected"), + [ + # Empty dict + ({}, None, {}), + ({}, "test", {}), + # Simple types + ( + {"str": "value", "int": 42, "float": 3.14, "bool": True}, + None, + {"str": "value", "int": 42, "float": 3.14, "bool": True}, + ), + # With prefix + ({"str": "value", "int": 42}, "prefix", {"prefix.str": "value", "prefix.int": 42}), + # Nested dict + ({"nested": {"key1": "value1", "key2": 42}}, None, {"nested.key1": "value1", "nested.key2": 42}), + # Lists and tuples + ({"list": [1, 2, 3], "tuple": ("a", "b", "c")}, None, {"list": [1, 2, 3], "tuple": ["a", "b", "c"]}), + # Complex objects in lists + ( + {"objects": [{"a": 1}, datetime(2023, 1, 1)]}, + None, + {"objects": ["{'a': 1}", "datetime.datetime(2023, 1, 1, 0, 0)"]}, + ), + # Mixed nested structure + ( + {"level1": {"level2": {"string": "value", "list": [1, {"x": "y"}]}}}, + "test", + {"test.level1.level2.string": "value", "test.level1.level2.list": [1, "{'x': 'y'}"]}, + ), + ], +) +def test_format_attributes(input_data: dict, prefix: str, expected: dict) -> None: + result = _format_attributes(input_data, prefix) + assert result == expected diff --git a/packages/ragbits-core/tests/unit/audit/test_trace.py b/packages/ragbits-core/tests/unit/audit/test_trace.py new file mode 100644 index 00000000..7ac8350e --- /dev/null +++ b/packages/ragbits-core/tests/unit/audit/test_trace.py @@ -0,0 +1,160 @@ +import asyncio +from collections.abc import Callable +from unittest.mock import MagicMock + +import pytest + +from ragbits.core.audit import _get_function_inputs, set_trace_handlers, trace, traceable +from ragbits.core.audit.base import TraceHandler + + +class MockTraceHandler(TraceHandler): + def start(self, name: str, inputs: dict, current_span: None = None) -> None: + pass + + def stop(self, outputs: dict, current_span: None) -> None: + pass + + def error(self, error: Exception, current_span: None) -> None: + pass + + +@pytest.fixture +def mock_handler() -> MockTraceHandler: + handler = MockTraceHandler() + set_trace_handlers(handler) + return handler + + +def test_trace_context_with_name(mock_handler: MockTraceHandler) -> None: + current_span = MagicMock() + mock_handler.start = MagicMock(return_value=current_span) # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + with trace(name="test", input1="value1") as outputs: + outputs.result = "success" + + mock_handler.start.assert_called_once_with(name="test", inputs={"input1": "value1"}, current_span=None) + mock_handler.stop.assert_called_once_with(outputs={"result": "success"}, current_span=current_span) + mock_handler.error.assert_not_called() + + +def test_trace_context_without_name(mock_handler: MockTraceHandler) -> None: + current_span = MagicMock() + mock_handler.start = MagicMock(return_value=current_span) # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + with trace() as outputs: + outputs.result = "success" + + mock_handler.start.assert_called_once_with(name="test_trace_context_without_name", inputs={}, current_span=None) + mock_handler.stop.assert_called_once_with(outputs={"result": "success"}, current_span=current_span) + mock_handler.error.assert_not_called() + + +def test_trace_context_exception(mock_handler: MockTraceHandler) -> None: + current_span = MagicMock() + mock_handler.start = MagicMock(return_value=current_span) # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + with pytest.raises(ValueError), trace(name="test"): + raise (error := ValueError("test error")) + + mock_handler.start.assert_called_once_with(name="test", inputs={}, current_span=None) + mock_handler.error.assert_called_once_with(error=error, current_span=current_span) + mock_handler.stop.assert_not_called() + + +def test_traceable_sync(mock_handler: MockTraceHandler) -> None: + current_span = MagicMock() + mock_handler.start = MagicMock(return_value=current_span) # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + @traceable + def sample_sync_function(a: int, b: str = "default") -> str: + return f"{a}-{b}" + + result = sample_sync_function(1, b="test") + assert result == "1-test" + + mock_handler.start.assert_called_once_with( + name="test_traceable_sync..sample_sync_function", + inputs={"a": 1, "b": "test"}, + current_span=None, + ) + mock_handler.stop.assert_called_once_with(outputs={"returned": "1-test"}, current_span=current_span) + + +async def test_traceable_async(mock_handler: MockTraceHandler) -> None: + current_span = MagicMock() + mock_handler.start = MagicMock(return_value=current_span) # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + @traceable + async def sample_async_function(x: int) -> int: + await asyncio.sleep(0.01) + return x * 2 + + result = await sample_async_function(5) + assert result == 10 + + mock_handler.start.assert_called_once_with( + name="test_traceable_async..sample_async_function", + inputs={"x": 5}, + current_span=None, + ) + mock_handler.stop.assert_called_once_with(outputs={"returned": 10}, current_span=current_span) + + +def test_traceable_no_return(mock_handler: MockTraceHandler) -> None: + current_span = MagicMock() + mock_handler.start = MagicMock(return_value=current_span) # type: ignore + mock_handler.stop = MagicMock() # type: ignore + mock_handler.error = MagicMock() # type: ignore + + @traceable + def void_function(x: int) -> None: + pass + + void_function(1) + mock_handler.start.assert_called_once_with( + name="test_traceable_no_return..void_function", + inputs={"x": 1}, + current_span=None, + ) + mock_handler.stop.assert_called_once_with(outputs={}, current_span=current_span) + + +@pytest.mark.parametrize( + ("func", "args", "kwargs", "expected"), + [ + # Test no args and no kwargs + (lambda: None, (), {}, {}), + # Test only args + (lambda a, b: None, (1, 2), {}, {"a": 1, "b": 2}), + # Test only kwargs + (lambda a, b: None, (), {"a": 1, "b": 2}, {"a": 1, "b": 2}), + # Test args and kwargs + (lambda a, b, c: None, (1,), {"b": 2, "c": 3}, {"a": 1, "b": 2, "c": 3}), + # Test with defaults + (lambda a, b=2: None, (1,), {}, {"a": 1, "b": 2}), + # Test extra kwargs + (lambda a: None, (), {"a": 1, "b": 2}, {"a": 1, "b": 2}), + # Test empty signature + (lambda: None, (1, 2, 3), {"a": 1, "b": 2}, {"a": 1, "b": 2}), + # Test *args + (lambda *args: None, (1, 2, 3), {}, {}), + # Test **kwargs + (lambda **kwargs: None, (), {"a": 1, "b": 2}, {"a": 1, "b": 2}), + # Test args, kwargs, and defaults + (lambda a, b=2, c=3: None, (1,), {"c": 4}, {"a": 1, "b": 2, "c": 4}), + ], +) +def test_get_function_inputs(func: Callable, args: tuple, kwargs: dict, expected: dict) -> None: + result = _get_function_inputs(func, args, kwargs) + assert result == expected diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index d8231677..21deb945 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field +from ragbits.core.audit import traceable from ragbits.core.embeddings import Embeddings, get_embeddings from ragbits.core.vector_stores import VectorStore, get_vector_store from ragbits.core.vector_stores.base import VectorStoreOptions @@ -16,7 +17,7 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser from ragbits.document_search.retrieval.rerankers import get_reranker -from ragbits.document_search.retrieval.rerankers.base import Reranker +from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions from ragbits.document_search.retrieval.rerankers.noop import NoopReranker @@ -84,7 +85,8 @@ def from_config(cls, config: dict) -> "DocumentSearch": return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router) - async def search(self, query: str, config: SearchConfig | None = None) -> list[Element]: + @traceable + async def search(self, query: str, config: SearchConfig | None = None) -> Sequence[Element]: """ Search for the most relevant chunks for a query. @@ -106,7 +108,11 @@ async def search(self, query: str, config: SearchConfig | None = None) -> list[E ) elements.extend([Element.from_vector_db_entry(entry) for entry in entries]) - return self.reranker.rerank(elements) + return await self.reranker.rerank( + elements=elements, + query=query, + options=RerankerOptions(**config.reranker_kwargs), + ) async def _process_document( self, @@ -137,6 +143,7 @@ async def _process_document( document_processor = self.document_processor_router.get_provider(document_meta) return await document_processor.process(document_meta) + @traceable async def ingest( self, documents: Sequence[DocumentMeta | Document | Source], diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py index c786df94..9a7d37ff 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/default.py @@ -7,6 +7,7 @@ from unstructured.staging.base import elements_from_dicts from unstructured_client import UnstructuredClient +from ragbits.core.audit import traceable from ragbits.document_search.documents.document import DocumentMeta, DocumentType from ragbits.document_search.documents.element import Element from ragbits.document_search.ingestion.providers.base import BaseProvider @@ -102,6 +103,7 @@ def client(self) -> UnstructuredClient: self._client = UnstructuredClient(api_key_auth=api_key, server_url=api_server) return self._client + @traceable async def process(self, document_meta: DocumentMeta) -> list[Element]: """ Process the document using the Unstructured API. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py index 4c2280a5..674a1584 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured/images.py @@ -21,8 +21,8 @@ to_text_element, ) -DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini" DEFAULT_IMAGE_QUESTION_PROMPT = "Describe the content of the image." +DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini" class _ImagePrompt(Prompt): @@ -34,9 +34,6 @@ class _ImagePromptInput(BaseModel): images: list[bytes] -DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini" - - class UnstructuredImageProvider(UnstructuredDefaultProvider): """ A specialized provider that handles pngs and jpgs using the Unstructured diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 8398b2a5..a7627dd0 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,5 +1,6 @@ from typing import Any +from ragbits.core.audit import traceable from ragbits.core.llms import get_llm from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt @@ -27,6 +28,7 @@ def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | No self._llm = llm self._prompt = prompt or QueryRephraserPrompt + @traceable async def rephrase(self, query: str) -> list[str]: """ Rephrase a given query using the LLM. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 2201e6da..cec7ed8e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,3 +1,4 @@ +from ragbits.core.audit import traceable from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser @@ -6,6 +7,7 @@ class NoopQueryRephraser(QueryRephraser): A no-op query paraphraser that does not change the query. """ + @traceable async def rephrase(self, query: str) -> list[str]: # noqa: PLR6301 """ Mock implementation which outputs the same query as in input. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py index 95a4cfab..a9026279 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py @@ -1,30 +1,30 @@ import sys from ragbits.core.utils.config_handling import get_cls_from_config - -from .base import Reranker -from .noop import NoopReranker +from ragbits.document_search.retrieval.rerankers.base import Reranker +from ragbits.document_search.retrieval.rerankers.noop import NoopReranker __all__ = ["NoopReranker", "Reranker"] -module = sys.modules[__name__] - -def get_reranker(reranker_config: dict | None) -> Reranker: +def get_reranker(config: dict | None = None) -> Reranker: """ Initializes and returns a Reranker object based on the provided configuration. Args: - reranker_config: A dictionary containing configuration details for the Reranker. + config: A dictionary containing configuration details for the Reranker. Returns: An instance of the specified Reranker class, initialized with the provided config (if any) or default arguments. + + Raises: + KeyError: If the provided configuration does not contain a valid "type" key. + InvalidConfigurationError: If the provided configuration is invalid. + NotImplementedError: If the specified Reranker class cannot be created from the provided configuration. """ - if reranker_config is None: + if config is None: return NoopReranker() - reranker_cls = get_cls_from_config(reranker_config["type"], module) - config = reranker_config.get("config", {}) - - return reranker_cls(**config) + reranker_cls = get_cls_from_config(config["type"], sys.modules[__name__]) + return reranker_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py index dec88647..11b0c86e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py @@ -1,22 +1,65 @@ -import abc +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from pydantic import BaseModel from ragbits.document_search.documents.element import Element -class Reranker(abc.ABC): +class RerankerOptions(BaseModel): + """ + Options for the reranker. + """ + + top_n: int | None = None + max_chunks_per_doc: int | None = None + + +class Reranker(ABC): """ - Reranks chunks retrieved from vector store. + Reranks elements retrieved from vector store. """ - @staticmethod - @abc.abstractmethod - def rerank(chunks: list[Element]) -> list[Element]: + def __init__(self, default_options: RerankerOptions | None = None) -> None: + """ + Constructs a new Reranker instance. + + Args: + default_options: The default options for reranking. + """ + self._default_options = default_options or RerankerOptions() + + @classmethod + def from_config(cls, config: dict) -> "Reranker": + """ + Creates and returns an instance of the Reranker class from the given configuration. + + Args: + config: A dictionary containing the configuration for initializing the Reranker instance. + + Returns: + An initialized instance of the Reranker class. + + Raises: + NotImplementedError: If the class cannot be created from the provided configuration. + """ + raise NotImplementedError(f"Cannot create class {cls.__name__} from config.") + + @abstractmethod + async def rerank( + self, + elements: Sequence[Element], + query: str, + options: RerankerOptions | None = None, + ) -> Sequence[Element]: """ - Rerank chunks. + Rerank elements. Args: - chunks: The chunks to rerank. + elements: The elements to rerank. + query: The query to rerank the elements against. + options: The options for reranking. Returns: - The reranked chunks. + The reranked elements. """ diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py new file mode 100644 index 00000000..c83906ce --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py @@ -0,0 +1,71 @@ +from collections.abc import Sequence + +import litellm + +from ragbits.core.audit import traceable +from ragbits.document_search.documents.element import Element +from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions + + +class LiteLLMReranker(Reranker): + """ + A [LiteLLM](https://docs.litellm.ai/docs/rerank) reranker for providers such as Cohere, Together AI, Azure AI. + """ + + def __init__(self, model: str, default_options: RerankerOptions | None = None) -> None: + """ + Constructs a new LiteLLMReranker instance. + + Args: + model: The reranker model to use. + default_options: The default options for reranking. + """ + super().__init__(default_options) + self.model = model + + @classmethod + def from_config(cls, config: dict) -> "LiteLLMReranker": + """ + Creates and returns an instance of the LiteLLMReranker class from the given configuration. + + Args: + config: A dictionary containing the configuration for initializing the LiteLLMReranker instance. + + Returns: + An initialized instance of the LiteLLMReranker class. + """ + return cls( + model=config["model"], + default_options=RerankerOptions(**config.get("default_options", {})), + ) + + @traceable + async def rerank( + self, + elements: Sequence[Element], + query: str, + options: RerankerOptions | None = None, + ) -> Sequence[Element]: + """ + Rerank elements with LiteLLM API. + + Args: + elements: The elements to rerank. + query: The query to rerank the elements against. + options: The options for reranking. + + Returns: + The reranked elements. + """ + options = self._default_options if options is None else options + documents = [element.get_text_representation() for element in elements] + + response = await litellm.arerank( + model=self.model, + query=query, + documents=documents, # type: ignore + top_n=options.top_n, + max_chunks_per_doc=options.max_chunks_per_doc, + ) + + return [elements[result["index"]] for result in response.results] # type: ignore diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py index e417a561..fe4686de 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py @@ -1,21 +1,44 @@ +from collections.abc import Sequence + +from ragbits.core.audit import traceable from ragbits.document_search.documents.element import Element -from ragbits.document_search.retrieval.rerankers.base import Reranker +from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions class NoopReranker(Reranker): """ - A no-op reranker that does not change the order of the chunks. + A no-op reranker that does not change the order of the elements. """ - @staticmethod - def rerank(chunks: list[Element]) -> list[Element]: + @classmethod + def from_config(cls, config: dict) -> "NoopReranker": + """ + Creates and returns an instance of the NoopReranker class from the given configuration. + + Args: + config: A dictionary containing the configuration for initializing the NoopReranker instance. + + Returns: + An initialized instance of the NoopReranker class. + """ + return cls(default_options=RerankerOptions(**config.get("default_options", {}))) + + @traceable + async def rerank( # noqa: PLR6301 + self, + elements: Sequence[Element], + query: str, + options: RerankerOptions | None = None, + ) -> Sequence[Element]: """ - No reranking, returning the same chunks as in input. + No reranking, returning the elements in the same order. Args: - chunks: The chunks to rerank. + elements: The elements to rerank. + query: The query to rerank the elements against. + options: The options for reranking. Returns: - The reranked chunks. + The reranked elements. """ - return chunks + return elements diff --git a/packages/ragbits-document-search/tests/integration/test_rerankers.py b/packages/ragbits-document-search/tests/integration/test_rerankers.py new file mode 100644 index 00000000..a457fa3d --- /dev/null +++ b/packages/ragbits-document-search/tests/integration/test_rerankers.py @@ -0,0 +1,38 @@ +import pytest + +from ragbits.document_search.documents.document import DocumentMeta +from ragbits.document_search.documents.element import TextElement +from ragbits.document_search.retrieval.rerankers.base import RerankerOptions +from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker + +from ..helpers import env_vars_not_set + +COHERE_API_KEY_ENV = "COHERE_API_KEY" # noqa: S105 + + +@pytest.mark.skipif( + env_vars_not_set([COHERE_API_KEY_ENV]), + reason="Cohere API KEY environment variables not set", +) +async def test_litellm_cohere_reranker_rerank() -> None: + options = RerankerOptions(top_n=2, max_chunks_per_doc=None) + reranker = LiteLLMReranker( + model="cohere/rerank-english-v3.0", + default_options=options, + ) + elements = [ + TextElement( + content="Element 1", document_meta=DocumentMeta.create_text_document_from_literal("Mock document 1") + ), + TextElement( + content="Element 2", document_meta=DocumentMeta.create_text_document_from_literal("Mock document 1") + ), + TextElement( + content="Element 3", document_meta=DocumentMeta.create_text_document_from_literal("Mock document 1") + ), + ] + query = "Test query" + + results = await reranker.rerank(elements, query) + + assert len(results) == 2 diff --git a/packages/ragbits-document-search/tests/unit/test_rerankers.py b/packages/ragbits-document-search/tests/unit/test_rerankers.py new file mode 100644 index 00000000..3a557f3d --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_rerankers.py @@ -0,0 +1,82 @@ +from argparse import Namespace +from collections.abc import Sequence +from unittest.mock import patch + +import pytest + +from ragbits.document_search.documents.document import DocumentMeta +from ragbits.document_search.documents.element import Element, TextElement +from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions +from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker + + +class CustomReranker(Reranker): + """ + Custom implementation of Reranker for testing. + """ + + async def rerank( # noqa: PLR6301 + self, elements: Sequence[Element], query: str, options: RerankerOptions | None = None + ) -> Sequence[Element]: + return elements + + +def test_custom_reranker_from_config() -> None: + with pytest.raises(NotImplementedError) as exc_info: + CustomReranker.from_config({}) + + assert "Cannot create class CustomReranker from config" in str(exc_info.value) + + +def test_litellm_reranker_from_config() -> None: + reranker = LiteLLMReranker.from_config( + { + "model": "test-provder/test-model", + "default_options": { + "top_n": 2, + "max_chunks_per_doc": None, + }, + } + ) + + assert reranker.model == "test-provder/test-model" + assert reranker._default_options == RerankerOptions(top_n=2, max_chunks_per_doc=None) + + +async def test_litellm_reranker_rerank() -> None: + options = RerankerOptions(top_n=2, max_chunks_per_doc=None) + reranker = LiteLLMReranker( + model="test-provder/test-model", + default_options=options, + ) + documents = [ + DocumentMeta.create_text_document_from_literal("Mock document Element 1"), + DocumentMeta.create_text_document_from_literal("Mock document Element 2"), + DocumentMeta.create_text_document_from_literal("Mock document Element 3"), + ] + elements = [ + TextElement(content="Element 1", document_meta=documents[0]), + TextElement(content="Element 2", document_meta=documents[1]), + TextElement(content="Element 3", document_meta=documents[2]), + ] + reranked_elements = [ + TextElement(content="Element 2", document_meta=documents[1]), + TextElement(content="Element 3", document_meta=documents[2]), + TextElement(content="Element 1", document_meta=documents[0]), + ] + reranker_output = Namespace(results=[{"index": 1}, {"index": 2}, {"index": 0}]) + query = "Test query" + + with patch( + "ragbits.document_search.retrieval.rerankers.litellm.litellm.arerank", return_value=reranker_output + ) as mock_arerank: + results = await reranker.rerank(elements, query) + + assert results == reranked_elements + mock_arerank.assert_called_once_with( + model="test-provder/test-model", + query=query, + documents=["Element 1", "Element 2", "Element 3"], + top_n=2, + max_chunks_per_doc=None, + ) diff --git a/packages/ragbits-guardrails/CHANGELOG.md b/packages/ragbits-guardrails/CHANGELOG.md new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-guardrails/README.md b/packages/ragbits-guardrails/README.md new file mode 100644 index 00000000..4dcb2479 --- /dev/null +++ b/packages/ragbits-guardrails/README.md @@ -0,0 +1 @@ +# Ragbits Guardrails diff --git a/packages/ragbits-guardrails/py.typed b/packages/ragbits-guardrails/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-guardrails/pyproject.toml b/packages/ragbits-guardrails/pyproject.toml new file mode 100644 index 00000000..d398e8b0 --- /dev/null +++ b/packages/ragbits-guardrails/pyproject.toml @@ -0,0 +1,61 @@ +[project] +name = "ragbits-guardrails" +version = "0.2.0" +description = "Guardrails module for Ragbits components" +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +authors = [ + { name = "deepsense.ai", email = "ragbits@deepsense.ai"} +] +keywords = [ + "Retrieval Augmented Generation", + "RAG", + "Large Language Models", + "LLMs", + "Generative AI", + "GenAI", + "Evaluation" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = ["ragbits-core==0.2.0"] + +[project.optional-dependencies] +openai = [ + "openai~=1.51.0", +] + +[tool.uv] +dev-dependencies = [ + "pre-commit~=3.8.0", + "pytest~=8.3.3", + "pytest-cov~=5.0.0", + "pytest-asyncio~=0.24.0", + "pip-licenses>=4.0.0,<5.0.0" +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/ragbits"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/packages/ragbits-guardrails/src/ragbits/guardrails/__init__.py b/packages/ragbits-guardrails/src/ragbits/guardrails/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-guardrails/src/ragbits/guardrails/base.py b/packages/ragbits-guardrails/src/ragbits/guardrails/base.py new file mode 100644 index 00000000..ccd12d25 --- /dev/null +++ b/packages/ragbits-guardrails/src/ragbits/guardrails/base.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod + +from pydantic import BaseModel + +from ragbits.core.prompt import Prompt + + +class GuardrailVerificationResult(BaseModel): + """ + Class representing result of guardrail verification + """ + + guardrail_name: str + succeeded: bool + fail_reason: str | None + + +class Guardrail(ABC): + """ + Abstract class representing guardrail + """ + + @abstractmethod + async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult: + """ + Verifies whether provided input meets certain criteria + + Args: + input_to_verify: prompt or output of the model to check + + Returns: + verification result + """ + + +class GuardrailManager: + """ + Class responsible for running guardrails + """ + + def __init__(self, guardrails: list[Guardrail]): + self._guardrails = guardrails + + async def verify(self, input_to_verify: Prompt | str) -> list[GuardrailVerificationResult]: + """ + Verifies whether provided input meets certain criteria + + Args: + input_to_verify: prompt or output of the model to check + + Returns: + list of verification result + """ + return [await guardrail.verify(input_to_verify) for guardrail in self._guardrails] diff --git a/packages/ragbits-guardrails/src/ragbits/guardrails/openai_moderation.py b/packages/ragbits-guardrails/src/ragbits/guardrails/openai_moderation.py new file mode 100644 index 00000000..bbc36a2a --- /dev/null +++ b/packages/ragbits-guardrails/src/ragbits/guardrails/openai_moderation.py @@ -0,0 +1,51 @@ +import base64 + +from openai import AsyncOpenAI + +from ragbits.core.prompt import Prompt +from ragbits.guardrails.base import Guardrail, GuardrailVerificationResult + + +class OpenAIModerationGuardrail(Guardrail): + """ + Guardrail based on OpenAI moderation + """ + + def __init__(self, moderation_model: str = "omni-moderation-latest"): + self._openai_client = AsyncOpenAI() + self._moderation_model = moderation_model + + async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult: + """ + Verifies whether provided input meets certain criteria + + Args: + input_to_verify: prompt or output of the model to check + + Returns: + verification result + """ + if isinstance(input_to_verify, Prompt): + inputs = [{"type": "text", "text": input_to_verify.rendered_user_prompt}] + if input_to_verify.rendered_system_prompt is not None: + inputs.append({"type": "text", "text": input_to_verify.rendered_system_prompt}) + if images := input_to_verify.images: + inputs.extend( + [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64.b64encode(im).decode('utf-8')}"}, # type: ignore + } + for im in images + ] + ) + else: + inputs = [{"type": "text", "text": input_to_verify}] + response = await self._openai_client.moderations.create(model=self._moderation_model, input=inputs) # type: ignore + + fail_reasons = [result for result in response.results if result.flagged] + return GuardrailVerificationResult( + guardrail_name=self.__class__.__name__, + succeeded=len(fail_reasons) == 0, + fail_reason=None if len(fail_reasons) == 0 else str(fail_reasons), + ) diff --git a/packages/ragbits-guardrails/tests/unit/test_openai_moderation.py b/packages/ragbits-guardrails/tests/unit/test_openai_moderation.py new file mode 100644 index 00000000..4831946b --- /dev/null +++ b/packages/ragbits-guardrails/tests/unit/test_openai_moderation.py @@ -0,0 +1,53 @@ +import os +from unittest.mock import AsyncMock, patch + +from pydantic import BaseModel + +from ragbits.guardrails.base import GuardrailManager, GuardrailVerificationResult +from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail + + +class MockedModeration(BaseModel): + flagged: bool + fail_reason: str | None + + +class MockedModerationCreateResponse(BaseModel): + results: list[MockedModeration] + + +async def test_manager(): + guardrail_mock = AsyncMock() + guardrail_mock.verify.return_value = GuardrailVerificationResult( + guardrail_name=".", succeeded=True, fail_reason=None + ) + manager = GuardrailManager([guardrail_mock]) + results = await manager.verify("test") + assert guardrail_mock.verify.call_count == 1 + assert len(results) == 1 + + +@patch.dict(os.environ, {"OPENAI_API_KEY": "."}, clear=True) +async def test_not_flagged(): + guardrail = OpenAIModerationGuardrail() + guardrail._openai_client = AsyncMock() + guardrail._openai_client.moderations.create.return_value = MockedModerationCreateResponse( + results=[MockedModeration(flagged=False, fail_reason=None)] + ) + results = await guardrail.verify("Test") + assert results.succeeded is True + assert results.fail_reason is None + assert results.guardrail_name == "OpenAIModerationGuardrail" + + +@patch.dict(os.environ, {"OPENAI_API_KEY": "."}, clear=True) +async def test_flagged(): + guardrail = OpenAIModerationGuardrail() + guardrail._openai_client = AsyncMock() + guardrail._openai_client.moderations.create.return_value = MockedModerationCreateResponse( + results=[MockedModeration(flagged=True, fail_reason="Harmful content")] + ) + results = await guardrail.verify("Test") + assert results.succeeded is False + assert results.fail_reason == "[MockedModeration(flagged=True, fail_reason='Harmful content')]" + assert results.guardrail_name == "OpenAIModerationGuardrail" diff --git a/pyproject.toml b/pyproject.toml index 3e2499ba..b56cc198 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,10 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "ragbits-cli", - "ragbits-core[litellm,local,lab,chroma]", - "ragbits-document-search[gcs, huggingface]", + "ragbits-core[chroma,lab,litellm,local,otel]", + "ragbits-document-search[gcs,huggingface]", "ragbits-evaluate[relari]", + "ragbits-guardrails[openai]", ] [tool.uv] @@ -35,6 +36,7 @@ ragbits-cli = { workspace = true } ragbits-core = { workspace = true } ragbits-document-search = { workspace = true } ragbits-evaluate = {workspace = true} +ragbits-guardrails = {workspace = true} [tool.uv.workspace] members = [ @@ -42,6 +44,7 @@ members = [ "packages/ragbits-core", "packages/ragbits-document-search", "packages/ragbits-evaluate", + "packages/ragbits-guardrails", ] [tool.pytest] @@ -88,6 +91,7 @@ mypy_path = [ "packages/ragbits-core/src", "packages/ragbits-document-search/src", "packages/ragbits-evaluate/src", + "packages/ragbits-guardrails/src", ] exclude = ["scripts"] @@ -146,7 +150,6 @@ ignore = [ "PLR0913", ] - [tool.ruff.lint.pydocstyle] convention = "google" @@ -173,7 +176,6 @@ convention = "google" docstring-code-format = true docstring-code-line-length = 120 - [tool.ruff.lint.isort] known-first-party = ["ragbits"] known-third-party = [ diff --git a/uv.lock b/uv.lock index 3472f9e0..3c70fd68 100644 --- a/uv.lock +++ b/uv.lock @@ -1,18 +1,11 @@ version = 1 requires-python = ">=3.10" resolution-markers = [ - "python_full_version < '3.11' and platform_system == 'Darwin'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')", - "python_full_version == '3.11.*' and platform_system == 'Darwin'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')", - "python_full_version == '3.12.*' and platform_system == 'Darwin'", - "python_full_version == '3.12.*' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.12.*' and platform_system != 'Darwin' and platform_system != 'Linux')", - "python_full_version >= '3.13' and platform_system == 'Darwin'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and platform_system == 'Linux'", - "(python_full_version >= '3.13' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.13' and platform_system != 'Darwin' and platform_system != 'Linux')", + "python_full_version < '3.11'", + "python_full_version == '3.11.*'", + "python_full_version >= '3.12' and python_full_version < '3.12.4'", + "python_full_version >= '3.12.4' and python_full_version < '3.13'", + "python_full_version >= '3.13'", ] [manifest] @@ -21,6 +14,7 @@ members = [ "ragbits-core", "ragbits-document-search", "ragbits-evaluate", + "ragbits-guardrails", "ragbits-workspace", ] @@ -3602,6 +3596,9 @@ local = [ { name = "torch" }, { name = "transformers" }, ] +otel = [ + { name = "opentelemetry-api" }, +] promptfoo = [ { name = "pyyaml" }, ] @@ -3622,6 +3619,7 @@ requires-dist = [ { name = "jinja2", specifier = ">=3.1.4" }, { name = "litellm", marker = "extra == 'litellm'", specifier = "~=1.46.0" }, { name = "numpy", marker = "extra == 'local'", specifier = "~=1.26.0" }, + { name = "opentelemetry-api", marker = "extra == 'otel'", specifier = "~=1.27.0" }, { name = "pydantic", specifier = ">=2.9.1" }, { name = "pyyaml", marker = "extra == 'promptfoo'", specifier = "~=6.0.2" }, { name = "tomli", specifier = "~=2.0.2" }, @@ -3729,15 +3727,53 @@ dev = [ { name = "pytest-cov", specifier = "~=5.0.0" }, ] +[[package]] +name = "ragbits-guardrails" +version = "0.2.0" +source = { editable = "packages/ragbits-guardrails" } +dependencies = [ + { name = "ragbits-core" }, +] + +[package.optional-dependencies] +openai = [ + { name = "openai" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pip-licenses" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, +] + +[package.metadata] +requires-dist = [ + { name = "openai", marker = "extra == 'openai'", specifier = "~=1.51.0" }, + { name = "ragbits-core", editable = "packages/ragbits-core" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pip-licenses", specifier = ">=4.0.0,<5.0.0" }, + { name = "pre-commit", specifier = "~=3.8.0" }, + { name = "pytest", specifier = "~=8.3.3" }, + { name = "pytest-asyncio", specifier = "~=0.24.0" }, + { name = "pytest-cov", specifier = "~=5.0.0" }, +] + [[package]] name = "ragbits-workspace" version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "ragbits-cli" }, - { name = "ragbits-core", extra = ["chroma", "lab", "litellm", "local"] }, + { name = "ragbits-core", extra = ["chroma", "lab", "litellm", "local", "otel"] }, { name = "ragbits-document-search", extra = ["gcs", "huggingface"] }, { name = "ragbits-evaluate", extra = ["relari"] }, + { name = "ragbits-guardrails", extra = ["openai"] }, ] [package.dev-dependencies] @@ -3762,9 +3798,10 @@ dev = [ [package.metadata] requires-dist = [ { name = "ragbits-cli", editable = "packages/ragbits-cli" }, - { name = "ragbits-core", extras = ["litellm", "local", "lab", "chroma"], editable = "packages/ragbits-core" }, + { name = "ragbits-core", extras = ["chroma", "lab", "litellm", "local", "otel"], editable = "packages/ragbits-core" }, { name = "ragbits-document-search", extras = ["gcs", "huggingface"], editable = "packages/ragbits-document-search" }, { name = "ragbits-evaluate", extras = ["relari"], editable = "packages/ragbits-evaluate" }, + { name = "ragbits-guardrails", extras = ["openai"], editable = "packages/ragbits-guardrails" }, ] [package.metadata.requires-dev]