Skip to content

Commit

Permalink
feat: change is_smth to enum and property approach for embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
svidiella committed Feb 29, 2024
1 parent e3d9ef3 commit 85a1b40
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
43 changes: 26 additions & 17 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import string
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from enum import Enum, auto
from typing import Any, Dict, List, Literal, Optional, Tuple, Type

from google.api_core.exceptions import (
Expand Down Expand Up @@ -34,6 +35,19 @@
_MIN_BATCH_SIZE = 5


class GoogleEmbeddingModelType(str, Enum):
TEXT = auto()
MULTIMODAL = auto()

@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]:
if "textembedding-gecko" in value.lower():
return GoogleEmbeddingModelType.TEXT
elif "multimodalembedding" in value.lower():
return GoogleEmbeddingModelType.MULTIMODAL
return None


class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""

Expand All @@ -51,7 +65,10 @@ def validate_environment(cls, values: Dict) -> Dict:
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
if cls._is_multimodal_model(values["model_name"]):
if (
GoogleEmbeddingModelType(values["model_name"])
== GoogleEmbeddingModelType.MULTIMODAL
):
values["client"] = MultiModalEmbeddingModel.from_pretrained(
values["model_name"]
)
Expand Down Expand Up @@ -92,6 +109,7 @@ def __init__(
self.instance[
"embeddings_task_type_supported"
] = not self.client._endpoint_name.endswith("/textembedding-gecko@001")

retry_errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Expand All @@ -102,6 +120,10 @@ def __init__(
error_types=retry_errors, max_retries=self.max_retries
)

@property
def model_type(self) -> str:
return GoogleEmbeddingModelType(self.model_name)

@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
Expand Down Expand Up @@ -340,7 +362,7 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
if self._is_multimodal_model(self.model_name):
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")

Expand All @@ -353,7 +375,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
if self._is_multimodal_model(self.model_name):
if self.model_type != GoogleEmbeddingModelType.TEXT:
raise NotImplementedError("Not supported for multimodal models")
embeddings = self.embed([text], 1, "RETRIEVAL_QUERY")
return embeddings[0]
Expand All @@ -368,23 +390,10 @@ def embed_image(self, image_path: str) -> List[float]:
Returns:
Embedding for the image.
"""
if not self._is_multimodal_model(self.model_name):
if self.model_type != GoogleEmbeddingModelType.MULTIMODAL:
raise NotImplementedError("Only supported for multimodal models")

embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings)
image = Image.load_from_file(image_path)
result: MultiModalEmbeddingResponse = embed_with_retry(image=image)
return result.image_embedding

@staticmethod
def _is_multimodal_model(model_name: str) -> bool:
"""
Check if the embeddings model is multimodal or not.
Args:
model_name: The embeddings model name.
Returns:
A boolean, True if the model is multimodal.
"""
return "multimodalembedding" in model_name
9 changes: 6 additions & 3 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from vertexai.language_models import TextEmbeddingModel
from vertexai.vision_models import MultiModalEmbeddingModel

from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import (
VertexAIEmbeddings,
GoogleEmbeddingModelType,
)


@pytest.mark.release
Expand Down Expand Up @@ -87,10 +90,10 @@ def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None:
def test_langchain_google_vertexai_text_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001")
assert isinstance(embeddings_model.client, TextEmbeddingModel)
assert not embeddings_model._is_multimodal_model(embeddings_model.model_name)
assert embeddings_model.model_type == GoogleEmbeddingModelType.TEXT


def test_langchain_google_vertexai_multimodal_model() -> None:
embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001")
assert isinstance(embeddings_model.client, MultiModalEmbeddingModel)
assert embeddings_model._is_multimodal_model(embeddings_model.model_name)
assert embeddings_model.model_type == GoogleEmbeddingModelType.MULTIMODAL
4 changes: 4 additions & 0 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@
import pytest

from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType


def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.TEXT
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_image("test")
assert e.value == "Only supported for multimodal models"


def test_langchain_google_vertexai_embed_documents_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_documents(["test"])
assert e.value == "Not supported for multimodal models"


def test_langchain_google_vertexai_embed_query_text_only() -> None:
mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001")
assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL
with pytest.raises(NotImplementedError) as e:
mock_embeddings.embed_query("test")
assert e.value == "Not supported for multimodal models"
Expand Down

0 comments on commit 85a1b40

Please sign in to comment.