Skip to content

Commit

Permalink
feat: add support for multimodal model and add embed_image (#44)
Browse files Browse the repository at this point in the history
* feat: add support for multimodal model and add embed_image
  • Loading branch information
svidiella authored Feb 29, 2024
1 parent 75d0d8c commit b72e14c
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 27 deletions.
1 change: 1 addition & 0 deletions libs/vertexai/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__pycache__
.mypy_cache_test
65 changes: 64 additions & 1 deletion libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from enum import Enum, auto
from typing import Any, Dict, List, Literal, Optional, Tuple, Type

from google.api_core.exceptions import (
Expand All @@ -19,6 +20,11 @@
TextEmbeddingInput,
TextEmbeddingModel,
)
from vertexai.vision_models import ( # type: ignore
Image,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
)

from langchain_google_vertexai._base import _VertexAICommon

Expand All @@ -29,6 +35,19 @@
_MIN_BATCH_SIZE = 5


class GoogleEmbeddingModelType(str, Enum):
TEXT = auto()
MULTIMODAL = auto()

@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]:
if "textembedding-gecko" in value.lower():
return GoogleEmbeddingModelType.TEXT
elif "multimodalembedding" in value.lower():
return GoogleEmbeddingModelType.MULTIMODAL
return None


class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""

Expand All @@ -46,7 +65,15 @@ def validate_environment(cls, values: Dict) -> Dict:
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
if (
GoogleEmbeddingModelType(values["model_name"])
== GoogleEmbeddingModelType.MULTIMODAL
):
values["client"] = MultiModalEmbeddingModel.from_pretrained(
values["model_name"]
)
else:
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

def __init__(
Expand Down Expand Up @@ -83,6 +110,20 @@ def __init__(
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")

retry_errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
self.instance["retry_decorator"] = create_base_retry_decorator(
error_types=retry_errors, max_retries=self.max_retries
)

@property
def model_type(self) -> str:
return GoogleEmbeddingModelType(self.model_name)

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
Expand Down Expand Up @@ -321,6 +362,8 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")

def embed_query(self, text: str) -> List[float]:
Expand All @@ -332,5 +375,25 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]

def embed_image(self, image_path: str) -> List[float]:
"""Embed an image.
Args:
image_path: Path to image (local or Google Cloud Storage) to generate
embeddings for.
Returns:
Embedding for the image.
"""
if self.model_type != GoogleEmbeddingModelType.MULTIMODAL:
raise NotImplementedError("Only supported for multimodal models")

embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings)
image = Image.load_from_file(image_path)
result: MultiModalEmbeddingResponse = embed_with_retry(image=image)
return result.image_embedding
39 changes: 39 additions & 0 deletions libs/vertexai/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import base64

import pytest
from _pytest.tmpdir import TempPathFactory
from vertexai.vision_models import Image # type: ignore


@pytest.fixture
def base64_image() -> str:
return (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA"
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)


@pytest.fixture
def tmp_image(tmp_path_factory: TempPathFactory, base64_image) -> str:
img_data = base64.b64decode(base64_image.split(",")[1])
image = Image(image_bytes=img_data)
fn = tmp_path_factory.mktemp("data") / "img.png"
image.save(str(fn))
return str(fn)
28 changes: 27 additions & 1 deletion libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
`gcloud auth login` first).
"""
import pytest
from vertexai.language_models import TextEmbeddingModel # type: ignore
from vertexai.vision_models import MultiModalEmbeddingModel # type: ignore

from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import (
GoogleEmbeddingModelType,
VertexAIEmbeddings,
)


@pytest.mark.release
Expand Down Expand Up @@ -74,3 +79,24 @@ def test_warning(caplog: pytest.LogCaptureFixture) -> None:
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message


@pytest.mark.release
def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None:
model = VertexAIEmbeddings(model_name="multimodalembedding")
output = model.embed_image(tmp_image)
assert len(output) == 1408


@pytest.mark.release
def test_langchain_google_vertexai_text_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001")
assert isinstance(embeddings_model.client, TextEmbeddingModel)
assert embeddings_model.model_type == GoogleEmbeddingModelType.TEXT


@pytest.mark.release
def test_langchain_google_vertexai_multimodal_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001")
assert isinstance(embeddings_model.client, MultiModalEmbeddingModel)
assert embeddings_model.model_type == GoogleEmbeddingModelType.MULTIMODAL
25 changes: 0 additions & 25 deletions libs/vertexai/tests/integration_tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,3 @@ def test_vertex_ai_image_generation_and_edition():

response = editor.invoke(messages)
assert isinstance(response, AIMessage)


@pytest.fixture
def base64_image() -> str:
return (
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAA"
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)
51 changes: 51 additions & 0 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Dict
from unittest.mock import MagicMock

import pytest
from pydantic.v1 import root_validator

from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType


def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.TEXT
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_image("test")
assert e.value == "Only supported for multimodal models"


def test_langchain_google_vertexai_embed_documents_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_documents(["test"])
assert e.value == "Not supported for multimodal models"


def test_langchain_google_vertexai_embed_query_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_query("test")
assert e.value == "Not supported for multimodal models"


class MockVertexAIEmbeddings(VertexAIEmbeddings):
"""
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client
instance during init
"""

def __init__(self, model_name, **kwargs: Any) -> None:
super().__init__(model_name, **kwargs)

@classmethod
def _init_vertexai(cls, values: Dict) -> None:
pass

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["client"] = MagicMock()
return values

0 comments on commit b72e14c

Please sign in to comment.