From 61a815f0f60782a29fcb3632f95cf79efde7d723 Mon Sep 17 00:00:00 2001 From: Sheepsta300 <128811766+Sheepsta300@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:16:20 +1300 Subject: [PATCH] Update class to use new v3 `pydantic` validation methods --- .../content_safety.py | 74 ++++++------------- 1 file changed, 24 insertions(+), 50 deletions(-) rename libs/community/langchain_community/tools/{azure_cognitive_services => azure_ai_services}/content_safety.py (61%) diff --git a/libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py b/libs/community/langchain_community/tools/azure_ai_services/content_safety.py similarity index 61% rename from libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py rename to libs/community/langchain_community/tools/azure_ai_services/content_safety.py index 9964760e7ccc5..5c60c3291977b 100644 --- a/libs/community/langchain_community/tools/azure_cognitive_services/content_safety.py +++ b/libs/community/langchain_community/tools/azure_ai_services/content_safety.py @@ -1,11 +1,12 @@ from __future__ import annotations import logging -import os from typing import Any, Dict, Optional from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.tools import BaseTool +from langchain_core.utils import get_from_dict_or_env +from pydantic import model_validator logger = logging.getLogger(__name__) @@ -28,16 +29,6 @@ class AzureContentSafetyTextTool(BaseTool): content_safety_client (Any): An instance of the Azure Content Safety Client used for making API requests. - - Methods: - _sentiment_analysis(text: str) -> Dict: - Analyzes the provided text to assess its sentiment and safety, - returning the analysis results. - - _run(query: str, - run_manager: Optional[CallbackManagerForToolRun] = None) -> str: - Uses the tool to analyze the given query and returns the result. - Raises a RuntimeError if an exception occurs. """ content_safety_key: str = "" #: :meta private: @@ -52,41 +43,20 @@ class AzureContentSafetyTextTool(BaseTool): "Input must be text (str)." ) - def __init__( - self, - *, - content_safety_key: Optional[str] = None, - content_safety_endpoint: Optional[str] = None, - ) -> None: - """ - Initialize the AzureContentSafetyTextTool with the given API key and endpoint. - - If not provided, the API key and endpoint are fetched from environment - variables. - - Args: - content_safety_key (Optional[str]): - The API key for Azure Content Safety API. If not provided, it will - be fetched from the environment variable 'CONTENT_SAFETY_API_KEY'. - content_safety_endpoint (Optional[str]): - The endpoint URL for Azure Content Safety API. If not provided, it - will be fetched from the environment variable - 'CONTENT_SAFETY_ENDPOINT'. - - Raises: - ImportError: If the 'azure-ai-contentsafety' package is not installed. - ValueError: If API key or endpoint is not provided and environment - variables are missing. - """ - content_safety_key = content_safety_key or os.environ["CONTENT_SAFETY_API_KEY"] - content_safety_endpoint = ( - content_safety_endpoint or os.environ["CONTENT_SAFETY_ENDPOINT"] + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + content_safety_key = get_from_dict_or_env( + values, "content_safety_key", "CONTENT_SAFETY_API_KEY" + ) + content_safety_endpoint = get_from_dict_or_env( + values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT" ) try: import azure.ai.contentsafety as sdk from azure.core.credentials import AzureKeyCredential - content_safety_client = sdk.ContentSafetyClient( + values["content_safety_client"] = sdk.ContentSafetyClient( endpoint=content_safety_endpoint, credential=AzureKeyCredential(content_safety_key), ) @@ -96,15 +66,12 @@ def __init__( "azure-ai-contentsafety is not installed. " "Run `pip install azure-ai-contentsafety` to install." ) - super().__init__( - content_safety_key=content_safety_key, - content_safety_endpoint=content_safety_endpoint, - content_safety_client=content_safety_client, - ) - def _sentiment_analysis(self, text: str) -> Dict: + return values + + def _detect_harmful_content(self, text: str) -> list: """ - Perform sentiment analysis on the provided text. + Detect harful content in the provided text. This method uses the Azure Content Safety Client to analyze the text and determine its sentiment and safety categories. @@ -122,11 +89,17 @@ def _sentiment_analysis(self, text: str) -> Dict: result = response.categories_analysis return result + def _format_response(self, result: list) -> str: + formatted_result = "" + for c in result: + formatted_result += f"{c.category}: {c.severity}\n" + return formatted_result + def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, - ) -> Dict: + ) -> str: """ Analyze the given query using the tool. @@ -146,6 +119,7 @@ def _run( RuntimeError: If an error occurs while running the tool. """ try: - return self._sentiment_analysis(query) + result = self._detect_harmful_content(query) + return self._format_response(result) except Exception as e: raise RuntimeError(f"Error while running AzureContentSafetyTextTool: {e}")