Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes to tool as per review advice #12

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import logging
from typing import Any, Dict, Optional

from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from pydantic import model_validator

logger = logging.getLogger(__name__)


class AzureAITextTranslateTool(BaseTool):
"""
A tool that interacts with the Azure Translator API using the SDK.

Attributes:
text_translation_key (str):
The API key for the Azure Translator service.
text_translation_endpoint (str):
The endpoint for the Azure Translator service.
region (str): The Azure region where the Translator service is hosted.
translate_client (TextTranslationClient):
The Azure Translator client initialized with the credentials.
default_language (str):
The default language for translation (default is 'en').

This tool queries the Azure Translator API to translate text between languages.
Input must be text (str), and the 'to_language'
parameter must be a two-letter language code (e.g., 'es', 'it', 'de').
"""

text_translation_key: Optional[str] = None
text_translation_endpoint: Optional[str] = None
region: Optional[str] = None

translate_client: Any = None
default_language: str = "en"

name: str = "azure_translator_tool"
description: str = (
"A wrapper around Azure Translator API. Useful for translating text between "
"languages. Input must be text (str), and the 'to_language'"
" parameter must be a two-letter language code (str), e.g., 'es', 'it', 'de'."
)

@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
"""
Validate that the required environment variables are set, and set up
the client.
"""

try:
from azure.ai.translation.text import TextTranslationClient
from azure.core.credentials import AzureKeyCredential
except ImportError:
raise ImportError(
"Azure Translator API packages are not installed. "
"Please install 'azure-ai-translation-text' and 'azure-core'."
)

text_translation_key = get_from_dict_or_env(
values, "text_translation_key", "AZURE_TRANSLATE_API_KEY"
)
text_translation_endpoint = get_from_dict_or_env(
values, "text_translation_endpoint", "AZURE_TRANSLATE_ENDPOINT"
)
region = get_from_dict_or_env(values, "region", "REGION")

values["translate_client"] = TextTranslationClient(
endpoint=text_translation_endpoint,
credential=AzureKeyCredential(text_translation_key),
region=region,
)

return values

def _translate_text(self, text: str, to_language: str = "en") -> str:
"""
Perform text translation using the Azure Translator API.
"""
if not text:
raise ValueError("Input text for translation is empty.")

body = [{"Text": text}]
try:
response = self.translate_client.translate(
body=body, to_language=[to_language]
)
return response[0].translations[0].text
except Exception as e:
logger.error("Translation failed: %s", e)
raise RuntimeError(f"Translation failed: {e}")

from azure.ai.translation.text import TextTranslationClient

self.translate_client: TextTranslationClient

body = [{"Text": text}]
try:
response = self.translate_client.translate(
body=body, to_language=[to_language]
)
return response[0].translations[0].text
except Exception:
raise

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
to_language: str = "en",
) -> str:
"""
Run the tool to perform translation.

Args:
query (str): The text to be translated.
run_manager (Optional[CallbackManagerForToolRun]):
A callback manager for tracking the tool run.
to_language (str): The target language for translation.

Returns:
str: The translated text.
"""
# Ensure only the text (not the full query dictionary)
# is passed to the translation function
text_to_translate = query
return self._translate_text(text_to_translate, to_language)
2 changes: 1 addition & 1 deletion libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,4 @@ develop = true

[tool.poetry.group.typing.dependencies.langchain]
path = "../langchain"
develop = true
develop = true
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ extend-exclude = [
# These files were failing the listed rules at the time ruff was adopted for notebooks.
# Don't require them to change at once, though we should look into them eventually.
"cookbook/gymnasium_agent_simulation.ipynb" = ["F821"]
"docs/docs/integrations/document_loaders/tensorflow_datasets.ipynb" = ["F821"]
"docs/docs/integrations/document_loaders/tensorflow_datasets.ipynb" = ["F821"]
70 changes: 70 additions & 0 deletions test_azure_translate_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest
from unittest.mock import patch, MagicMock
from libs.community.langchain_community.tools.azure_ai_services.text_translate import (
AzureAITextTranslateTool,
)


class TestAzureTranslateTool(unittest.TestCase):
@patch(
"azure.ai.translation.text.TextTranslationClient"
) # Mock the translation client
def setUp(self, mock_translation_client):
# Set up a mock translation client
self.mock_translation_client = mock_translation_client
self.mock_translation_instance = mock_translation_client.return_value

# Create an instance of the AzureTranslateTool
self.tool = AzureAITextTranslateTool()

# Mock environment variables
self.tool.text_translation_key = "fake_api_key"
self.tool.text_translation_endpoint = "https://fake.endpoint.com"
self.tool.region = "fake_region"

def test_translate_success(self):
# Mock the translation API response
mock_response = MagicMock()
mock_response.translations = [MagicMock(text="Hola")]
self.mock_translation_instance.translate.return_value = [mock_response]

# Call the translate function
result = self.tool._translate_text("Hello", to_language="es")

# Assert the result is as expected
self.assertEqual(result, "Hola")
self.mock_translation_instance.translate.assert_called_once_with(
body=[{"Text": "Hello"}], to_language=["es"]
)

def test_empty_text(self):
# Test that an empty input raises a ValueError
with self.assertRaises(ValueError) as context:
self.tool._translate_text("", to_language="es")

self.assertEqual(str(context.exception), "Input text for translation is empty.")

def test_api_failure(self):
# Simulate an API failure
self.mock_translation_instance.translate.side_effect = Exception("API failure")

# Test that an exception is raised and handled properly
with self.assertRaises(RuntimeError) as context:
self.tool._translate_text("Hello", to_language="es")

self.assertIn("Translation failed", str(context.exception))

@patch("azure.ai.translation.text.TextTranslationClient")
def test_validate_environment(self, mock_translation_client):
# Test that the environment is validated and the client is initialized
self.tool.validate_environment()

# Ensure the translation client is initialized properly
mock_translation_client.assert_called_once_with(
endpoint=self.tool.text_translation_endpoint, credential=unittest.mock.ANY
)


# Run the unit tests
if __name__ == "__main__":
unittest.main()
Loading
Loading