Skip to content

Commit

Permalink
Make visual features configurable by the user
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheepsta300 committed Oct 7, 2024
1 parent 7a07196 commit f458705
Showing 1 changed file with 16 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.tools import BaseTool
Expand All @@ -22,16 +22,17 @@ class AzureAiServicesImageAnalysisTool(BaseTool):
https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/quickstarts-sdk/image-analysis-client-library-40
"""

azure_ai_services_key: str = "" #: :meta private:
azure_ai_services_endpoint: str = "" #: :meta private:
azure_ai_services_key: Optional[str] = None #: :meta private:
azure_ai_services_endpoint: Optional[str] = None #: :meta private:
visual_features: Optional[List[str]] = None
image_analysis_client: Any #: :meta private:
visual_features: Any #: :meta private:
formatted_features: Any #: :meta private:

name: str = "azure_ai_services_image_analysis"
description: str = (
"A wrapper around Azure AI Services Image Analysis. "
"Useful for when you need to analyze images. "
"Input should be a url to an image."
"Input must be a url string or path string to an image."
)

@model_validator(mode="before")
Expand Down Expand Up @@ -67,36 +68,36 @@ def validate_environment(cls, values: Dict) -> Any:
raise RuntimeError(
f"Initialization of Azure AI Vision Image Analysis client failed: {e}"
)

values["visual_features"] = [
VisualFeatures.TAGS,
VisualFeatures.OBJECTS,
VisualFeatures.CAPTION,
VisualFeatures.READ,
]

visual_features: List[str] = values.get("visual_features", ["TAGS"])
values["formatted_features"] = [
VisualFeatures[feat.upper()] for feat in visual_features
]

return values

def _image_analysis(self, image_path: str) -> Dict:
try:
from azure.ai.vision.imageanalysis import ImageAnalysisClient
from azure.ai.vision.imageanalysis.models import VisualFeatures
except ImportError:
pass

self.image_analysis_client: ImageAnalysisClient

self.formatted_features: List[VisualFeatures]

image_src_type = detect_file_src_type(image_path)
if image_src_type == "local":
with open(image_path, "rb") as image_file:
image_data = image_file.read()
result = self.image_analysis_client.analyze(
image_data=image_data,
visual_features=self.visual_features,
visual_features=self.formatted_features,
)
elif image_src_type == "remote":
result = self.image_analysis_client.analyze_from_url(
image_url=image_path,
visual_features=self.visual_features,
visual_features=self.formatted_features,
)
else:
raise ValueError(f"Invalid image path: {image_path}")
Expand Down

0 comments on commit f458705

Please sign in to comment.