Skip to content

Commit

Permalink
changed default for VertexAIEmbeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Dec 18, 2023
1 parent 5fc2c57 commit 2315d3c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
13 changes: 10 additions & 3 deletions libs/community/langchain_community/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,28 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._try_init_vertexai(values)
if values["model_name"] == "textembedding-gecko-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Feb-01-2024. Currently the default is set to "
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
try:
from vertexai.language_models import TextEmbeddingModel

values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
except ImportError:
raise_vertex_import_error()
values["client"] = TextEmbeddingModel.from_pretrained(values["model_name"])
return values

def __init__(
self,
# the default value would be removed after Feb-01-2024
model_name: str = "textembedding-gecko-default",
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
max_retries: int = 6,
model_name: str = "textembedding-gecko",
credentials: Optional[Any] = None,
**kwargs: Any,
):
Expand Down
15 changes: 15 additions & 0 deletions libs/community/tests/integration_tests/embeddings/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first).
"""
import pytest

from langchain_community.embeddings import VertexAIEmbeddings


Expand All @@ -15,6 +17,7 @@ def test_embedding_documents() -> None:
assert len(output) == 1
assert len(output[0]) == 768
assert model.model_name == model.client._model_id
assert model.model_name == "textembedding-gecko@001"


def test_embedding_query() -> None:
Expand Down Expand Up @@ -50,3 +53,15 @@ def test_paginated_texts() -> None:
assert len(output) == 8
assert len(output[0]) == 768
assert model.model_name == model.client._model_id


def test_warning(caplog: pytest.LogCaptureFixture) -> None:
_ = VertexAIEmbeddings()
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelname == "WARNING"
expected_message = (
"Model_name will become a required arg for VertexAIEmbeddings starting from "
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message

0 comments on commit 2315d3c

Please sign in to comment.