From f458705aaacb8ec1eed9813ffc7fef70fccfaafc Mon Sep 17 00:00:00 2001 From: Sheepsta300 <128811766+Sheepsta300@users.noreply.github.com> Date: Mon, 7 Oct 2024 02:10:45 +0000 Subject: [PATCH] Make visual features configurable by the user --- .../tools/azure_ai_services/image_analysis.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/libs/community/langchain_community/tools/azure_ai_services/image_analysis.py b/libs/community/langchain_community/tools/azure_ai_services/image_analysis.py index 0a1e206e45e7d..a357d6d7058cc 100644 --- a/libs/community/langchain_community/tools/azure_ai_services/image_analysis.py +++ b/libs/community/langchain_community/tools/azure_ai_services/image_analysis.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.tools import BaseTool @@ -22,16 +22,17 @@ class AzureAiServicesImageAnalysisTool(BaseTool): https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/quickstarts-sdk/image-analysis-client-library-40 """ - azure_ai_services_key: str = "" #: :meta private: - azure_ai_services_endpoint: str = "" #: :meta private: + azure_ai_services_key: Optional[str] = None #: :meta private: + azure_ai_services_endpoint: Optional[str] = None #: :meta private: + visual_features: Optional[List[str]] = None image_analysis_client: Any #: :meta private: - visual_features: Any #: :meta private: + formatted_features: Any #: :meta private: name: str = "azure_ai_services_image_analysis" description: str = ( "A wrapper around Azure AI Services Image Analysis. " "Useful for when you need to analyze images. " - "Input should be a url to an image." + "Input must be a url string or path string to an image." ) @model_validator(mode="before") @@ -67,36 +68,36 @@ def validate_environment(cls, values: Dict) -> Any: raise RuntimeError( f"Initialization of Azure AI Vision Image Analysis client failed: {e}" ) - - values["visual_features"] = [ - VisualFeatures.TAGS, - VisualFeatures.OBJECTS, - VisualFeatures.CAPTION, - VisualFeatures.READ, - ] + + visual_features: List[str] = values.get("visual_features", ["TAGS"]) + values["formatted_features"] = [ + VisualFeatures[feat.upper()] for feat in visual_features + ] return values def _image_analysis(self, image_path: str) -> Dict: try: from azure.ai.vision.imageanalysis import ImageAnalysisClient + from azure.ai.vision.imageanalysis.models import VisualFeatures except ImportError: pass self.image_analysis_client: ImageAnalysisClient - + self.formatted_features: List[VisualFeatures] + image_src_type = detect_file_src_type(image_path) if image_src_type == "local": with open(image_path, "rb") as image_file: image_data = image_file.read() result = self.image_analysis_client.analyze( image_data=image_data, - visual_features=self.visual_features, + visual_features=self.formatted_features, ) elif image_src_type == "remote": result = self.image_analysis_client.analyze_from_url( image_url=image_path, - visual_features=self.visual_features, + visual_features=self.formatted_features, ) else: raise ValueError(f"Invalid image path: {image_path}")