-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enabler: guardrails component design (#169)
- Loading branch information
1 parent
2ea2875
commit 2cf123a
Showing
12 changed files
with
294 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# /// script | ||
# requires-python = ">=3.10" | ||
# dependencies = [ | ||
# "ragbits-core", | ||
# "openai", | ||
# ] | ||
# /// | ||
import asyncio | ||
from argparse import ArgumentParser | ||
|
||
from ragbits.guardrails.base import GuardrailManager | ||
from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail | ||
|
||
|
||
async def guardrail_run(message: str) -> None: | ||
""" | ||
Example of using the OpenAIModerationGuardrail. Requires the OPENAI_API_KEY environment variable to be set. | ||
""" | ||
manager = GuardrailManager([OpenAIModerationGuardrail()]) | ||
res = await manager.verify(message) | ||
print(res) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = ArgumentParser() | ||
args.add_argument("message", nargs="+", type=str, help="Message to validate") | ||
parsed_args = args.parse_args() | ||
|
||
asyncio.run(guardrail_run("".join(parsed_args.message))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Ragbits Guardrails |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
[project] | ||
name = "ragbits-guardrails" | ||
version = "0.2.0" | ||
description = "Guardrails module for Ragbits components" | ||
readme = "README.md" | ||
requires-python = ">=3.10" | ||
license = "MIT" | ||
authors = [ | ||
{ name = "deepsense.ai", email = "[email protected]"} | ||
] | ||
keywords = [ | ||
"Retrieval Augmented Generation", | ||
"RAG", | ||
"Large Language Models", | ||
"LLMs", | ||
"Generative AI", | ||
"GenAI", | ||
"Evaluation" | ||
] | ||
classifiers = [ | ||
"Development Status :: 4 - Beta", | ||
"Environment :: Console", | ||
"Intended Audience :: Science/Research", | ||
"License :: OSI Approved :: MIT License", | ||
"Natural Language :: English", | ||
"Operating System :: OS Independent", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Programming Language :: Python :: 3.13", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
] | ||
dependencies = ["ragbits-core==0.2.0"] | ||
|
||
[project.optional-dependencies] | ||
openai = [ | ||
"openai~=1.51.0", | ||
] | ||
|
||
[tool.uv] | ||
dev-dependencies = [ | ||
"pre-commit~=3.8.0", | ||
"pytest~=8.3.3", | ||
"pytest-cov~=5.0.0", | ||
"pytest-asyncio~=0.24.0", | ||
"pip-licenses>=4.0.0,<5.0.0" | ||
] | ||
|
||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[tool.hatch.metadata] | ||
allow-direct-references = true | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/ragbits"] | ||
|
||
[tool.pytest.ini_options] | ||
asyncio_mode = "auto" |
Empty file.
54 changes: 54 additions & 0 deletions
54
packages/ragbits-guardrails/src/ragbits/guardrails/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from pydantic import BaseModel | ||
|
||
from ragbits.core.prompt import Prompt | ||
|
||
|
||
class GuardrailVerificationResult(BaseModel): | ||
""" | ||
Class representing result of guardrail verification | ||
""" | ||
|
||
guardrail_name: str | ||
succeeded: bool | ||
fail_reason: str | None | ||
|
||
|
||
class Guardrail(ABC): | ||
""" | ||
Abstract class representing guardrail | ||
""" | ||
|
||
@abstractmethod | ||
async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult: | ||
""" | ||
Verifies whether provided input meets certain criteria | ||
Args: | ||
input_to_verify: prompt or output of the model to check | ||
Returns: | ||
verification result | ||
""" | ||
|
||
|
||
class GuardrailManager: | ||
""" | ||
Class responsible for running guardrails | ||
""" | ||
|
||
def __init__(self, guardrails: list[Guardrail]): | ||
self._guardrails = guardrails | ||
|
||
async def verify(self, input_to_verify: Prompt | str) -> list[GuardrailVerificationResult]: | ||
""" | ||
Verifies whether provided input meets certain criteria | ||
Args: | ||
input_to_verify: prompt or output of the model to check | ||
Returns: | ||
list of verification result | ||
""" | ||
return [await guardrail.verify(input_to_verify) for guardrail in self._guardrails] |
51 changes: 51 additions & 0 deletions
51
packages/ragbits-guardrails/src/ragbits/guardrails/openai_moderation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import base64 | ||
|
||
from openai import AsyncOpenAI | ||
|
||
from ragbits.core.prompt import Prompt | ||
from ragbits.guardrails.base import Guardrail, GuardrailVerificationResult | ||
|
||
|
||
class OpenAIModerationGuardrail(Guardrail): | ||
""" | ||
Guardrail based on OpenAI moderation | ||
""" | ||
|
||
def __init__(self, moderation_model: str = "omni-moderation-latest"): | ||
self._openai_client = AsyncOpenAI() | ||
self._moderation_model = moderation_model | ||
|
||
async def verify(self, input_to_verify: Prompt | str) -> GuardrailVerificationResult: | ||
""" | ||
Verifies whether provided input meets certain criteria | ||
Args: | ||
input_to_verify: prompt or output of the model to check | ||
Returns: | ||
verification result | ||
""" | ||
if isinstance(input_to_verify, Prompt): | ||
inputs = [{"type": "text", "text": input_to_verify.rendered_user_prompt}] | ||
if input_to_verify.rendered_system_prompt is not None: | ||
inputs.append({"type": "text", "text": input_to_verify.rendered_system_prompt}) | ||
if images := input_to_verify.images: | ||
inputs.extend( | ||
[ | ||
{ | ||
"type": "image_url", | ||
"image_url": {"url": f"data:image/jpeg;base64,{base64.b64encode(im).decode('utf-8')}"}, # type: ignore | ||
} | ||
for im in images | ||
] | ||
) | ||
else: | ||
inputs = [{"type": "text", "text": input_to_verify}] | ||
response = await self._openai_client.moderations.create(model=self._moderation_model, input=inputs) # type: ignore | ||
|
||
fail_reasons = [result for result in response.results if result.flagged] | ||
return GuardrailVerificationResult( | ||
guardrail_name=self.__class__.__name__, | ||
succeeded=len(fail_reasons) == 0, | ||
fail_reason=None if len(fail_reasons) == 0 else str(fail_reasons), | ||
) |
53 changes: 53 additions & 0 deletions
53
packages/ragbits-guardrails/tests/unit/test_openai_moderation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
from unittest.mock import AsyncMock, patch | ||
|
||
from pydantic import BaseModel | ||
|
||
from ragbits.guardrails.base import GuardrailManager, GuardrailVerificationResult | ||
from ragbits.guardrails.openai_moderation import OpenAIModerationGuardrail | ||
|
||
|
||
class MockedModeration(BaseModel): | ||
flagged: bool | ||
fail_reason: str | None | ||
|
||
|
||
class MockedModerationCreateResponse(BaseModel): | ||
results: list[MockedModeration] | ||
|
||
|
||
async def test_manager(): | ||
guardrail_mock = AsyncMock() | ||
guardrail_mock.verify.return_value = GuardrailVerificationResult( | ||
guardrail_name=".", succeeded=True, fail_reason=None | ||
) | ||
manager = GuardrailManager([guardrail_mock]) | ||
results = await manager.verify("test") | ||
assert guardrail_mock.verify.call_count == 1 | ||
assert len(results) == 1 | ||
|
||
|
||
@patch.dict(os.environ, {"OPENAI_API_KEY": "."}, clear=True) | ||
async def test_not_flagged(): | ||
guardrail = OpenAIModerationGuardrail() | ||
guardrail._openai_client = AsyncMock() | ||
guardrail._openai_client.moderations.create.return_value = MockedModerationCreateResponse( | ||
results=[MockedModeration(flagged=False, fail_reason=None)] | ||
) | ||
results = await guardrail.verify("Test") | ||
assert results.succeeded is True | ||
assert results.fail_reason is None | ||
assert results.guardrail_name == "OpenAIModerationGuardrail" | ||
|
||
|
||
@patch.dict(os.environ, {"OPENAI_API_KEY": "."}, clear=True) | ||
async def test_flagged(): | ||
guardrail = OpenAIModerationGuardrail() | ||
guardrail._openai_client = AsyncMock() | ||
guardrail._openai_client.moderations.create.return_value = MockedModerationCreateResponse( | ||
results=[MockedModeration(flagged=True, fail_reason="Harmful content")] | ||
) | ||
results = await guardrail.verify("Test") | ||
assert results.succeeded is False | ||
assert results.fail_reason == "[MockedModeration(flagged=True, fail_reason='Harmful content')]" | ||
assert results.guardrail_name == "OpenAIModerationGuardrail" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.