Skip to content

Commit

Permalink
Update class to use new v3 pydantic validation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheepsta300 committed Oct 4, 2024
1 parent 28fb0e7 commit 61a815f
Showing 1 changed file with 24 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import logging
import os
from typing import Any, Dict, Optional

from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from pydantic import model_validator

logger = logging.getLogger(__name__)

Expand All @@ -28,16 +29,6 @@ class AzureContentSafetyTextTool(BaseTool):
content_safety_client (Any):
An instance of the Azure Content Safety Client used for making API
requests.
Methods:
_sentiment_analysis(text: str) -> Dict:
Analyzes the provided text to assess its sentiment and safety,
returning the analysis results.
_run(query: str,
run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
Uses the tool to analyze the given query and returns the result.
Raises a RuntimeError if an exception occurs.
"""

content_safety_key: str = "" #: :meta private:
Expand All @@ -52,41 +43,20 @@ class AzureContentSafetyTextTool(BaseTool):
"Input must be text (str)."
)

def __init__(
self,
*,
content_safety_key: Optional[str] = None,
content_safety_endpoint: Optional[str] = None,
) -> None:
"""
Initialize the AzureContentSafetyTextTool with the given API key and endpoint.
If not provided, the API key and endpoint are fetched from environment
variables.
Args:
content_safety_key (Optional[str]):
The API key for Azure Content Safety API. If not provided, it will
be fetched from the environment variable 'CONTENT_SAFETY_API_KEY'.
content_safety_endpoint (Optional[str]):
The endpoint URL for Azure Content Safety API. If not provided, it
will be fetched from the environment variable
'CONTENT_SAFETY_ENDPOINT'.
Raises:
ImportError: If the 'azure-ai-contentsafety' package is not installed.
ValueError: If API key or endpoint is not provided and environment
variables are missing.
"""
content_safety_key = content_safety_key or os.environ["CONTENT_SAFETY_API_KEY"]
content_safety_endpoint = (
content_safety_endpoint or os.environ["CONTENT_SAFETY_ENDPOINT"]
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
content_safety_key = get_from_dict_or_env(
values, "content_safety_key", "CONTENT_SAFETY_API_KEY"
)
content_safety_endpoint = get_from_dict_or_env(
values, "content_safety_endpoint", "CONTENT_SAFETY_ENDPOINT"
)
try:
import azure.ai.contentsafety as sdk
from azure.core.credentials import AzureKeyCredential

content_safety_client = sdk.ContentSafetyClient(
values["content_safety_client"] = sdk.ContentSafetyClient(
endpoint=content_safety_endpoint,
credential=AzureKeyCredential(content_safety_key),
)
Expand All @@ -96,15 +66,12 @@ def __init__(
"azure-ai-contentsafety is not installed. "
"Run `pip install azure-ai-contentsafety` to install."
)
super().__init__(
content_safety_key=content_safety_key,
content_safety_endpoint=content_safety_endpoint,
content_safety_client=content_safety_client,
)

def _sentiment_analysis(self, text: str) -> Dict:
return values

def _detect_harmful_content(self, text: str) -> list:
"""
Perform sentiment analysis on the provided text.
Detect harful content in the provided text.
This method uses the Azure Content Safety Client to analyze the text and
determine its sentiment and safety categories.
Expand All @@ -122,11 +89,17 @@ def _sentiment_analysis(self, text: str) -> Dict:
result = response.categories_analysis
return result

def _format_response(self, result: list) -> str:
formatted_result = ""
for c in result:
formatted_result += f"{c.category}: {c.severity}\n"
return formatted_result

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Dict:
) -> str:
"""
Analyze the given query using the tool.
Expand All @@ -146,6 +119,7 @@ def _run(
RuntimeError: If an error occurs while running the tool.
"""
try:
return self._sentiment_analysis(query)
result = self._detect_harmful_content(query)
return self._format_response(result)
except Exception as e:
raise RuntimeError(f"Error while running AzureContentSafetyTextTool: {e}")

0 comments on commit 61a815f

Please sign in to comment.