forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/content safety runnable (#19)
* create a content safety chain * format file for linting * change constructor format * reformat line length for linting * update content safety chain to use `pydantic` decorator * Create tests for the azure_content_safety_chain * Create docs for the class and correct hyperlinks in docstrings * Run linting checks * Change test files to have return type annotations * Reformat file * Reformat files and add required testing dependencies * Lint file * Add review changes * Update docs * Lint file
- Loading branch information
1 parent
15aa72d
commit bc1fd7b
Showing
3 changed files
with
448 additions
and
0 deletions.
There are no files selected for viewing
194 changes: 194 additions & 0 deletions
194
docs/docs/integrations/chains/azure_ai_content_safety.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# `AzureAIContentSafetyChain`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"> [Azure AI Content Safety Chain](https://learn.microsoft.com/python/api/overview/azure/ai-contentsafety-readme?view=azure-python) is a wrapper around\n", | ||
"> the Azure AI Content Safety service, implemented in LangChain using the LangChain \n", | ||
"> [Runnables](https://python.langchain.com/docs/how_to/lcel_cheatsheet/) base class to allow use in a Runnables Sequence.\n", | ||
"\n", | ||
"The Class can be used to stop or filter content based on the Azure AI Content Safety policy." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Example Usage" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Get the required imports, here we will use a `ChatPromptTemplate` for convenience and the `AzureChatOpenAI`, however, any LangChain integrated model will work in a chain." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"from langchain_community.chains.azure_content_safety_chain import (\n", | ||
" AzureAIContentSafetyChain,\n", | ||
" AzureHarmfulContentError,\n", | ||
")\n", | ||
"from langchain_core.prompts import ChatPromptTemplate\n", | ||
"from langchain_openai import AzureChatOpenAI" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"moderate = AzureAIContentSafetyChain()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = AzureChatOpenAI(\n", | ||
" openai_api_version=os.environ[\"OPENAI_API_VERSION\"],\n", | ||
" azure_deployment=os.environ[\"COMPLETIONS_MODEL\"],\n", | ||
" azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n", | ||
" api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Combine the objects to create a LangChain RunnablesSequence" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"moderated_chain = moderate | prompt | model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"response = moderated_chain.invoke({\"input\": \"I like you!\"})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"'I like you!'" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"response.content" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"With harmful content" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Harmful content: I hate you!\n" | ||
] | ||
}, | ||
{ | ||
"ename": "AzureHarmfulContentError", | ||
"evalue": "The input has breached Azure's Content Safety Policy", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mAzureHarmfulContentError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m----> 2\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mmoderated_chain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mI hate you!\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m AzureHarmfulContentError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mHarmful content: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39minput\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", | ||
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain_core/runnables/base.py:3020\u001b[0m, in \u001b[0;36mRunnableSequence.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 3018\u001b[0m context\u001b[38;5;241m.\u001b[39mrun(_set_config_context, config)\n\u001b[1;32m 3019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 3020\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mcontext\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstep\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3021\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3022\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m context\u001b[38;5;241m.\u001b[39mrun(step\u001b[38;5;241m.\u001b[39minvoke, \u001b[38;5;28minput\u001b[39m, config)\n", | ||
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:170\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 169\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e)\n\u001b[0;32m--> 170\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 171\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs)\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m include_run_info:\n", | ||
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:160\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_inputs(inputs)\n\u001b[1;32m 159\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 160\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrun_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_manager\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_arg_supported\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(inputs)\n\u001b[1;32m 163\u001b[0m )\n\u001b[1;32m 165\u001b[0m final_outputs: Dict[\u001b[38;5;28mstr\u001b[39m, Any] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(\n\u001b[1;32m 166\u001b[0m inputs, outputs, return_only_outputs\n\u001b[1;32m 167\u001b[0m )\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n", | ||
"File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:161\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._call\u001b[0;34m(self, inputs, run_manager)\u001b[0m\n\u001b[1;32m 158\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclient\u001b[38;5;241m.\u001b[39manalyze_text(request)\n\u001b[1;32m 160\u001b[0m result \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mcategories_analysis\n\u001b[0;32m--> 161\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_detect_harmful_content\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_key: output, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: output}\n", | ||
"File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:142\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._detect_harmful_content\u001b[0;34m(self, text, results)\u001b[0m\n\u001b[1;32m 137\u001b[0m error_str \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 138\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe input text contains harmful content \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maccording to Azure OpenAI\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms content policy\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39merror:\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m AzureHarmfulContentError(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mtext)\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m error_str\n", | ||
"\u001b[0;31mAzureHarmfulContentError\u001b[0m: The input has breached Azure's Content Safety Policy" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"try:\n", | ||
" response = moderated_chain.invoke({\"input\": \"I hate you!\"})\n", | ||
"except AzureHarmfulContentError as e:\n", | ||
" print(f\"Harmful content: {e.input}\")\n", | ||
" raise" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
182 changes: 182 additions & 0 deletions
182
libs/community/langchain_community/chains/azure_content_safety_chain.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
"""Pass input through an azure content safety resource.""" | ||
|
||
from typing import Any, Dict, List, Optional | ||
|
||
from langchain.chains.base import Chain | ||
from langchain_core.callbacks import ( | ||
CallbackManagerForChainRun, | ||
) | ||
from langchain_core.exceptions import LangChainException | ||
from langchain_core.utils import get_from_dict_or_env | ||
from pydantic import model_validator | ||
|
||
|
||
class AzureHarmfulContentError(LangChainException): | ||
"""Exception for handling harmful content detected | ||
in input for a model or chain according to Azure's | ||
content safety policy.""" | ||
|
||
def __init__( | ||
self, | ||
input: str, | ||
): | ||
"""Constructor | ||
Args: | ||
input (str): The input given by the user to the model. | ||
""" | ||
self.input = input | ||
self.message = "The input has breached Azure's Content Safety Policy" | ||
super().__init__(self.message) | ||
|
||
|
||
class AzureAIContentSafetyChain(Chain): | ||
""" | ||
A wrapper for the Azure AI Content Safety API in a Runnable form. | ||
Allows for harmful content detection and filtering before input is | ||
provided to a model. | ||
**Note**: | ||
This Service will filter input that shows any sign of harmful content, | ||
this is non-configurable. | ||
Attributes: | ||
error (bool): Whether to raise an error if harmful content is detected. | ||
content_safety_key (Optional[str]): API key for Azure Content Safety. | ||
content_safety_endpoint (Optional[str]): Endpoint URL for Azure Content Safety. | ||
Setup: | ||
1. Follow the instructions here to deploy Azure AI Content Safety: | ||
https://learn.microsoft.com/azure/ai-services/content-safety/overview | ||
2. Install ``langchain`` ``langchain_community`` and set the following | ||
environment variables: | ||
.. code-block:: bash | ||
pip install -U langchain langchain-community | ||
export AZURE_CONTENT_SAFETY_KEY="your-api-key" | ||
export AZURE_CONTENT_SAFETY_ENDPOINT="https://your-endpoint.azure.com/" | ||
Example Usage (with safe content): | ||
.. code-block:: python | ||
from langchain_community.chains import AzureAIContentSafetyChain | ||
from langchain_openai import AzureChatOpenAI | ||
moderate = AzureAIContentSafetyChain() | ||
prompt = ChatPromptTemplate.from_messages([("system", | ||
"repeat after me: {input}")]) | ||
model = AzureChatOpenAI() | ||
moderated_chain = moderate | prompt | model | ||
moderated_chain.invoke({"input": "Hey, How are you?"}) | ||
Example Usage (with harmful content): | ||
.. code-block:: python | ||
from langchain_community.chains import AzureAIContentSafetyChain | ||
from langchain_openai import AzureChatOpenAI | ||
moderate = AzureAIContentSafetyChain() | ||
prompt = ChatPromptTemplate.from_messages([("system", | ||
"repeat after me: {input}")]) | ||
model = AzureChatOpenAI() | ||
moderated_chain = moderate | prompt | model | ||
try: | ||
response = moderated_chain.invoke({"input": "I hate you!"}) | ||
except AzureHarmfulContentError as e: | ||
print(f'Harmful content: {e.input}') | ||
raise | ||
""" | ||
|
||
client: Any = None #: :meta private: | ||
error: bool = True | ||
"""Whether or not to error if bad content was found.""" | ||
input_key: str = "input" #: :meta private: | ||
output_key: str = "output" #: :meta private: | ||
content_safety_key: Optional[str] = None | ||
content_safety_endpoint: Optional[str] = None | ||
|
||
@property | ||
def input_keys(self) -> List[str]: | ||
"""Expect input key. | ||
:meta private: | ||
""" | ||
return [self.input_key] | ||
|
||
@property | ||
def output_keys(self) -> List[str]: | ||
"""Return output key. | ||
:meta private: | ||
""" | ||
return [self.output_key] | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_environment(cls, values: Dict) -> Any: | ||
"""Validate that api key and python package exists in environment.""" | ||
content_safety_key = get_from_dict_or_env( | ||
values, "content_safety_key", "CONTENT_SAFETY_API_KEY" | ||
) | ||
content_safety_endpoint = get_from_dict_or_env( | ||
values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT" | ||
) | ||
try: | ||
import azure.ai.contentsafety as sdk | ||
from azure.core.credentials import AzureKeyCredential | ||
|
||
values["client"] = sdk.ContentSafetyClient( | ||
endpoint=content_safety_endpoint, | ||
credential=AzureKeyCredential(content_safety_key), | ||
) | ||
|
||
except ImportError: | ||
raise ImportError( | ||
"azure-ai-contentsafety is not installed. " | ||
"Run `pip install azure-ai-contentsafety` to install." | ||
) | ||
return values | ||
|
||
def _detect_harmful_content(self, text: str, results: Any) -> str: | ||
contains_harmful_content = False | ||
|
||
for category in results: | ||
if category["severity"] > 0: | ||
contains_harmful_content = True | ||
|
||
if contains_harmful_content: | ||
error_str = ( | ||
"The input text contains harmful content " | ||
"according to Azure OpenAI's content policy" | ||
) | ||
if self.error: | ||
raise AzureHarmfulContentError(input=text) | ||
else: | ||
return error_str | ||
|
||
return text | ||
|
||
def _call( | ||
self, | ||
inputs: Dict[str, Any], | ||
run_manager: Optional[CallbackManagerForChainRun] = None, | ||
) -> Dict[str, Any]: | ||
text = inputs[self.input_key] | ||
|
||
from azure.ai.contentsafety.models import AnalyzeTextOptions | ||
|
||
request = AnalyzeTextOptions(text=text) | ||
response = self.client.analyze_text(request) | ||
|
||
result = response.categories_analysis | ||
output = self._detect_harmful_content(text, result) | ||
|
||
return {self.input_key: output, self.output_key: output} |
Oops, something went wrong.