Skip to content

Commit

Permalink
enabler: guardrails component design (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad-czarnota-ds authored Nov 6, 2024
1 parent 2ea2875 commit 2cf123a
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 4 deletions.
29 changes: 29 additions & 0 deletions examples/guardrails/openai_moderation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-core",
# "openai",
# ]
# ///
import asyncio
from argparse import ArgumentParser

from ragbits.guardrails.base import GuardrailManager
from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail


async def guardrail_run(message: str) -> None:
"""
Example of using the OpenAIModerationGuardrail. Requires the OPENAI_API_KEY environment variable to be set.
"""
manager = GuardrailManager([OpenAIModerationGuardrail()])
res = await manager.verify(message)
print(res)


if __name__ == "__main__":
args = ArgumentParser()
args.add_argument("message", nargs="+", type=str, help="Message to validate")
parsed_args = args.parse_args()

asyncio.run(guardrail_run("".join(parsed_args.message)))
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
to_text_element,
)

DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini"
DEFAULT_IMAGE_QUESTION_PROMPT = "Describe the content of the image."
DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini"


class _ImagePrompt(Prompt):
Expand All @@ -34,9 +34,6 @@ class _ImagePromptInput(BaseModel):
images: list[bytes]


DEFAULT_LLM_IMAGE_DESCRIPTION_MODEL = "gpt-4o-mini"


