Skip to content

Commit

Permalink
[vectorstores] Add new VectorSearch class with DataStore backend (#82)
Browse files Browse the repository at this point in the history
* Add new VectorSearch class with datastore backend

---------

Co-authored-by: Jorge <[email protected]>
  • Loading branch information
jzaldi and Jorge authored Mar 22, 2024
1 parent d7bc26a commit 51bf9ef
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 16 deletions.
11 changes: 11 additions & 0 deletions libs/vertexai/langchain_google_vertexai/vectorstores/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from langchain_google_vertexai.vectorstores.vectorstores import (
VectorSearchVectorStore,
VectorSearchVectorStoreDatastore,
VectorSearchVectorStoreGCS,
)

__all__ = [
"VectorSearchVectorStore",
"VectorSearchVectorStoreDatastore",
"VectorSearchVectorStoreGCS",
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
import warnings
from typing import Any, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
Expand All @@ -11,6 +11,7 @@
from langchain_core.vectorstores import VectorStore

from langchain_google_vertexai.vectorstores._document_storage import (
DataStoreDocumentStorage,
DocumentStorage,
GCSDocumentStorage,
)
Expand Down Expand Up @@ -310,3 +311,90 @@ def from_components( # Implemented in order to keep the current API
),
embbedings=embedding,
)


class VectorSearchVectorStoreGCS(VectorSearchVectorStore):
"""Alias of `VectorSearchVectorStore` for consistency with the rest of vector
stores with different document storage backends.
"""


class VectorSearchVectorStoreDatastore(_BaseVertexAIVectorStore):
"""VectorSearch with DatasTore document storage."""

@classmethod
def from_components(
cls: Type["VectorSearchVectorStoreDatastore"],
project_id: str,
region: str,
index_id: str,
endpoint_id: str,
index_staging_bucket_name: Optional[str] = None,
credentials_path: Optional[str] = None,
embedding: Optional[Embeddings] = None,
stream_update: bool = False,
datastore_client_kwargs: Optional[Dict[str, Any]] = None,
datastore_kind: str = "document_id",
datastore_text_property_name: str = "text",
datastore_metadata_property_name: str = "metadata",
**kwargs: Dict[str, Any],
) -> "VectorSearchVectorStoreDatastore":
"""Takes the object creation out of the constructor.
# Args:
project_id: The GCP project id.
region: The default location making the API calls. It must have
the same location as the GCS bucket and must be regional.
index_id: The id of the created index.
endpoint_id: The id of the created endpoint.
index_staging_bucket_name: (Optional) If the index is updated by batch,
bucket where the data will be staged before updating the index. Only
required when updating the index.
credentials_path: (Optional) The path of the Google credentials on
the local file system.
embedding: The :class:`Embeddings` that will be used for
embedding the texts.
stream_update: Whether to update with streaming or batching. VectorSearch
index must be compatible with stream/batch updates.
kwargs: Additional keyword arguments to pass to
VertexAIVectorSearch.__init__().
"""

sdk_manager = VectorSearchSDKManager(
project_id=project_id, region=region, credentials_path=credentials_path
)

sdk_manager = VectorSearchSDKManager(
project_id=project_id, region=region, credentials_path=credentials_path
)

if index_staging_bucket_name is not None:
bucket = sdk_manager.get_gcs_bucket(bucket_name=index_staging_bucket_name)
else:
bucket = None

index = sdk_manager.get_index(index_id=index_id)
endpoint = sdk_manager.get_endpoint(endpoint_id=endpoint_id)

if datastore_client_kwargs is None:
datastore_client_kwargs = {}

datastore_client = sdk_manager.get_datastore_client(**datastore_client_kwargs)

document_storage = DataStoreDocumentStorage(
datastore_client=datastore_client,
kind=datastore_kind,
text_property_name=datastore_text_property_name,
metadata_property_name=datastore_metadata_property_name,
)

return cls(
document_storage=document_storage,
searcher=VectorSearchSearcher(
endpoint=endpoint,
index=index,
staging_bucket=bucket,
stream_update=stream_update,
),
embbedings=embedding,
)
63 changes: 48 additions & 15 deletions libs/vertexai/tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from langchain_google_vertexai.vectorstores._searcher import (
VectorSearchSearcher,
)
from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore
from langchain_google_vertexai.vectorstores.vectorstores import (
VectorSearchVectorStore,
VectorSearchVectorStoreDatastore,
)


@pytest.fixture
Expand Down Expand Up @@ -77,6 +80,22 @@ def vector_store() -> VectorSearchVectorStore:
return vector_store


@pytest.fixture
def datastore_vector_store() -> VectorSearchVectorStoreDatastore:
embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default")

vector_store = VectorSearchVectorStoreDatastore.from_components(
project_id=os.environ["PROJECT_ID"],
region=os.environ["REGION"],
index_id=os.environ["STREAM_INDEX_ID_DATASTORE"],
endpoint_id=os.environ["STREAM_ENDPOINT_ID_DATASTORE"],
embedding=embeddings,
stream_update=True,
)

return vector_store


@pytest.mark.extended
def test_vector_search_sdk_manager(sdk_manager: VectorSearchSDKManager):
gcs_client = sdk_manager.get_gcs_client()
Expand Down Expand Up @@ -145,8 +164,11 @@ def test_public_endpoint_vector_searcher(sdk_manager: VectorSearchSDKManager):


@pytest.mark.extended
def test_vector_store(vector_store: VectorSearchVectorStore):
assert isinstance(vector_store, VectorSearchVectorStore)
@pytest.mark.parametrize(
"vector_store_class", ["vector_store", "datastore_vector_store"]
)
def test_vector_store(vector_store_class: str, request: pytest.FixtureRequest):
vector_store: VectorSearchVectorStore = request.getfixturevalue(vector_store_class)

query = "What are your favourite animals?"
docs_with_scores = vector_store.similarity_search_with_score(query, k=1)
Expand All @@ -162,7 +184,17 @@ def test_vector_store(vector_store: VectorSearchVectorStore):


@pytest.mark.extended
def test_vector_store_filtering(vector_store: VectorSearchVectorStore):
@pytest.mark.parametrize(
"vector_store_class",
[
"vector_store",
# "datastore_vector_store" Waiting for the bug to be fixed as its stream
],
)
def test_vector_store_filtering(
vector_store_class: str, request: pytest.FixtureRequest
):
vector_store: VectorSearchVectorStore = request.getfixturevalue(vector_store_class)
documents = vector_store.similarity_search(
"I want some pants",
filter=[Namespace(name="color", allow_tokens=["blue"])],
Expand All @@ -175,19 +207,20 @@ def test_vector_store_filtering(vector_store: VectorSearchVectorStore):


@pytest.mark.extended
def test_vector_store_update_index(sample_documents: List[Document]):
embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default")
def test_vector_store_update_index(
vector_store: VectorSearchVectorStore, sample_documents: List[Document]
):
vector_store.add_documents(documents=sample_documents, is_complete_overwrite=True)

vector_store = VectorSearchVectorStore.from_components(
project_id=os.environ["PROJECT_ID"],
region=os.environ["REGION"],
gcs_bucket_name=os.environ["GCS_BUCKET_NAME"],
index_id=os.environ["INDEX_ID"],
endpoint_id=os.environ["ENDPOINT_ID"],
embedding=embeddings,
)

vector_store.add_documents(documents=sample_documents, is_complete_overwrite=True)
@pytest.mark.extended
def test_vector_store_stream_update_index(
datastore_vector_store: VectorSearchVectorStoreDatastore,
sample_documents: List[Document],
):
datastore_vector_store.add_documents(
documents=sample_documents, is_complete_overwrite=True
)


@pytest.fixture
Expand Down

0 comments on commit 51bf9ef

Please sign in to comment.