diff --git a/libs/community/langchain_google_community/__init__.py b/libs/community/langchain_google_community/__init__.py index 396d874a..146ddb08 100644 --- a/libs/community/langchain_google_community/__init__.py +++ b/libs/community/langchain_google_community/__init__.py @@ -23,10 +23,12 @@ VertexAIMultiTurnSearchRetriever, VertexAISearchRetriever, ) +from langchain_google_community.vision import CloudVisionLoader __all__ = [ "BigQueryLoader", "BigQueryVectorSearch", + "CloudVisionLoader", "DocAIParser", "DocAIParsingResults", "DocumentAIWarehouseRetriever", diff --git a/libs/community/langchain_google_community/vision.py b/libs/community/langchain_google_community/vision.py new file mode 100644 index 00000000..6d9f64ce --- /dev/null +++ b/libs/community/langchain_google_community/vision.py @@ -0,0 +1,42 @@ +from typing import Iterator, Optional + +from langchain_core.document_loaders import BaseBlobParser +from langchain_core.document_loaders.blob_loaders import Blob +from langchain_core.documents import Document + +from langchain_google_community._utils import get_client_info + + +class CloudVisionLoader(BaseBlobParser): + def __init__(self, project: Optional[str] = None): + try: + from google.cloud import vision # type: ignore[attr-defined] + except ImportError as e: + raise ImportError( + "Cannot import google.cloud.vision, please install " + "`pip install google-cloud-vision`." + ) from e + client_options = None + if project: + client_options = {"quota_project_id": project} + self._client = vision.ImageAnnotatorClient( + client_options=client_options, + client_info=get_client_info(module="cloud-vision"), + ) + + def load(self, gcs_uri: str) -> Document: + """Loads an image from GCS path to a Document, only the text.""" + from google.cloud import vision # type: ignore[attr-defined] + + image = vision.Image(source=vision.ImageSource(gcs_image_uri=gcs_uri)) + text_detection_response = self._client.text_detection(image=image) + annotations = text_detection_response.text_annotations + + if annotations: + text = annotations[0].description + else: + text = "" + return Document(page_content=text, metadata={"source": gcs_uri}) + + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + yield self.load(blob.path) # type: ignore[arg-type] diff --git a/libs/community/tests/integration_tests/test_vision.py b/libs/community/tests/integration_tests/test_vision.py new file mode 100644 index 00000000..0a22ed2e --- /dev/null +++ b/libs/community/tests/integration_tests/test_vision.py @@ -0,0 +1,19 @@ +import os + +import pytest +from langchain_core.document_loaders.blob_loaders import Blob +from langchain_core.documents import Document + +from langchain_google_community import CloudVisionLoader + + +@pytest.mark.skip(reason="CI/CD not ready.") +def test_parse_image() -> None: + gcs_path = os.environ["IMAGE_GCS_PATH"] + project = os.environ["PROJECT"] + blob = Blob(path=gcs_path, data="") # type: ignore + loader = CloudVisionLoader(project=project) + documents = loader.parse(blob) + assert len(documents) == 1 + assert isinstance(documents[0], Document) + assert len(documents[0].page_content) > 1