From ff57b2b3916e9d406523b0ac504dc7660f680592 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 26 Mar 2024 15:22:16 +0100 Subject: [PATCH] removed default model_name for embeddings (#90) --- .../langchain_google_vertexai/embeddings.py | 10 +------- .../integration_tests/test_embeddings.py | 23 ++++++------------- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index 1d0ff524..04424324 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -61,13 +61,6 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings): def validate_environment(cls, values: Dict) -> Dict: """Validates that the python package exists in environment.""" cls._init_vertexai(values) - if values["model_name"] == "textembedding-gecko-default": - logger.warning( - "Model_name will become a required arg for VertexAIEmbeddings " - "starting from Feb-01-2024. Currently the default is set to " - "textembedding-gecko@001" - ) - values["model_name"] = "textembedding-gecko@001" _, user_agent = get_user_agent(f"{cls.__name__}_{values['model_name']}") # type: ignore with telemetry.tool_context_manager(user_agent): if ( @@ -85,8 +78,7 @@ def validate_environment(cls, values: Dict) -> Dict: def __init__( self, - # the default value would be removed after Feb-01-2024 - model_name: str = "textembedding-gecko-default", + model_name: str, project: Optional[str] = None, location: str = "us-central1", request_parallelism: int = 5, diff --git a/libs/vertexai/tests/integration_tests/test_embeddings.py b/libs/vertexai/tests/integration_tests/test_embeddings.py index 600c09ac..fdc2481f 100644 --- a/libs/vertexai/tests/integration_tests/test_embeddings.py +++ b/libs/vertexai/tests/integration_tests/test_embeddings.py @@ -16,7 +16,7 @@ @pytest.mark.release def test_initialization() -> None: """Test embedding model initialization.""" - VertexAIEmbeddings() + VertexAIEmbeddings(model_name="textembedding-gecko@001") @pytest.mark.release @@ -56,27 +56,18 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) - @pytest.mark.release def test_langchain_google_vertexai_large_batches() -> None: documents = ["foo bar" for _ in range(0, 251)] - model_uscentral1 = VertexAIEmbeddings(location="us-central1") - model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1") + model_uscentral1 = VertexAIEmbeddings( + model_name="textembedding-gecko@001", location="us-central1" + ) + model_asianortheast1 = VertexAIEmbeddings( + model_name="textembedding-gecko@001", location="asia-northeast1" + ) model_uscentral1.embed_documents(documents) model_asianortheast1.embed_documents(documents) assert model_uscentral1.instance["batch_size"] >= 250 assert model_asianortheast1.instance["batch_size"] < 50 -@pytest.mark.release -def test_warning(caplog: pytest.LogCaptureFixture) -> None: - _ = VertexAIEmbeddings() - assert len(caplog.records) == 1 - record = caplog.records[0] - assert record.levelname == "WARNING" - expected_message = ( - "Model_name will become a required arg for VertexAIEmbeddings starting from " - "Feb-01-2024. Currently the default is set to textembedding-gecko@001" - ) - assert record.message == expected_message - - @pytest.mark.release def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None: model = VertexAIEmbeddings(model_name="multimodalembedding")