-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/content safety runnable #13
base: master
Are you sure you want to change the base?
Changes from 13 commits
a5fcbe6
a433039
697dda5
2ee37a1
7277794
aa7fa80
203d20c
242e9fc
25a636c
39fd0fd
a329647
ed84d48
a813d11
563587e
7170a4e
f5f1149
82bb0cd
a83357d
6ed2d38
c60695a
5867f25
a537495
aa2c17b
195ae7b
c161f7d
a97c53e
7b9a0d9
869c8f5
42b1882
e7dc26a
e9c1655
74d9d2c
a1c9079
7d95a10
06fafc6
ec205fc
8adc4a5
ffe7bd4
585da22
85b8cec
42b8ad0
51e98a5
bb83abd
607c60a
c09000f
0901f11
5141f25
b9bf7fd
733a6ad
8358666
ee6fc3f
a8b21af
b7e10bb
2813e86
c2f1d02
49914e9
4743348
e294e71
58d2bfe
e5b4f9a
c953f93
88d6d02
2848759
60021e5
ecee41a
9f04416
42d40d6
000be1f
6b7e93d
ff675c1
7315360
926e452
c74f34c
a220ee5
50ddf13
ab831ce
b0a8307
ef36554
6e607bb
15e7353
eff8a54
bc5ec63
a009249
d34bf78
6151ea7
e6a0835
12d74d5
eec55c2
a013014
8bc2c91
2b360d6
f459754
c5acedd
43c35d1
0f539f0
d9e42a1
ecff9a0
8f9b3b7
7d44316
d26555c
b8e861a
1e285cb
1581857
ecdfc98
3d54935
66f819c
5519a1c
7a96ce1
28f8d43
482e8a7
478def8
a197e0b
7ecf38f
bc636cc
18386c1
5277a02
f3dc142
15cbc36
dda9f90
2c6bc74
c38b845
e6663b6
221ab03
4f99952
791d7e9
1cedf40
246c10a
9e2abcd
f943205
925ca75
b7c2029
0eb7ab6
80a88f8
4a7dc6e
07c2ac7
a32035d
d801c6f
5e8553c
9b84849
dd0085a
524ee6d
fbf0704
317a38b
2354bb7
1fc4ac3
220b33d
e8508fb
079c7ea
54fba7e
481c4bf
eabe587
b53f07b
90f162e
cef21a0
ec9b414
6e6061f
ffb5c19
4c70fff
b64d846
b476fdb
bdb4cf7
9fcd203
85114b4
242fee1
ce3b69a
6815981
534b8f4
75bc6bb
2c49f58
cf6d1c0
34ca31e
b78b2f7
4e743b5
5afeb8b
9b7d49f
8c6eec5
0f0df2d
ef2f875
d0e9597
e24f86e
24292c4
5c6e2cb
aba2711
0d20c31
c1d348e
7040594
bc4dc7f
e6a62d8
1fbd86a
91227ad
d0e662e
3048a9a
b9dd4f2
df5008f
a37afbe
8c37808
fa155a4
8780f7a
b20230c
ee640d6
45f9c9a
a4713ca
0af5ad8
ba9b95c
d1e0ec7
5a31792
d834c6b
ca054ed
4149c0d
13c3c4a
28cb2ce
f601101
48ab91b
94c22c3
fa06188
c855d43
05ebe1e
86b3c6e
b0a2988
da28cf1
ce90b25
4802c31
258b3be
67fd554
f3fb5a9
498f024
9c55c75
b909d54
3107d78
af2e0a7
12111cb
337fed8
bd008ba
288f204
fc80061
008efad
decd77c
387284c
23b433f
679e3a9
089e659
a0534ae
a3851cb
4c1871d
d417e4b
690aa02
8ec1c72
4f6ccb7
29305cd
5c5fd6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# `AzureAIContentSafetyChain`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"> [Azure AI Content Safety Chain](https://learn.microsoft.com/python/api/overview/azure/ai-contentsafety-readme?view=azure-python) is a wrapper around\n", | ||
"> the Azure AI Content Safety service, implemented in LangChain using the LangChain \n", | ||
"> [Runnables](https://python.langchain.com/docs/how_to/lcel_cheatsheet/) base class to allow use in a Runnables Sequence.\n", | ||
"\n", | ||
"The Class can be used to stop or filter content based on the Azure AI Content Safety policy." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Example Usage" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Get the required imports, here we will use a `ChatPromptTemplate` for convenience and the `AzureChatOpenAI`, however, any LangChain integrated model will work in a chain." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"from langchain_community.chains.azure_content_safety_chain import (\n", | ||
" AzureOpenAIContentSafetyChain,\n", | ||
")\n", | ||
"from langchain_core.prompts import ChatPromptTemplate\n", | ||
"from langchain_openai import AzureChatOpenAI" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"moderate = AzureOpenAIContentSafetyChain()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = AzureChatOpenAI(\n", | ||
" openai_api_version=os.environ[\"OPENAI_API_VERSION\"],\n", | ||
" azure_deployment=os.environ[\"COMPLETIONS_MODEL\"],\n", | ||
" azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n", | ||
" api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Combine the objects to create a LangChain RunnablesSequence" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"moderated_chain = moderate | prompt | model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = moderated_chain.invoke({\"input\": \"I hate you!\"})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response.content" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
"""Pass input through an azure content safety resource.""" | ||
|
||
from typing import Any, Dict, List, Optional | ||
|
||
from langchain.chains.base import Chain | ||
from langchain_core.callbacks import ( | ||
CallbackManagerForChainRun, | ||
) | ||
from langchain_core.exceptions import LangChainException | ||
from langchain_core.utils import get_from_dict_or_env | ||
from pydantic import model_validator | ||
|
||
|
||
class AzureHarmfulContentError(LangChainException): | ||
"""Exception for handling harmful content detected | ||
in input for a model or chain according to Azure's | ||
content safety policy.""" | ||
|
||
def __init__( | ||
self, | ||
input: str, | ||
): | ||
"""Constructor | ||
|
||
Args: | ||
input (str): The input given by the user to the model. | ||
""" | ||
self.input = input | ||
self.message = "The input has breached Azure's Content Safety Policy" | ||
super().__init__(self.message) | ||
|
||
|
||
class AzureAIContentSafetyChain(Chain): | ||
""" | ||
A wrapper for the Azure AI Content Safety API in a Runnable form. | ||
Allows for harmful content detection and filtering before input is | ||
provided to a model. | ||
|
||
**Note**: | ||
This Service will filter input that shows any sign of harmful content, | ||
this is non-configurable. | ||
|
||
Attributes: | ||
error (bool): Whether to raise an error if harmful content is detected. | ||
content_safety_key (Optional[str]): API key for Azure Content Safety. | ||
content_safety_endpoint (Optional[str]): Endpoint URL for Azure Content Safety. | ||
|
||
Setup: | ||
1. Follow the instructions here to deploy Azure AI Content Safety: | ||
https://learn.microsoft.com/azure/ai-services/content-safety/overview | ||
|
||
2. Install ``langchain`` ``langchain_community`` and set the following | ||
environment variables: | ||
|
||
.. code-block:: bash | ||
|
||
pip install -U langchain langchain-community | ||
|
||
export AZURE_CONTENT_SAFETY_KEY="your-api-key" | ||
export AZURE_CONTENT_SAFETY_ENDPOINT="https://your-endpoint.azure.com/" | ||
|
||
|
||
Example Usage: | ||
.. code-block:: python | ||
|
||
from langchain_community.chains import AzureAIContentSafetyChain | ||
from langchain_openai import AzureChatOpenAI | ||
|
||
moderate = AzureAIContentSafetyChain() | ||
prompt = ChatPromptTemplate.from_messages([("system", | ||
"repeat after me: {input}")]) | ||
model = AzureChatOpenAI() | ||
|
||
moderated_chain = moderate | prompt | model | ||
|
||
moderated_chain.invoke({"input": "Hey, How are you?"}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be helpful to show the full example here with the code you would have to write reacting to harmful content |
||
""" | ||
|
||
client: Any = None #: :meta private: | ||
error: bool = False | ||
"""Whether or not to error if bad content was found.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would default to True. If a user was going through the trouble of adding this to their chain, it feels to me it's because they want to react to harmful content, i.e. do something if an exception is raised |
||
input_key: str = "input" #: :meta private: | ||
output_key: str = "output" #: :meta private: | ||
content_safety_key: Optional[str] = None | ||
content_safety_endpoint: Optional[str] = None | ||
|
||
@property | ||
def input_keys(self) -> List[str]: | ||
"""Expect input key. | ||
|
||
:meta private: | ||
""" | ||
return [self.input_key] | ||
|
||
@property | ||
def output_keys(self) -> List[str]: | ||
"""Return output key. | ||
|
||
:meta private: | ||
""" | ||
return [self.output_key] | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_environment(cls, values: Dict) -> Any: | ||
"""Validate that api key and python package exists in environment.""" | ||
content_safety_key = get_from_dict_or_env( | ||
values, "content_safety_key", "CONTENT_SAFETY_API_KEY" | ||
) | ||
content_safety_endpoint = get_from_dict_or_env( | ||
values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT" | ||
) | ||
try: | ||
import azure.ai.contentsafety as sdk | ||
from azure.core.credentials import AzureKeyCredential | ||
|
||
values["client"] = sdk.ContentSafetyClient( | ||
endpoint=content_safety_endpoint, | ||
credential=AzureKeyCredential(content_safety_key), | ||
) | ||
|
||
except ImportError: | ||
raise ImportError( | ||
"azure-ai-contentsafety is not installed. " | ||
"Run `pip install azure-ai-contentsafety` to install." | ||
) | ||
return values | ||
|
||
def _detect_harmful_content(self, text: str, results: Any) -> str: | ||
contains_harmful_content = False | ||
|
||
for category in results: | ||
if category["severity"] > 0: | ||
contains_harmful_content = True | ||
|
||
if contains_harmful_content: | ||
error_str = ( | ||
"The input text contains harmful content " | ||
"according to Azure OpenAI's content policy" | ||
) | ||
if self.error: | ||
raise AzureHarmfulContentError(input=text) | ||
else: | ||
return error_str | ||
|
||
return text | ||
|
||
def _call( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Dict[str, Any]: | ||
text = inputs[self.input_key] | ||
|
||
from azure.ai.contentsafety.models import AnalyzeTextOptions | ||
|
||
request = AnalyzeTextOptions(text=text) | ||
response = self.client.analyze_text(request) | ||
|
||
result = response.categories_analysis | ||
output = self._detect_harmful_content(text, result) | ||
|
||
return {self.input_key: output, self.output_key: output} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does returning the input_key: with the output do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes it so the input of the next step of the chain is the filtered content. It makes it so that if the user doesn't want an error to be thrown, the model will not receive harmful content but instead the message ''The input has breached Azure's content safety policy'. Without this the input at the next step would still be the original harmful content. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Tests for the Azure AI Content Safety Chain.""" | ||
|
||
from typing import Any | ||
|
||
import pytest | ||
|
||
from langchain_community.chains.azure_content_safety_chain import ( | ||
AzureAIContentSafetyChain, | ||
AzureHarmfulContentError, | ||
) | ||
|
||
|
||
@pytest.mark.requires("azure.ai.contentsafety") | ||
def test_content_safety(mocker: Any) -> None: | ||
mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True) | ||
mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True) | ||
|
||
key = "key" | ||
endpoint = "endpoint" | ||
|
||
chain = AzureAIContentSafetyChain( | ||
content_safety_key=key, content_safety_endpoint=endpoint | ||
) | ||
assert chain.content_safety_key == key | ||
assert chain.content_safety_endpoint == endpoint | ||
|
||
|
||
@pytest.mark.requires("azure.ai.contentsafety") | ||
def test_raise_error_when_harmful_content_detected(mocker: Any) -> None: | ||
key = "key" | ||
endpoint = "endpoint" | ||
|
||
mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True) | ||
mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True) | ||
chain = AzureAIContentSafetyChain( | ||
content_safety_key=key, content_safety_endpoint=endpoint, error=True | ||
) | ||
|
||
mock_content_client = mocker.Mock() | ||
mock_content_client.analyze_text.return_value.categories_analysis = [ | ||
{"Category": "Harm", "severity": 1} | ||
] | ||
|
||
chain.client = mock_content_client | ||
|
||
text = "This text contains harmful content" | ||
with pytest.raises(AzureHarmfulContentError): | ||
chain._call({chain.input_key: text}) | ||
|
||
|
||
@pytest.mark.requires("azure.ai.contentsafety") | ||
def test_no_harmful_content_detected(mocker: Any) -> None: | ||
key = "key" | ||
endpoint = "endpoint" | ||
|
||
mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True) | ||
mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True) | ||
chain = AzureAIContentSafetyChain( | ||
content_safety_key=key, content_safety_endpoint=endpoint, error=True | ||
) | ||
|
||
mock_content_client = mocker.Mock() | ||
mock_content_client.analyze_text.return_value.categories_analysis = [ | ||
{"Category": "Harm", "severity": 0} | ||
] | ||
|
||
chain.client = mock_content_client | ||
|
||
text = "This text contains no harmful content" | ||
output = chain._call({chain.input_key: text}) | ||
|
||
assert output[chain.output_key] == text |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to show how a user would write code to handle an exception from input violating the content policy