Skip to content

Commit

Permalink
Add unit tests and required dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheepsta300 committed Oct 8, 2024
1 parent 61a815f commit 2afbff5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
1 change: 1 addition & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ anthropic>=0.3.11,<0.4
arxiv>=1.4,<2
assemblyai>=0.17.0,<0.18
atlassian-python-api>=3.36.0,<4
azure-ai-contentsafety>=1.0.0
azure-ai-documentintelligence>=1.0.0b1,<2
azure-identity>=1.15.0,<2
azure-search-documents==11.4.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _detect_harmful_content(self, text: str) -> list:
def _format_response(self, result: list) -> str:
formatted_result = ""
for c in result:
formatted_result += f"{c.category}: {c.severity}\n"
formatted_result += f"{c['category']}: {c['severity']}\n"
return formatted_result

def _run(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests for the Azure AI Content Safety Text Tool."""

from typing import Any

import pytest

from langchain_community.tools.azure_ai_services.content_safety import (
AzureContentSafetyTextTool,
)


@pytest.mark.requires("azure.ai.contentsafety")
def test_content_safety(mocker: Any) -> None:
mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True)
mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True)

key = "key"
endpoint = "endpoint"

tool = AzureContentSafetyTextTool(
content_safety_key=key, content_safety_endpoint=endpoint
)
assert tool.content_safety_key == key
assert tool.content_safety_endpoint == endpoint


@pytest.mark.requires("azure.ai.contentsafety")
def test_harmful_content_detected(mocker: Any) -> None:
key = "key"
endpoint = "endpoint"

mocker.patch("azure.core.credentials.AzureKeyCredential", autospec=True)
mocker.patch("azure.ai.contentsafety.ContentSafetyClient", autospec=True)
tool = AzureContentSafetyTextTool(
content_safety_key=key, content_safety_endpoint=endpoint
)

mock_content_client = mocker.Mock()
mock_content_client.analyze_text.return_value.categories_analysis = [
{"category": "Harm", "severity": 1}
]

tool.content_safety_client = mock_content_client

input = "This text contains harmful content"
output = "Harm: 1\n"

result = tool._run(input)
assert result == output


@pytest.mark.requires("azure.ai.contentsafety")
def test_no_harmful_content_detected(mocker: Any) -> None:
key = "key"
endpoint = "endpoint"

tool = AzureContentSafetyTextTool(
content_safety_key=key, content_safety_endpoint=endpoint
)

mock_content_client = mocker.Mock()
mock_content_client.analyze_text.return_value.categories_analysis = [
{"category": "Harm", "severity": 0}
]

tool.content_safety_client = mock_content_client

input = "This text contains harmful content"
output = "Harm: 0\n"

result = tool._run(input)
assert result == output

0 comments on commit 2afbff5

Please sign in to comment.