Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for multimodal model and add embed_image #44

Merged
merged 5 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libs/vertexai/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__pycache__
.mypy_cache_test
65 changes: 64 additions & 1 deletion libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from enum import Enum, auto
from typing import Any, Dict, List, Literal, Optional, Tuple, Type

from google.api_core.exceptions import (
Expand All @@ -19,6 +20,11 @@
TextEmbeddingInput,
TextEmbeddingModel,
)
from vertexai.vision_models import ( # type: ignore
Image,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
)

from langchain_google_vertexai._base import _VertexAICommon

Expand All @@ -29,6 +35,19 @@
_MIN_BATCH_SIZE = 5


class GoogleEmbeddingModelType(str, Enum):
TEXT = auto()
MULTIMODAL = auto()

@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]:
if "textembedding-gecko" in value.lower():
return GoogleEmbeddingModelType.TEXT
elif "multimodalembedding" in value.lower():
return GoogleEmbeddingModelType.MULTIMODAL
return None


class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""

Expand All @@ -46,7 +65,15 @@ def validate_environment(cls, values: Dict) -> Dict:
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
if (
GoogleEmbeddingModelType(values["model_name"])
== GoogleEmbeddingModelType.MULTIMODAL
):
values["client"] = MultiModalEmbeddingModel.from_pretrained(
values["model_name"]
)
else:
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

def __init__(
Expand Down Expand Up @@ -83,6 +110,20 @@ def __init__(
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")

retry_errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
]
self.instance["retry_decorator"] = create_base_retry_decorator(
error_types=retry_errors, max_retries=self.max_retries
)

@property
def model_type(self) -> str:
return GoogleEmbeddingModelType(self.model_name)

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
Expand Down Expand Up @@ -321,6 +362,8 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")

def embed_query(self, text: str) -> List[float]:
Expand All @@ -332,5 +375,25 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]

def embed_image(self, image_path: str) -> List[float]:
"""Embed an image.

Args:
image_path: Path to image (local or Google Cloud Storage) to generate
embeddings for.

Returns:
Embedding for the image.
"""
if self.model_type != GoogleEmbeddingModelType.MULTIMODAL:
raise NotImplementedError("Only supported for multimodal models")

embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings)
image = Image.load_from_file(image_path)
result: MultiModalEmbeddingResponse = embed_with_retry(image=image)
return result.image_embedding
39 changes: 39 additions & 0 deletions libs/vertexai/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import base64

import pytest
from _pytest.tmpdir import TempPathFactory
from vertexai.vision_models import Image # type: ignore


@pytest.fixture
def base64_image() -> str:
return (
""
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)


@pytest.fixture
def tmp_image(tmp_path_factory: TempPathFactory, base64_image) -> str:
img_data = base64.b64decode(base64_image.split(",")[1])
image = Image(image_bytes=img_data)
fn = tmp_path_factory.mktemp("data") / "img.png"
image.save(str(fn))
return str(fn)
28 changes: 27 additions & 1 deletion libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
`gcloud auth login` first).
"""
import pytest
from vertexai.language_models import TextEmbeddingModel # type: ignore
from vertexai.vision_models import MultiModalEmbeddingModel # type: ignore

from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import (
GoogleEmbeddingModelType,
VertexAIEmbeddings,
)


@pytest.mark.release
Expand Down Expand Up @@ -74,3 +79,24 @@ def test_warning(caplog: pytest.LogCaptureFixture) -> None:
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message


@pytest.mark.release
def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None:
model = VertexAIEmbeddings(model_name="multimodalembedding")
lkuligin marked this conversation as resolved.
Show resolved Hide resolved
output = model.embed_image(tmp_image)
assert len(output) == 1408


@pytest.mark.release
def test_langchain_google_vertexai_text_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001")
assert isinstance(embeddings_model.client, TextEmbeddingModel)
assert embeddings_model.model_type == GoogleEmbeddingModelType.TEXT


@pytest.mark.release
def test_langchain_google_vertexai_multimodal_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001")
assert isinstance(embeddings_model.client, MultiModalEmbeddingModel)
assert embeddings_model.model_type == GoogleEmbeddingModelType.MULTIMODAL
25 changes: 0 additions & 25 deletions libs/vertexai/tests/integration_tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,3 @@ def test_vertex_ai_image_generation_and_edition():

response = editor.invoke(messages)
assert isinstance(response, AIMessage)


@pytest.fixture
def base64_image() -> str:
return (
""
"BHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3"
"d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBap"
"ySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnx"
"BwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXr"
"CDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD"
"1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQD"
"ry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPs"
"gxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96Cu"
"tRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOM"
"OVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWqua"
"ZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYS"
"Ub3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6E"
"hOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oW"
"VeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmH"
"rwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz"
"8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66Pf"
"yuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UN"
"z8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII="
)
51 changes: 51 additions & 0 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Dict
from unittest.mock import MagicMock

import pytest
from pydantic.v1 import root_validator

from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType


def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.TEXT
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_image("test")
assert e.value == "Only supported for multimodal models"


def test_langchain_google_vertexai_embed_documents_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_documents(["test"])
assert e.value == "Not supported for multimodal models"


def test_langchain_google_vertexai_embed_query_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_query("test")
assert e.value == "Not supported for multimodal models"


class MockVertexAIEmbeddings(VertexAIEmbeddings):
"""
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client
instance during init
"""

def __init__(self, model_name, **kwargs: Any) -> None:
super().__init__(model_name, **kwargs)

@classmethod
def _init_vertexai(cls, values: Dict) -> None:
pass

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["client"] = MagicMock()
return values
Loading