Skip to content

Commit

Permalink
Feature/content safety runnable (#19)
Browse files Browse the repository at this point in the history
* create a content safety chain

* format file for linting

* change constructor format

* reformat line length for linting

* update content safety chain to use `pydantic` decorator

* Create tests for the azure_content_safety_chain

* Create docs for the class and correct hyperlinks in docstrings

* Run linting checks

* Change test files to have return type annotations

* Reformat file

* Reformat files and add required testing dependencies

* Lint file

* Add review changes

* Update docs

* Lint file
  • Loading branch information
Sheepsta300 authored Oct 19, 2024
1 parent 15aa72d commit bc1fd7b
Show file tree
Hide file tree
Showing 3 changed files with 448 additions and 0 deletions.
194 changes: 194 additions & 0 deletions docs/docs/integrations/chains/azure_ai_content_safety.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# `AzureAIContentSafetyChain`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> [Azure AI Content Safety Chain](https://learn.microsoft.com/python/api/overview/azure/ai-contentsafety-readme?view=azure-python) is a wrapper around\n",
"> the Azure AI Content Safety service, implemented in LangChain using the LangChain \n",
"> [Runnables](https://python.langchain.com/docs/how_to/lcel_cheatsheet/) base class to allow use in a Runnables Sequence.\n",
"\n",
"The Class can be used to stop or filter content based on the Azure AI Content Safety policy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example Usage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the required imports, here we will use a `ChatPromptTemplate` for convenience and the `AzureChatOpenAI`, however, any LangChain integrated model will work in a chain."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from langchain_community.chains.azure_content_safety_chain import (\n",
" AzureAIContentSafetyChain,\n",
" AzureHarmfulContentError,\n",
")\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_openai import AzureChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"moderate = AzureAIContentSafetyChain()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = AzureChatOpenAI(\n",
" openai_api_version=os.environ[\"OPENAI_API_VERSION\"],\n",
" azure_deployment=os.environ[\"COMPLETIONS_MODEL\"],\n",
" azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n",
" api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Combine the objects to create a LangChain RunnablesSequence"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"moderated_chain = moderate | prompt | model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"response = moderated_chain.invoke({\"input\": \"I like you!\"})"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'I like you!'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response.content"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With harmful content"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Harmful content: I hate you!\n"
]
},
{
"ename": "AzureHarmfulContentError",
"evalue": "The input has breached Azure's Content Safety Policy",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAzureHarmfulContentError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m----> 2\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mmoderated_chain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mI hate you!\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m AzureHarmfulContentError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mHarmful content: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39minput\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n",
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain_core/runnables/base.py:3020\u001b[0m, in \u001b[0;36mRunnableSequence.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 3018\u001b[0m context\u001b[38;5;241m.\u001b[39mrun(_set_config_context, config)\n\u001b[1;32m 3019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 3020\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mcontext\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstep\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3021\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3022\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m context\u001b[38;5;241m.\u001b[39mrun(step\u001b[38;5;241m.\u001b[39minvoke, \u001b[38;5;28minput\u001b[39m, config)\n",
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:170\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 169\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e)\n\u001b[0;32m--> 170\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 171\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs)\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m include_run_info:\n",
"File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:160\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_inputs(inputs)\n\u001b[1;32m 159\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 160\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrun_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_manager\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_arg_supported\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(inputs)\n\u001b[1;32m 163\u001b[0m )\n\u001b[1;32m 165\u001b[0m final_outputs: Dict[\u001b[38;5;28mstr\u001b[39m, Any] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(\n\u001b[1;32m 166\u001b[0m inputs, outputs, return_only_outputs\n\u001b[1;32m 167\u001b[0m )\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:161\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._call\u001b[0;34m(self, inputs, run_manager)\u001b[0m\n\u001b[1;32m 158\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclient\u001b[38;5;241m.\u001b[39manalyze_text(request)\n\u001b[1;32m 160\u001b[0m result \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mcategories_analysis\n\u001b[0;32m--> 161\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_detect_harmful_content\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_key: output, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: output}\n",
"File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:142\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._detect_harmful_content\u001b[0;34m(self, text, results)\u001b[0m\n\u001b[1;32m 137\u001b[0m error_str \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 138\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe input text contains harmful content \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maccording to Azure OpenAI\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms content policy\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39merror:\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m AzureHarmfulContentError(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mtext)\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m error_str\n",
"\u001b[0;31mAzureHarmfulContentError\u001b[0m: The input has breached Azure's Content Safety Policy"
]
}
],
"source": [
"try:\n",
" response = moderated_chain.invoke({\"input\": \"I hate you!\"})\n",
"except AzureHarmfulContentError as e:\n",
" print(f\"Harmful content: {e.input}\")\n",
" raise"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""Pass input through an azure content safety resource."""

from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.exceptions import LangChainException
from langchain_core.utils import get_from_dict_or_env
from pydantic import model_validator


class AzureHarmfulContentError(LangChainException):
"""Exception for handling harmful content detected
in input for a model or chain according to Azure's
content safety policy."""

def __init__(
self,
input: str,
):
"""Constructor
Args:
input (str): The input given by the user to the model.
"""
self.input = input
self.message = "The input has breached Azure's Content Safety Policy"
super().__init__(self.message)


class AzureAIContentSafetyChain(Chain):
"""
A wrapper for the Azure AI Content Safety API in a Runnable form.
Allows for harmful content detection and filtering before input is
provided to a model.
**Note**:
This Service will filter input that shows any sign of harmful content,
this is non-configurable.
Attributes:
error (bool): Whether to raise an error if harmful content is detected.
content_safety_key (Optional[str]): API key for Azure Content Safety.
content_safety_endpoint (Optional[str]): Endpoint URL for Azure Content Safety.
Setup:
1. Follow the instructions here to deploy Azure AI Content Safety:
https://learn.microsoft.com/azure/ai-services/content-safety/overview
2. Install ``langchain`` ``langchain_community`` and set the following
environment variables:
.. code-block:: bash
pip install -U langchain langchain-community
export AZURE_CONTENT_SAFETY_KEY="your-api-key"
export AZURE_CONTENT_SAFETY_ENDPOINT="https://your-endpoint.azure.com/"
Example Usage (with safe content):
.. code-block:: python
from langchain_community.chains import AzureAIContentSafetyChain
from langchain_openai import AzureChatOpenAI
moderate = AzureAIContentSafetyChain()
prompt = ChatPromptTemplate.from_messages([("system",
"repeat after me: {input}")])
model = AzureChatOpenAI()
moderated_chain = moderate | prompt | model
moderated_chain.invoke({"input": "Hey, How are you?"})
Example Usage (with harmful content):
.. code-block:: python
from langchain_community.chains import AzureAIContentSafetyChain
from langchain_openai import AzureChatOpenAI
moderate = AzureAIContentSafetyChain()
prompt = ChatPromptTemplate.from_messages([("system",
"repeat after me: {input}")])
model = AzureChatOpenAI()
moderated_chain = moderate | prompt | model
try:
response = moderated_chain.invoke({"input": "I hate you!"})
except AzureHarmfulContentError as e:
print(f'Harmful content: {e.input}')
raise
"""

client: Any = None #: :meta private:
error: bool = True
"""Whether or not to error if bad content was found."""
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
content_safety_key: Optional[str] = None
content_safety_endpoint: Optional[str] = None

@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]

@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""Validate that api key and python package exists in environment."""
content_safety_key = get_from_dict_or_env(
values, "content_safety_key", "CONTENT_SAFETY_API_KEY"
)
content_safety_endpoint = get_from_dict_or_env(
values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT"
)
try:
import azure.ai.contentsafety as sdk
from azure.core.credentials import AzureKeyCredential

values["client"] = sdk.ContentSafetyClient(
endpoint=content_safety_endpoint,
credential=AzureKeyCredential(content_safety_key),
)

except ImportError:
raise ImportError(
"azure-ai-contentsafety is not installed. "
"Run `pip install azure-ai-contentsafety` to install."
)
return values

def _detect_harmful_content(self, text: str, results: Any) -> str:
contains_harmful_content = False

for category in results:
if category["severity"] > 0:
contains_harmful_content = True

if contains_harmful_content:
error_str = (
"The input text contains harmful content "
"according to Azure OpenAI's content policy"
)
if self.error:
raise AzureHarmfulContentError(input=text)
else:
return error_str

return text

def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]

from azure.ai.contentsafety.models import AnalyzeTextOptions

request = AnalyzeTextOptions(text=text)
response = self.client.analyze_text(request)

result = response.categories_analysis
output = self._detect_harmful_content(text, result)

return {self.input_key: output, self.output_key: output}
Loading

0 comments on commit bc1fd7b

Please sign in to comment.