diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index a849462d..78b3ec48 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -116,9 +116,12 @@ def __init__( Aborted, DeadlineExceeded, ] - self.instance["retry_decorator"] = create_base_retry_decorator( + retry_decorator = create_base_retry_decorator( error_types=retry_errors, max_retries=self.max_retries ) + self.instance["get_embeddings_with_retry"] = retry_decorator( + self.client.get_embeddings + ) @property def model_type(self) -> str: @@ -188,30 +191,41 @@ def _get_embeddings_with_retry( self, texts: List[str], embeddings_type: Optional[str] = None ) -> List[List[float]]: """Makes a Vertex AI model request with retry logic.""" - - errors: List[Type[BaseException]] = [ - ResourceExhausted, - ServiceUnavailable, - Aborted, - DeadlineExceeded, - ] - retry_decorator = create_base_retry_decorator( - error_types=errors, max_retries=self.max_retries + if self.model_type == GoogleEmbeddingModelType.MULTIMODAL: + return self._get_multimodal_embeddings_with_retry(texts) + return self._get_text_embeddings_with_retry( + texts, embeddings_type=embeddings_type ) - @retry_decorator - def _completion_with_retry(texts_to_process: List[str]) -> Any: - if embeddings_type and self.instance["embeddings_task_type_supported"]: - requests = [ - TextEmbeddingInput(text=t, task_type=embeddings_type) - for t in texts_to_process - ] - else: - requests = texts_to_process - embeddings = self.client.get_embeddings(requests) - return [embs.values for embs in embeddings] + def _get_multimodal_embeddings_with_retry( + self, texts: List[str] + ) -> List[List[float]]: + tasks = [] + for text in texts: + tasks.append( + self.instance["task_executor"].submit( + self.instance["get_embeddings_with_retry"], + contextual_text=text, + ) + ) + if len(tasks) > 0: + wait(tasks) + embeddings = [task.result().text_embedding for task in tasks] + return embeddings - return _completion_with_retry(texts) + def _get_text_embeddings_with_retry( + self, texts: List[str], embeddings_type: Optional[str] = None + ) -> List[List[float]]: + """Makes a Vertex AI model request with retry logic.""" + + if embeddings_type and self.instance["embeddings_task_type_supported"]: + requests = [ + TextEmbeddingInput(text=t, task_type=embeddings_type) for t in texts + ] + else: + requests = texts + embeddings = self.instance["get_embeddings_with_retry"](requests) + return [embedding.values for embedding in embeddings] def _prepare_and_validate_batches( self, texts: List[str], embeddings_type: Optional[str] = None @@ -240,7 +254,7 @@ def _prepare_and_validate_batches( return [], VertexAIEmbeddings._prepare_batches( texts, self.instance["batch_size"] ) - # Figure out largest possible batch size by trying to push + # Figure out the largest possible batch size by trying to push # batches and lowering their size in half after every failure. first_batch = batches[0] first_result = [] @@ -362,8 +376,6 @@ def embed_documents( Returns: List of embeddings, one for each text. """ - if self.model_type != GoogleEmbeddingModelType.TEXT: - raise NotImplementedError("Not supported for multimodal models") return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT") def embed_query(self, text: str) -> List[float]: @@ -375,10 +387,7 @@ def embed_query(self, text: str) -> List[float]: Returns: Embedding for the text. """ - if self.model_type != GoogleEmbeddingModelType.TEXT: - raise NotImplementedError("Not supported for multimodal models") - embeddings = self.embed([text], 1, "RETRIEVAL_QUERY") - return embeddings[0] + return self.embed([text], 1, "RETRIEVAL_QUERY")[0] def embed_image(self, image_path: str) -> List[float]: """Embed an image. @@ -393,7 +402,8 @@ def embed_image(self, image_path: str) -> List[float]: if self.model_type != GoogleEmbeddingModelType.MULTIMODAL: raise NotImplementedError("Only supported for multimodal models") - embed_with_retry = self.instance["retry_decorator"](self.client.get_embeddings) image = Image.load_from_file(image_path) - result: MultiModalEmbeddingResponse = embed_with_retry(image=image) + result: MultiModalEmbeddingResponse = self.instance[ + "get_embeddings_with_retry" + ](image=image) return result.image_embedding diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index 56fb03cf..600c09ac 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -20,22 +20,37 @@ def test_initialization() -> None: @pytest.mark.release -def test_langchain_google_vertexai_embedding_documents() -> None: - documents = ["foo bar"] - model = VertexAIEmbeddings() +@pytest.mark.parametrize( + "number_of_docs", + [1, 8], +) +@pytest.mark.parametrize( + "model_name, embeddings_dim", + [("textembedding-gecko@001", 768), ("multimodalembedding@001", 1408)], +) +def test_langchain_google_vertexai_embedding_documents( + number_of_docs: int, model_name: str, embeddings_dim: int +) -> None: + documents = ["foo bar"] * number_of_docs + model = VertexAIEmbeddings(model_name) output = model.embed_documents(documents) - assert len(output) == 1 - assert len(output[0]) == 768 + assert len(output) == number_of_docs + for embedding in output: + assert len(embedding) == embeddings_dim assert model.model_name == model.client._model_id - assert model.model_name == "textembedding-gecko@001" + assert model.model_name == model_name @pytest.mark.release -def test_langchain_google_vertexai_embedding_query() -> None: +@pytest.mark.parametrize( + "model_name, embeddings_dim", + [("textembedding-gecko@001", 768), ("multimodalembedding@001", 1408)], +) +def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -> None: document = "foo bar" - model = VertexAIEmbeddings() + model = VertexAIEmbeddings(model_name) output = model.embed_query(document) - assert len(output) == 768 + assert len(output) == embeddings_dim @pytest.mark.release @@ -49,25 +64,6 @@ def test_langchain_google_vertexai_large_batches() -> None: assert model_asianortheast1.instance["batch_size"] < 50 -@pytest.mark.release -def test_langchain_google_vertexai_paginated_texts() -> None: - documents = [ - "foo bar", - "foo baz", - "bar foo", - "baz foo", - "bar bar", - "foo foo", - "baz baz", - "baz bar", - ] - model = VertexAIEmbeddings() - output = model.embed_documents(documents) - assert len(output) == 8 - assert len(output[0]) == 768 - assert model.model_name == model.client._model_id - - @pytest.mark.release def test_warning(caplog: pytest.LogCaptureFixture) -> None: _ = VertexAIEmbeddings() diff --git a/libs/vertexai/tests/unit_tests/test_embeddings.py b/libs/vertexai/tests/unit_tests/test_embeddings.py index 8e1d86f4..2a4bd299 100644 --- a/libs/vertexai/tests/unit_tests/test_embeddings.py +++ b/libs/vertexai/tests/unit_tests/test_embeddings.py @@ -16,22 +16,6 @@ def test_langchain_google_vertexai_embed_image_multimodal_only() -> None: assert e.value == "Only supported for multimodal models" -def test_langchain_google_vertexai_embed_documents_text_only() -> None: - mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001") - assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL - with pytest.raises(NotImplementedError) as e: - mock_embeddings.embed_documents(["test"]) - assert e.value == "Not supported for multimodal models" - - -def test_langchain_google_vertexai_embed_query_text_only() -> None: - mock_embeddings = MockVertexAIEmbeddings("multimodalembedding@001") - assert mock_embeddings.model_type == GoogleEmbeddingModelType.MULTIMODAL - with pytest.raises(NotImplementedError) as e: - mock_embeddings.embed_query("test") - assert e.value == "Not supported for multimodal models" - - class MockVertexAIEmbeddings(VertexAIEmbeddings): """ A mock class for avoiding instantiating VertexAI and the EmbeddingModel client