class UnstructuredImageProvider(UnstructuredDefaultProvider):
"""
A specialized provider that handles pngs and jpgs using the Unstructured
Expand Down
Empty file.
1 change: 1 addition & 0 deletions packages/ragbits-guardrails/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Ragbits Guardrails
Empty file.
61 changes: 61 additions & 0 deletions packages/ragbits-guardrails/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
[project]
name = "ragbits-guardrails"
version = "0.2.0"
description = "Guardrails module for Ragbits components"
readme = "README.md"
requires-python = ">=3.10"
license = "MIT"
authors = [
{ name = "deepsense.ai", email = "[email protected]"}
]
keywords = [
"Retrieval Augmented Generation",
"RAG",
"Large Language Models",
"LLMs",
"Generative AI",
"GenAI",
"Evaluation"
]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = ["ragbits-core==0.2.0"]

[project.optional-dependencies]
openai = [
"openai~=1.51.0",
]

[tool.uv]
dev-dependencies = [
"pre-commit~=3.8.0",
"pytest~=8.3.3",
"pytest-cov~=5.0.0",
"pytest-asyncio~=0.24.0",
"pip-licenses>=4.0.0,<5.0.0"
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/ragbits"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
Empty file.
54 changes: 54 additions & 0 deletions packages/ragbits-guardrails/src/ragbits/guardrails/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from abc import ABC, abstractmethod

from pydantic import BaseModel

from ragbits.core.prompt import Prompt


class GuardrailVerificationResult(BaseModel):
"""
Class representing result of guardrail verification
"""

guardrail_name: str
succeeded: bool
fail_reason: str | None


class Guardrail(ABC):
"""
Abstract class representing guardrail
"""

@abstractmethod
async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult:
"""
Verifies whether provided input meets certain criteria
Args:
input_to_verify: prompt or output of the model to check
Returns:
verification result
"""


class GuardrailManager:
"""
Class responsible for running guardrails
"""

def __init__(self, guardrails: list[Guardrail]):
self._guardrails = guardrails

async def verify(self, input_to_verify: Prompt | str) -> list[GuardrailVerificationResult]:
"""
Verifies whether provided input meets certain criteria
Args:
input_to_verify: prompt or output of the model to check
Returns:
list of verification result
"""
return [await guardrail.verify(input_to_verify) for guardrail in self._guardrails]
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import base64

from openai import AsyncOpenAI

from ragbits.core.prompt import Prompt
from ragbits.guardrails.base import Guardrail, GuardrailVerificationResult


class OpenAIModerationGuardrail(Guardrail):
"""
Guardrail based on OpenAI moderation
"""

def __init__(self, moderation_model: str = "omni-moderation-latest"):
self._openai_client = AsyncOpenAI()
self._moderation_model = moderation_model

async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult:
"""
Verifies whether provided input meets certain criteria
Args:
input_to_verify: prompt or output of the model to check
Returns:
verification result
"""
if isinstance(input_to_verify, Prompt):
inputs = [{"type": "text", "text": input_to_verify.rendered_user_prompt}]
if input_to_verify.rendered_system_prompt is not None:
inputs.append({"type": "text", "text": input_to_verify.rendered_system_prompt})
if images := input_to_verify.images:
inputs.extend(
[
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64.b64encode(im).decode('utf-8')}"}, # type: ignore
}
for im in images
]
)
else:
inputs = [{"type": "text", "text": input_to_verify}]
response = await self._openai_client.moderations.create(model=self._moderation_model, input=inputs) # type: ignore

fail_reasons = [result for result in response.results if result.flagged]
return GuardrailVerificationResult(
guardrail_name=self.__class__.__name__,
succeeded=len(fail_reasons) == 0,
fail_reason=None if len(fail_reasons) == 0 else str(fail_reasons),
)
53 changes: 53 additions & 0 deletions packages/ragbits-guardrails/tests/unit/test_openai_moderation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from unittest.mock import AsyncMock, patch

from pydantic import BaseModel

from ragbits.guardrails.base import GuardrailManager, GuardrailVerificationResult
from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail


class MockedModeration(BaseModel):
flagged: bool
fail_reason: str | None


class MockedModerationCreateResponse(BaseModel):
results: list[MockedModeration]


async def test_manager():
guardrail_mock = AsyncMock()
guardrail_mock.verify.return_value = GuardrailVerificationResult(
guardrail_name=".", succeeded=True, fail_reason=None
)
manager = GuardrailManager([guardrail_mock])
results = await manager.verify("test")
assert guardrail_mock.verify.call_count == 1
assert len(results) == 1


@patch.dict(os.environ, {"OPENAI_API_KEY": "."}, clear=True)
async def test_not_flagged():
guardrail = OpenAIModerationGuardrail()
guardrail._openai_client = AsyncMock()
guardrail._openai_client.moderations.create.return_value = MockedModerationCreateResponse(
results=[MockedModeration(flagged=False, fail_reason=None)]
)
results = await guardrail.verify("Test")
assert results.succeeded is True
assert results.fail_reason is None
assert results.guardrail_name == "OpenAIModerationGuardrail"


@patch.dict(os.environ, {"OPENAI_API_KEY": "."}, clear=True)
async def test_flagged():
guardrail = OpenAIModerationGuardrail()
guardrail._openai_client = AsyncMock()
guardrail._openai_client.moderations.create.return_value = MockedModerationCreateResponse(
results=[MockedModeration(flagged=True, fail_reason="Harmful content")]
)
results = await guardrail.verify("Test")
assert results.succeeded is False
assert results.fail_reason == "[MockedModeration(flagged=True, fail_reason='Harmful content')]"
assert results.guardrail_name == "OpenAIModerationGuardrail"
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"ragbits-core[litellm,local,lab,chroma]",
"ragbits-document-search[gcs, huggingface]",
"ragbits-evaluate[relari]",
"ragbits-guardrails[openai]",
]

[tool.uv]
Expand All @@ -35,13 +36,15 @@ ragbits-cli = { workspace = true }
ragbits-core = { workspace = true }
ragbits-document-search = { workspace = true }
ragbits-evaluate = {workspace = true}
ragbits-guardrails = {workspace = true}

[tool.uv.workspace]
members = [
"packages/ragbits-cli",
"packages/ragbits-core",
"packages/ragbits-document-search",
"packages/ragbits-evaluate",
"packages/ragbits-guardrails",
]

[tool.pytest]
Expand Down Expand Up @@ -88,6 +91,7 @@ mypy_path = [
"packages/ragbits-core/src",
"packages/ragbits-document-search/src",
"packages/ragbits-evaluate/src",
"packages/ragbits-guardrails/src",
]
exclude = ["scripts"]

Expand Down
40 changes: 40 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 2cf123a

Please sign in to comment.