Skip to content

Commit

Permalink
removed default model_name for embeddings (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin authored Mar 26, 2024
1 parent 73f83ea commit ff57b2b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 25 deletions.
10 changes: 1 addition & 9 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
def validate_environment(cls, values: Dict) -> Dict:
"""Validates that the python package exists in environment."""
cls._init_vertexai(values)
if values["model_name"] == "textembedding-gecko-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Feb-01-2024. Currently the default is set to "
"textembedding-gecko@001"
)
values["model_name"] = "textembedding-gecko@001"
_, user_agent = get_user_agent(f"{cls.__name__}_{values['model_name']}") # type: ignore
with telemetry.tool_context_manager(user_agent):
if (
Expand All @@ -85,8 +78,7 @@ def validate_environment(cls, values: Dict) -> Dict:

def __init__(
self,
# the default value would be removed after Feb-01-2024
model_name: str = "textembedding-gecko-default",
model_name: str,
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
Expand Down
23 changes: 7 additions & 16 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@pytest.mark.release
def test_initialization() -> None:
"""Test embedding model initialization."""
VertexAIEmbeddings()
VertexAIEmbeddings(model_name="textembedding-gecko@001")


@pytest.mark.release
Expand Down Expand Up @@ -56,27 +56,18 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -
@pytest.mark.release
def test_langchain_google_vertexai_large_batches() -> None:
documents = ["foo bar" for _ in range(0, 251)]
model_uscentral1 = VertexAIEmbeddings(location="us-central1")
model_asianortheast1 = VertexAIEmbeddings(location="asia-northeast1")
model_uscentral1 = VertexAIEmbeddings(
model_name="textembedding-gecko@001", location="us-central1"
)
model_asianortheast1 = VertexAIEmbeddings(
model_name="textembedding-gecko@001", location="asia-northeast1"
)
model_uscentral1.embed_documents(documents)
model_asianortheast1.embed_documents(documents)
assert model_uscentral1.instance["batch_size"] >= 250
assert model_asianortheast1.instance["batch_size"] < 50


@pytest.mark.release
def test_warning(caplog: pytest.LogCaptureFixture) -> None:
_ = VertexAIEmbeddings()
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelname == "WARNING"
expected_message = (
"Model_name will become a required arg for VertexAIEmbeddings starting from "
"Feb-01-2024. Currently the default is set to textembedding-gecko@001"
)
assert record.message == expected_message


@pytest.mark.release
def test_langchain_google_vertexai_image_embeddings(tmp_image) -> None:
model = VertexAIEmbeddings(model_name="multimodalembedding")
Expand Down

0 comments on commit ff57b2b

Please sign in to comment.