diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index 61682b87..a849462d 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -3,6 +3,7 @@ import string import threading from concurrent.futures import ThreadPoolExecutor, wait +from enum import Enum, auto from typing import Any, Dict, List, Literal, Optional, Tuple, Type from google.api_core.exceptions import ( @@ -34,6 +35,19 @@ _MIN_BATCH_SIZE = 5 +class GoogleEmbeddingModelType(str, Enum): + TEXT = auto() + MULTIMODAL = auto() + + @classmethod + def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]: + if "textembedding-gecko" in value.lower(): + return GoogleEmbeddingModelType.TEXT + elif "multimodalembedding" in value.lower(): + return GoogleEmbeddingModelType.MULTIMODAL + return None + + class VertexAIEmbeddings(_VertexAICommon, Embeddings): """Google Cloud VertexAI embedding models.""" @@ -51,7 +65,10 @@ def validate_environment(cls, values: Dict) -> Dict: "textembedding-gecko@001" ) values["model_name"] = "textembedding-gecko@001" - if cls._is_multimodal_model(values["model_name"]): + if ( + GoogleEmbeddingModelType(values["model_name"]) + == GoogleEmbeddingModelType.MULTIMODAL + ): values["client"] = MultiModalEmbeddingModel.from_pretrained( values["model_name"] ) @@ -92,6 +109,7 @@ def __init__( self.instance[ "embeddings_task_type_supported" ] = not self.client._endpoint_name.endswith("/textembedding-gecko@001") + retry_errors: List[Type[BaseException]] = [ ResourceExhausted, ServiceUnavailable, @@ -102,6 +120,10 @@ def __init__( error_types=retry_errors, max_retries=self.max_retries ) + @property + def model_type(self) -> str: + return GoogleEmbeddingModelType(self.model_name) + @staticmethod def _split_by_punctuation(text: str) -> List[str]: """Splits a string by punctuation and whitespace characters.""" @@ -340,7 +362,7 @@ def embed_documents( Returns: List of embeddings, one for each text. """ - if self._is_multimodal_model(self.model_name): + if self.model_type != GoogleEmbeddingModelType.TEXT: raise NotImplementedError("Not supported for multimodal models") return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT") @@ -353,7 +375,7 @@ def embed_query(self, text: str) -> List[float]: Returns: Embedding for the text. """ - if self._is_multimodal_model(self.model_name): + if self.model_type != GoogleEmbeddingModelType.TEXT: raise NotImplementedError("Not supported for multimodal models") embeddings = self.embed([text], 1, "RETRIEVAL_QUERY") return embeddings[0] @@ -368,23 +390,10 @@ def embed_image(self, image_path: str) -> List[float]: Returns: Embedding for the image. """ - if not self._is_multimodal_model(self.model_name): + if self.model_type != GoogleEmbeddingModelType.MULTIMODAL: raise NotImplementedError("Only supported for multimodal models") embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings) image = Image.load_from_file(image_path) result: MultiModalEmbeddingResponse = embed_with_retry(image=image) return result.image_embedding - - @staticmethod - def _is_multimodal_model(model_name: str) -> bool: - """ - Check if the embeddings model is multimodal or not. - - Args: - model_name: The embeddings model name. - - Returns: - A boolean, True if the model is multimodal. - """ - return "multimodalembedding" in model_name diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index 08e9abd2..ed1abc79 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -7,7 +7,10 @@ from vertexai.language_models import TextEmbeddingModel from vertexai.vision_models import MultiModalEmbeddingModel -from langchain_google_vertexai.embeddings import VertexAIEmbeddings +from langchain_google_vertexai.embeddings import ( + VertexAIEmbeddings, + GoogleEmbeddingModelType, +) @pytest.mark.release @@ -87,10 +90,10 @@ def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None: def test_langchain_google_vertexai_text_model() -> None: embeddings_model = VertexAIEmbeddings(model_name="textembedding-gecko@001") assert isinstance(embeddings_model.client, TextEmbeddingModel) - assert not embeddings_model._is_multimodal_model(embeddings_model.model_name) + assert embeddings_model.model_type == GoogleEmbeddingModelType.TEXT def test_langchain_google_vertexai_multimodal_model() -> None: embeddings_model = VertexAIEmbeddings(model_name="multimodalembedding@001") assert isinstance(embeddings_model.client, MultiModalEmbeddingModel) - assert embeddings_model._is_multimodal_model(embeddings_model.model_name) + assert embeddings_model.model_type == GoogleEmbeddingModelType.MULTIMODAL diff --git a/libs/vertexai/tests/unit_tests/test_embeddings.py b/libs/vertexai/tests/unit_tests/test_embeddings.py index e3581ed8..664e103a 100644 --- a/libs/vertexai/tests/unit_tests/test_embeddings.py +++ b/libs/vertexai/tests/unit_tests/test_embeddings.py @@ -4,10 +4,12 @@ import pytest from langchain_google_vertexai import VertexAIEmbeddings +from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType def test_langchain_google_vertexai_embed_image_multimodal_only() -> None: mock_embeddings = MockVertexAIEmbeddings("textembedding-gecko@001") + assert mock_embeddings.model_type == GoogleEmbeddingModelType.TEXT with pytest.raises(NotImplementedError) as e: mock_embeddings.embed_image("test") assert e.value == "Only supported for multimodal models" @@ -15,6 +17,7 @@ def test_langchain_google_vertexai_embed_image_multimodal_only() -> None: def test_langchain_google_vertexai_embed_documents_text_only() -> None: mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001") + assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL with pytest.raises(NotImplementedError) as e: mock_embeddings.embed_documents(["test"]) assert e.value == "Not supported for multimodal models" @@ -22,6 +25,7 @@ def test_langchain_google_vertexai_embed_documents_text_only() -> None: def test_langchain_google_vertexai_embed_query_text_only() -> None: mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001") + assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL with pytest.raises(NotImplementedError) as e: mock_embeddings.embed_query("test") assert e.value == "Not supported for multimodal models"