From 2315d3c8b4c23351e1df704ce3a938e7caded6f3 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 12 Dec 2023 19:35:08 +0100 Subject: [PATCH] changed default for VertexAIEmbeddings --- .../langchain_community/embeddings/vertexai.py | 13 ++++++++++--- .../integration_tests/embeddings/test_vertexai.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/embeddings/vertexai.py b/libs/community/langchain_community/embeddings/vertexai.py index 700aa36af3188..821708a7825ba 100644 --- a/libs/community/langchain_community/embeddings/vertexai.py +++ b/libs/community/langchain_community/embeddings/vertexai.py @@ -29,21 +29,28 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings): def validate_environment(cls, values: Dict) -> Dict: """Validates that the python package exists in environment.""" cls._try_init_vertexai(values) + if values["model_name"] == "textembedding-gecko-default": + logger.warning( + "Model_name will become a required arg for VertexAIEmbeddings " + "starting from Feb-01-2024. Currently the default is set to " + "textembedding-gecko@001" + ) + values["model_name"] = "textembedding-gecko@001" try: from vertexai.language_models import TextEmbeddingModel - - values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) except ImportError: raise_vertex_import_error() + values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"]) return values def __init__( self, + # the default value would be removed after Feb-01-2024 + model_name: str = "textembedding-gecko-default", project: Optional[str] = None, location: str = "us-central1", request_parallelism: int = 5, max_retries: int = 6, - model_name: str = "textembedding-gecko", credentials: Optional[Any] = None, **kwargs: Any, ): diff --git a/libs/community/tests/integration_tests/embeddings/test_vertexai.py b/libs/community/tests/integration_tests/embeddings/test_vertexai.py index 98682baddfca7..008321a49288b 100644 --- a/libs/community/tests/integration_tests/embeddings/test_vertexai.py +++ b/libs/community/tests/integration_tests/embeddings/test_vertexai.py @@ -5,6 +5,8 @@ Your end-user credentials would be used to make the calls (make sure you've run `gcloud auth login` first). """ +import pytest + from langchain_community.embeddings import VertexAIEmbeddings @@ -15,6 +17,7 @@ def test_embedding_documents() -> None: assert len(output) == 1 assert len(output[0]) == 768 assert model.model_name == model.client._model_id + assert model.model_name == "textembedding-gecko@001" def test_embedding_query() -> None: @@ -50,3 +53,15 @@ def test_paginated_texts() -> None: assert len(output) == 8 assert len(output[0]) == 768 assert model.model_name == model.client._model_id + + +def test_warning(caplog: pytest.LogCaptureFixture) -> None: + _ = VertexAIEmbeddings() + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelname == "WARNING" + expected_message = ( + "Model_name will become a required arg for VertexAIEmbeddings starting from " + "Feb-01-2024. Currently the default is set to textembedding-gecko@001" + ) + assert record.message == expected_message