diff --git a/docs/docs/integrations/chains/azure_ai_content_safety.ipynb b/docs/docs/integrations/chains/azure_ai_content_safety.ipynb new file mode 100644 index 0000000000000..99664c74a7494 --- /dev/null +++ b/docs/docs/integrations/chains/azure_ai_content_safety.ipynb @@ -0,0 +1,194 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# `AzureAIContentSafetyChain`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> [Azure AI Content Safety Chain](https://learn.microsoft.com/python/api/overview/azure/ai-contentsafety-readme?view=azure-python) is a wrapper around\n", + "> the Azure AI Content Safety service, implemented in LangChain using the LangChain \n", + "> [Runnables](https://python.langchain.com/docs/how_to/lcel_cheatsheet/) base class to allow use in a Runnables Sequence.\n", + "\n", + "The Class can be used to stop or filter content based on the Azure AI Content Safety policy." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example Usage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get the required imports, here we will use a `ChatPromptTemplate` for convenience and the `AzureChatOpenAI`, however, any LangChain integrated model will work in a chain." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from langchain_community.chains.azure_content_safety_chain import (\n", + " AzureAIContentSafetyChain,\n", + " AzureHarmfulContentError,\n", + ")\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "from langchain_openai import AzureChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "moderate = AzureAIContentSafetyChain()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model = AzureChatOpenAI(\n", + " openai_api_version=os.environ[\"OPENAI_API_VERSION\"],\n", + " azure_deployment=os.environ[\"COMPLETIONS_MODEL\"],\n", + " azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n", + " api_key=os.environ[\"AZURE_OPENAI_API_KEY\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = ChatPromptTemplate.from_messages([(\"system\", \"repeat after me: {input}\")])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Combine the objects to create a LangChain RunnablesSequence" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "moderated_chain = moderate | prompt | model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "response = moderated_chain.invoke({\"input\": \"I like you!\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'I like you!'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response.content" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With harmful content" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Harmful content: I hate you!\n" + ] + }, + { + "ename": "AzureHarmfulContentError", + "evalue": "The input has breached Azure's Content Safety Policy", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAzureHarmfulContentError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m----> 2\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mmoderated_chain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m(\u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mI hate you!\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m AzureHarmfulContentError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mHarmful content: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;241m.\u001b[39minput\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n", + "File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain_core/runnables/base.py:3020\u001b[0m, in \u001b[0;36mRunnableSequence.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 3018\u001b[0m context\u001b[38;5;241m.\u001b[39mrun(_set_config_context, config)\n\u001b[1;32m 3019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 3020\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mcontext\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstep\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minvoke\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3021\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3022\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m context\u001b[38;5;241m.\u001b[39mrun(step\u001b[38;5;241m.\u001b[39minvoke, \u001b[38;5;28minput\u001b[39m, config)\n", + "File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:170\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 169\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e)\n\u001b[0;32m--> 170\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 171\u001b[0m run_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs)\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m include_run_info:\n", + "File \u001b[0;32m~/.local/lib/python3.11/site-packages/langchain/chains/base.py:160\u001b[0m, in \u001b[0;36mChain.invoke\u001b[0;34m(self, input, config, **kwargs)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_inputs(inputs)\n\u001b[1;32m 159\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 160\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrun_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_manager\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m new_arg_supported\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call(inputs)\n\u001b[1;32m 163\u001b[0m )\n\u001b[1;32m 165\u001b[0m final_outputs: Dict[\u001b[38;5;28mstr\u001b[39m, Any] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(\n\u001b[1;32m 166\u001b[0m inputs, outputs, return_only_outputs\n\u001b[1;32m 167\u001b[0m )\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:161\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._call\u001b[0;34m(self, inputs, run_manager)\u001b[0m\n\u001b[1;32m 158\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclient\u001b[38;5;241m.\u001b[39manalyze_text(request)\n\u001b[1;32m 160\u001b[0m result \u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mcategories_analysis\n\u001b[0;32m--> 161\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_detect_harmful_content\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_key: output, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: output}\n", + "File \u001b[0;32m/workspaces/langchain/docs/docs/integrations/chains/../../../../libs/community/langchain_community/chains/azure_content_safety_chain.py:142\u001b[0m, in \u001b[0;36mAzureAIContentSafetyChain._detect_harmful_content\u001b[0;34m(self, text, results)\u001b[0m\n\u001b[1;32m 137\u001b[0m error_str \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 138\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe input text contains harmful content \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maccording to Azure OpenAI\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms content policy\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 140\u001b[0m )\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39merror:\n\u001b[0;32m--> 142\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m AzureHarmfulContentError(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m=\u001b[39mtext)\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m error_str\n", + "\u001b[0;31mAzureHarmfulContentError\u001b[0m: The input has breached Azure's Content Safety Policy" + ] + } + ], + "source": [ + "try:\n", + " response = moderated_chain.invoke({\"input\": \"I hate you!\"})\n", + "except AzureHarmfulContentError as e:\n", + " print(f\"Harmful content: {e.input}\")\n", + " raise" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chains/azure_content_safety_chain.py b/libs/community/langchain_community/chains/azure_content_safety_chain.py new file mode 100644 index 0000000000000..b401d6bd669c7 --- /dev/null +++ b/libs/community/langchain_community/chains/azure_content_safety_chain.py @@ -0,0 +1,182 @@ +"""Pass input through an azure content safety resource.""" + +from typing import Any, Dict, List, Optional + +from langchain.chains.base import Chain +from langchain_core.callbacks import ( + CallbackManagerForChainRun, +) +from langchain_core.exceptions import LangChainException +from langchain_core.utils import get_from_dict_or_env +from pydantic import model_validator + + +class AzureHarmfulContentError(LangChainException): + """Exception for handling harmful content detected + in input for a model or chain according to Azure's + content safety policy.""" + + def __init__( + self, + input: str, + ): + """Constructor + + Args: + input (str): The input given by the user to the model. + """ + self.input = input + self.message = "The input has breached Azure's Content Safety Policy" + super().__init__(self.message) + + +class AzureAIContentSafetyChain(Chain): + """ + A wrapper for the Azure AI Content Safety API in a Runnable form. + Allows for harmful content detection and filtering before input is + provided to a model. + + **Note**: + This Service will filter input that shows any sign of harmful content, + this is non-configurable. + + Attributes: + error (bool): Whether to raise an error if harmful content is detected. + content_safety_key (Optional[str]): API key for Azure Content Safety. + content_safety_endpoint (Optional[str]): Endpoint URL for Azure Content Safety. + + Setup: + 1. Follow the instructions here to deploy Azure AI Content Safety: + https://learn.microsoft.com/azure/ai-services/content-safety/overview + + 2. Install ``langchain`` ``langchain_community`` and set the following + environment variables: + + .. code-block:: bash + + pip install -U langchain langchain-community + + export AZURE_CONTENT_SAFETY_KEY="your-api-key" + export AZURE_CONTENT_SAFETY_ENDPOINT="https://your-endpoint.azure.com/" + + + Example Usage (with safe content): + .. code-block:: python + + from langchain_community.chains import AzureAIContentSafetyChain + from langchain_openai import AzureChatOpenAI + + moderate = AzureAIContentSafetyChain() + prompt = ChatPromptTemplate.from_messages([("system", + "repeat after me: {input}")]) + model = AzureChatOpenAI() + + moderated_chain = moderate | prompt | model + + moderated_chain.invoke({"input": "Hey, How are you?"}) + + Example Usage (with harmful content): + .. code-block:: python + + from langchain_community.chains import AzureAIContentSafetyChain + from langchain_openai import AzureChatOpenAI + + moderate = AzureAIContentSafetyChain() + prompt = ChatPromptTemplate.from_messages([("system", + "repeat after me: {input}")]) + model = AzureChatOpenAI() + + moderated_chain = moderate | prompt | model + + try: + response = moderated_chain.invoke({"input": "I hate you!"}) + except AzureHarmfulContentError as e: + print(f'Harmful content: {e.input}') + raise + """ + + client: Any = None #: :meta private: + error: bool = True + """Whether or not to error if bad content was found.""" + input_key: str = "input" #: :meta private: + output_key: str = "output" #: :meta private: + content_safety_key: Optional[str] = None + content_safety_endpoint: Optional[str] = None + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [self.output_key] + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validate that api key and python package exists in environment.""" + content_safety_key = get_from_dict_or_env( + values, "content_safety_key", "CONTENT_SAFETY_API_KEY" + ) + content_safety_endpoint = get_from_dict_or_env( + values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT" + ) + try: + import azure.ai.contentsafety as sdk + from azure.core.credentials import AzureKeyCredential + + values["client"] = sdk.ContentSafetyClient( + endpoint=content_safety_endpoint, + credential=AzureKeyCredential(content_safety_key), + ) + + except ImportError: + raise ImportError( + "azure-ai-contentsafety is not installed. " + "Run `pip install azure-ai-contentsafety` to install." + ) + return values + + def _detect_harmful_content(self, text: str, results: Any) -> str: + contains_harmful_content = False + + for category in results: + if category["severity"] > 0: + contains_harmful_content = True + + if contains_harmful_content: + error_str = ( + "The input text contains harmful content " + "according to Azure OpenAI's content policy" + ) + if self.error: + raise AzureHarmfulContentError(input=text) + else: + return error_str + + return text + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + text = inputs[self.input_key] + + from azure.ai.contentsafety.models import AnalyzeTextOptions + + request = AnalyzeTextOptions(text=text) + response = self.client.analyze_text(request) + + result = response.categories_analysis + output = self._detect_harmful_content(text, result) + + return {self.input_key: output, self.output_key: output} diff --git a/libs/community/tests/unit_tests/chains/test_azure_content_safety.py b/libs/community/tests/unit_tests/chains/test_azure_content_safety.py new file mode 100644 index 0000000000000..fd83aefcc920c --- /dev/null +++ b/libs/community/tests/unit_tests/chains/test_azure_content_safety.py @@ -0,0 +1,72 @@ +"""Tests for the Azure AI Content Safety Chain.""" + +from typing import Any + +import pytest + +from langchain_community.chains.azure_content_safety_chain import ( + AzureAIContentSafetyChain, + AzureHarmfulContentError, +) + + +@pytest.mark.requires("azure.ai.contentsafety") +def test_content_safety(mocker: Any) -> None: + mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True) + mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True) + + key = "key" + endpoint = "endpoint" + + chain = AzureAIContentSafetyChain( + content_safety_key=key, content_safety_endpoint=endpoint + ) + assert chain.content_safety_key == key + assert chain.content_safety_endpoint == endpoint + + +@pytest.mark.requires("azure.ai.contentsafety") +def test_raise_error_when_harmful_content_detected(mocker: Any) -> None: + key = "key" + endpoint = "endpoint" + + mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True) + mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True) + chain = AzureAIContentSafetyChain( + content_safety_key=key, content_safety_endpoint=endpoint, error=True + ) + + mock_content_client = mocker.Mock() + mock_content_client.analyze_text.return_value.categories_analysis = [ + {"Category": "Harm", "severity": 1} + ] + + chain.client = mock_content_client + + text = "This text contains harmful content" + with pytest.raises(AzureHarmfulContentError): + chain._call({chain.input_key: text}) + + +@pytest.mark.requires("azure.ai.contentsafety") +def test_no_harmful_content_detected(mocker: Any) -> None: + key = "key" + endpoint = "endpoint" + + mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True) + mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True) + chain = AzureAIContentSafetyChain( + content_safety_key=key, content_safety_endpoint=endpoint, error=True + ) + + mock_content_client = mocker.Mock() + mock_content_client.analyze_text.return_value.categories_analysis = [ + {"Category": "Harm", "severity": 0} + ] + + chain.client = mock_content_client + + text = "This text contains no harmful content" + output = chain._call({chain.input_key: text}) + + assert output[chain.output_key] == text