From b72e14cc442d77582bfa11ad611a5fafbaccec28 Mon Sep 17 00:00:00 2001 From: Sergio Vidiella Pinto <61970661+svidiella@users.noreply.github.com> Date: Thu, 29 Feb 2024 15:54:05 +0100 Subject: [PATCH] feat: add support for multimodal model and add embed_image (#44) * feat: add support for multimodal model and add embed_image --- libs/vertexai/.gitignore | 1 + .../langchain_google_vertexai/embeddings.py | 65 ++++++++++++++++++- .../tests/integration_tests/conftest.py | 39 +++++++++++ .../integration_tests/test_embeddings.py | 28 +++++++- .../integration_tests/test_vision_models.py | 25 ------- .../tests/unit_tests/test_embeddings.py | 51 +++++++++++++++ 6 files changed, 182 insertions(+), 27 deletions(-) create mode 100644 libs/vertexai/tests/integration_tests/conftest.py create mode 100644 libs/vertexai/tests/unit_tests/test_embeddings.py diff --git a/libs/vertexai/.gitignore b/libs/vertexai/.gitignore index bee8a64b..7dcfd0e2 100644 --- a/libs/vertexai/.gitignore +++ b/libs/vertexai/.gitignore @@ -1 +1,2 @@ __pycache__ +.mypy_cache_test diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index f7a9b97f..a849462d 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -3,6 +3,7 @@ import string import threading from concurrent.futures import ThreadPoolExecutor, wait +from enum import Enum, auto from typing import Any, Dict, List, Literal, Optional, Tuple, Type from google.api_core.exceptions import ( @@ -19,6 +20,11 @@ TextEmbeddingInput, TextEmbeddingModel, ) +from vertexai.vision_models import ( # type: ignore + Image, + MultiModalEmbeddingModel, + MultiModalEmbeddingResponse, +) from langchain_google_vertexai._base import _VertexAICommon @@ -29,6 +35,19 @@ _MIN_BATCH_SIZE = 5 +class GoogleEmbeddingModelType(str, Enum): + TEXT = auto() + MULTIMODAL = auto() + + @classmethod + def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]: + if "textembedding-gecko" in value.lower(): + return GoogleEmbeddingModelType.TEXT + elif "multimodalembedding" in value.lower(): + return GoogleEmbeddingModelType.MULTIMODAL + return None + + class VertexAIEmbeddings(_VertexAICommon, Embeddings): """Google Cloud VertexAI embedding models.""" @@ -46,7 +65,15 @@ def validate_environment(cls, values: Dict) -> Dict: "textembedding-gecko@001" ) values["model_name"] = "textembedding-gecko@001" - values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) + if ( + GoogleEmbeddingModelType(values["model_name"]) + == GoogleEmbeddingModelType.MULTIMODAL + ): + values["client"] = MultiModalEmbeddingModel.from_pretrained( + values["model_name"] + ) + else: + values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) return values def __init__( @@ -83,6 +110,20 @@ def __init__( "embeddings_task_type_supported" ] = not self.client._endpoint_name.endswith("/textembedding-gecko@001") + retry_errors: List[Type[BaseException]] = [ + ResourceExhausted, + ServiceUnavailable, + Aborted, + DeadlineExceeded, + ] + self.instance["retry_decorator"] = create_base_retry_decorator( + error_types=retry_errors, max_retries=self.max_retries + ) + + @property + def model_type(self) -> str: + return GoogleEmbeddingModelType(self.model_name) + @staticmethod def _split_by_punctuation(text: str) -> List[str]: """Splits a string by punctuation and whitespace characters.""" @@ -321,6 +362,8 @@ def embed_documents( Returns: List of embeddings, one for each text. """ + if self.model_type != GoogleEmbeddingModelType.TEXT: + raise NotImplementedError("Not supported for multimodal models") return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT") def embed_query(self, text: str) -> List[float]: @@ -332,5 +375,25 @@ def embed_query(self, text: str) -> List[float]: Returns: Embedding for the text. """ + if self.model_type != GoogleEmbeddingModelType.TEXT: + raise NotImplementedError("Not supported for multimodal models") embeddings = self.embed([text], 1, "RETRIEVAL_QUERY") return embeddings[0] + + def embed_image(self, image_path: str) -> List[float]: + """Embed an image. + + Args: + image_path: Path to image (local or Google Cloud Storage) to generate + embeddings for. + + Returns: + Embedding for the image. + """ + if self.model_type != GoogleEmbeddingModelType.MULTIMODAL: + raise NotImplementedError("Only supported for multimodal models") + + embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings) + image = Image.load_from_file(image_path) + result: MultiModalEmbeddingResponse = embed_with_retry(image=image) + return result.image_embedding diff --git a/libs/vertexai/tests/integration_tests/conftest.py b/libs/vertexai/tests/integration_tests/conftest.py new file mode 100644 index 00000000..b0903366 --- /dev/null +++ b/libs/vertexai/tests/integration_tests/conftest.py @@ -0,0 +1,39 @@ +import base64 + +import pytest +from _pytest.tmpdir import TempPathFactory +from vertexai.vision_models import Image # type: ignore + + +@pytest.fixture +def base64_image() -> str: + return ( + "" + "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" + "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" + "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" + "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" + "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" + "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" + "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" + "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" + "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" + "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" + "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" + "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" + "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" + "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" + "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" + "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" + "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" + "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) + + +@pytest.fixture +def tmp_image(tmp_path_factory: TempPathFactory, base64_image) -> str: + img_data = base64.b64decode(base64_image.split(",")[1]) + image = Image(image_bytes=img_data) + fn = tmp_path_factory.mktemp("data") / "img.png" + image.save(str(fn)) + return str(fn) diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index 6772274a..56fb03cf 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -4,8 +4,13 @@ `gcloud auth login` first). """ import pytest +from vertexai.language_models import TextEmbeddingModel # type: ignore +from vertexai.vision_models import MultiModalEmbeddingModel # type: ignore -from langchain_google_vertexai.embeddings import VertexAIEmbeddings +from langchain_google_vertexai.embeddings import ( + GoogleEmbeddingModelType, + VertexAIEmbeddings, +) @pytest.mark.release @@ -74,3 +79,24 @@ def test_warning(caplog: pytest.LogCaptureFixture) -> None: "Feb-01-2024. Currently the default is set to textembedding-gecko@001" ) assert record.message == expected_message + + +@pytest.mark.release +def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None: + model = VertexAIEmbeddings(model_name="multimodalembedding") + output = model.embed_image(tmp_image) + assert len(output) == 1408 + + +@pytest.mark.release +def test_langchain_google_vertexai_text_model() -> None: + embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001") + assert isinstance(embeddings_model.client, TextEmbeddingModel) + assert embeddings_model.model_type == GoogleEmbeddingModelType.TEXT + + +@pytest.mark.release +def test_langchain_google_vertexai_multimodal_model() -> None: + embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001") + assert isinstance(embeddings_model.client, MultiModalEmbeddingModel) + assert embeddings_model.model_type == GoogleEmbeddingModelType.MULTIMODAL diff --git a/libs/vertexai/tests/integration_tests/test_vision_models.py b/libs/vertexai/tests/integration_tests/test_vision_models.py index 7714373a..1cbdfdda 100644 --- a/libs/vertexai/tests/integration_tests/test_vision_models.py +++ b/libs/vertexai/tests/integration_tests/test_vision_models.py @@ -131,28 +131,3 @@ def test_vertex_ai_image_generation_and_edition(): response = editor.invoke(messages) assert isinstance(response, AIMessage) - - -@pytest.fixture -def base64_image() -> str: - return ( - "" - "BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3" - "d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap" - "ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx" - "BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr" - "CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD" - "1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD" - "ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs" - "gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu" - "tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM" - "OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua" - "ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS" - "Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E" - "hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW" - "VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH" - "rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz" - "8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf" - "yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN" - "z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" - ) diff --git a/libs/vertexai/tests/unit_tests/test_embeddings.py b/libs/vertexai/tests/unit_tests/test_embeddings.py new file mode 100644 index 00000000..8e1d86f4 --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_embeddings.py @@ -0,0 +1,51 @@ +from typing import Any, Dict +from unittest.mock import MagicMock + +import pytest +from pydantic.v1 import root_validator + +from langchain_google_vertexai import VertexAIEmbeddings +from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType + + +def test_langchain_google_vertexai_embed_image_multimodal_only() -> None: + mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001") + assert mock_embeddings.model_type == GoogleEmbeddingModelType.TEXT + with pytest.raises(NotImplementedError) as e: + mock_embeddings.embed_image("test") + assert e.value == "Only supported for multimodal models" + + +def test_langchain_google_vertexai_embed_documents_text_only() -> None: + mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001") + assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL + with pytest.raises(NotImplementedError) as e: + mock_embeddings.embed_documents(["test"]) + assert e.value == "Not supported for multimodal models" + + +def test_langchain_google_vertexai_embed_query_text_only() -> None: + mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001") + assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL + with pytest.raises(NotImplementedError) as e: + mock_embeddings.embed_query("test") + assert e.value == "Not supported for multimodal models" + + +class MockVertexAIEmbeddings(VertexAIEmbeddings): + """ + A mock class for avoiding instantiating VertexAI and the EmbeddingModel client + instance during init + """ + + def __init__(self, model_name, **kwargs: Any) -> None: + super().__init__(model_name, **kwargs) + + @classmethod + def _init_vertexai(cls, values: Dict) -> None: + pass + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["client"] = MagicMock() + return values