From 42adc4d78711afcbafdbdf6a9dc002d095429d57 Mon Sep 17 00:00:00 2001 From: Jorge Date: Sat, 9 Mar 2024 02:32:36 +0100 Subject: [PATCH 01/14] Added support for metadata in document storage --- .../vectorstores/_document_storage.py | 106 +++++++++++++----- .../vectorstores/_sdk_manager.py | 4 +- .../vectorstores/vectorstores.py | 25 ++++- .../integration_tests/test_vectorstores.py | 92 +++++++-------- 4 files changed, 145 insertions(+), 82 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py index c15c035a..cbdc8bb2 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py @@ -1,49 +1,51 @@ from __future__ import annotations +import json from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from google.cloud import storage # type: ignore[attr-defined] +from langchain_core.documents import Document if TYPE_CHECKING: - from google.cloud import datastore # type: ignore[attr-defined] + from google.cloud import datastore # type: ignore[attr-defined, unused-ignore] class DocumentStorage(ABC): """Abstract interface of a key, text storage for retrieving documents.""" @abstractmethod - def get_by_id(self, document_id: str) -> str | None: - """Gets the text of a document by its id. If not found, returns None. + def get_by_id(self, document_id: str) -> Document | None: + """Gets a document by its id. If not found, returns None. Args: document_id: Id of the document to get from the storage. Returns: - Text of the document if found, otherwise None. + Document if found, otherwise None. """ raise NotImplementedError() @abstractmethod - def store_by_id(self, document_id: str, text: str): - """Stores a document text associated to a document_id. + def store_by_id(self, document_id: str, document: Document): + """Stores a document associated to a document_id. Args: document_id: Id of the document to be stored. - text: Text of the document to be stored. + document: Document to be stored. """ raise NotImplementedError() - def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None: + def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None: """Stores a list of ids and documents in batch. The default implementation only loops to the individual `store_by_id`. Subclasses that have faster ways to store data via batch uploading should implement the proper way. Args: ids: List of ids for the text. - texts: List of texts. + documents: List of documents. """ - for id_, text in zip(ids, texts): - self.store_by_id(id_, text) + for id_, document in zip(ids, documents): + self.store_by_id(id_, document) - def batch_get_by_id(self, ids: List[str]) -> List[str | None]: + def batch_get_by_id(self, ids: List[str]) -> List[Document | None]: """Gets a batch of documents by id. The default implementation only loops `get_by_id`. Subclasses that have faster ways to retrieve data by batch should implement @@ -51,8 +53,8 @@ def batch_get_by_id(self, ids: List[str]) -> List[str | None]: Args: ids: List of ids for the text. Returns: - List of texts. If the key id is not found for any id record returns a None - instead. + List of documents. If the key id is not found for any id record returns a + None instead. """ return [self.get_by_id(id_) for id_ in ids] @@ -75,12 +77,12 @@ def __init__( self._bucket = bucket self._prefix = prefix - def get_by_id(self, document_id: str) -> str | None: + def get_by_id(self, document_id: str) -> Document | None: """Gets the text of a document by its id. If not found, returns None. Args: document_id: Id of the document to get from the storage. Returns: - Text of the document if found, otherwise None. + Document if found, otherwise None. """ blob_name = self._get_blob_name(document_id) @@ -89,17 +91,22 @@ def get_by_id(self, document_id: str) -> str | None: if existing_blob is None: return None - return existing_blob.download_as_text() + document_str = existing_blob.download_as_text() + document_json: Dict[str, Any] = json.loads(document_str) + return Document(**document_json) - def store_by_id(self, document_id: str, text: str) -> None: + def store_by_id(self, document_id: str, document: Document) -> None: """Stores a document text associated to a document_id. Args: document_id: Id of the document to be stored. - text: Text of the document to be stored. + dcoument: Document to be stored. """ blob_name = self._get_blob_name(document_id) new_blow = self._bucket.blob(blob_name) - new_blow.upload_from_string(text) + + document_json = document.dict() + document_text = json.dumps(document_json) + new_blow.upload_from_string(document_text) def _get_blob_name(self, document_id: str) -> str: """Builds a blob name using the prefix and the document_id. @@ -119,6 +126,7 @@ def __init__( datastore_client: datastore.Client, kind: str = "document_id", text_property_name: str = "text", + metadata_property_name: str = "metadata", ) -> None: """Constructor. Args: @@ -128,9 +136,10 @@ def __init__( super().__init__() self._client = datastore_client self._text_property_name = text_property_name + self._metadata_property_name = metadata_property_name self._kind = kind - def get_by_id(self, document_id: str) -> str | None: + def get_by_id(self, document_id: str) -> Document | None: """Gets the text of a document by its id. If not found, returns None. Args: document_id: Id of the document to get from the storage. @@ -139,9 +148,16 @@ def get_by_id(self, document_id: str) -> str | None: """ key = self._client.key(self._kind, document_id) entity = self._client.get(key) - return entity[self._text_property_name] - def store_by_id(self, document_id: str, text: str) -> None: + if entity is None: + return None + + return Document( + page_content=entity[self._text_property_name], + metadata=self._convert_entity_to_dict(entity[self._metadata_property_name]), + ) + + def store_by_id(self, document_id: str, document: Document) -> None: """Stores a document text associated to a document_id. Args: document_id: Id of the document to be stored. @@ -149,11 +165,14 @@ def store_by_id(self, document_id: str, text: str) -> None: """ with self._client.transaction(): key = self._client.key(self._kind, document_id) + entity = self._client.entity(key=key) - entity[self._text_property_name] = text + entity[self._text_property_name] = document.page_content + entity[self._metadata_property_name] = document.metadata + self._client.put(entity) - def batch_get_by_id(self, ids: List[str]) -> List[str | None]: + def batch_get_by_id(self, ids: List[str]) -> List[Document | None]: """Gets a batch of documents by id. Args: ids: List of ids for the text. @@ -166,9 +185,22 @@ def batch_get_by_id(self, ids: List[str]) -> List[str | None]: # TODO: Handle when a key is not present entities = self._client.get_multi(keys) - return [entity[self._text_property_name] for entity in entities] - - def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None: + # Entities are not sorted by key by default, the order is unclear. This orders + # the list by the id retrieved. + entity_id_lookup = {entity.key.id_or_name: entity for entity in entities} + entities = [entity_id_lookup[id_] for id_ in ids] + + return [ + Document( + page_content=entity[self._text_property_name], + metadata=self._convert_entity_to_dict( + entity[self._metadata_property_name] + ), + ) + for entity in entities + ] + + def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None: """Stores a list of ids and documents in batch. Args: ids: List of ids for the text. @@ -179,9 +211,21 @@ def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None: keys = [self._client.key(self._kind, id_) for id_ in ids] entities = [] - for key, text in zip(keys, texts): + for key, document in zip(keys, documents): entity = self._client.entity(key=key) - entity[self._text_property_name] = text + entity[self._text_property_name] = document.page_content + entity[self._metadata_property_name] = document.metadata entities.append(entity) self._client.put_multi(entities) + + def _convert_entity_to_dict(self, entity: datastore.Entity) -> Dict[str, Any]: + """Recursively transform an entity into a plain dictionary.""" + from google.cloud import datastore # type: ignore[attr-defined, unused-ignore] + + dict_entity = dict(entity) + for key in dict_entity: + value = dict_entity[key] + if isinstance(value, datastore.Entity): + dict_entity[key] = self._convert_entity_to_dict(value) + return dict_entity diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py index 38ae4f76..4deccea0 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py @@ -8,7 +8,7 @@ from google.oauth2.service_account import Credentials # type: ignore if TYPE_CHECKING: - from google.cloud import datastore # type: ignore[attr-defined] + from google.cloud import datastore # type: ignore[attr-defined, unused-ignore] class VectorSearchSDKManager: @@ -107,7 +107,7 @@ def get_datastore_client(self, **kwargs: Any) -> "datastore.Client": Returns: datastore Client. """ - from google.cloud import datastore # type: ignore[attr-defined] + from google.cloud import datastore # type: ignore[attr-defined, unused-ignore] ds_client = datastore.Client( project=self._project_id, credentials=self._credentials, **kwargs diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py index fa8f6379..3288ad9b 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py @@ -106,14 +106,13 @@ def similarity_search_by_vector_with_score( results = [] for neighbor_id, distance in neighbors_list[0]: - text = self._document_storage.get_by_id(neighbor_id) + document = self._document_storage.get_by_id(neighbor_id) - if text is None: + if document is None: raise ValueError( f"Document with id {neighbor_id} not found in document" "storage." ) - # TODO: Handle metadata - document = Document(page_content=text, metadata={}) + results.append((document, distance)) return results @@ -164,8 +163,24 @@ def add_texts( texts = list(texts) ids = self._generate_unique_ids(len(texts)) - self._document_storage.batch_store_by_id(ids=ids, texts=texts) + if metadatas is None: + metadatas = [{}] * len(texts) + + if len(metadatas) != len(texts): + raise ValueError( + "`metadatas` should be the same length as `texts` " + f"{len(metadatas)} != {len(texts)}" + ) + + documents = [ + Document(page_content=text, metadata=metadata) + for text, metadata in zip(texts, metadatas) + ] + + self._document_storage.batch_store_by_id(ids=ids, documents=documents) + embeddings = self._embeddings.embed_documents(texts) + self._searcher.add_to_index(ids, embeddings, metadatas, **kwargs) return ids diff --git a/libs/vertexai/tests/integration_tests/test_vectorstores.py b/libs/vertexai/tests/integration_tests/test_vectorstores.py index e8fd5b39..a4da55fe 100644 --- a/libs/vertexai/tests/integration_tests/test_vectorstores.py +++ b/libs/vertexai/tests/integration_tests/test_vectorstores.py @@ -12,6 +12,7 @@ """ import os +from uuid import uuid4 import pytest from google.cloud import storage # type: ignore[attr-defined] @@ -24,6 +25,7 @@ from langchain_google_vertexai.embeddings import VertexAIEmbeddings from langchain_google_vertexai.vectorstores._document_storage import ( DataStoreDocumentStorage, + DocumentStorage, GCSDocumentStorage, ) from langchain_google_vertexai.vectorstores._sdk_manager import VectorSearchSDKManager @@ -41,6 +43,20 @@ def sdk_manager() -> VectorSearchSDKManager: return sdk_manager +@pytest.fixture +def gcs_document_storage(sdk_manager: VectorSearchSDKManager) -> GCSDocumentStorage: + bucket = sdk_manager.get_gcs_bucket(bucket_name=os.environ["GCS_BUCKET_NAME"]) + return GCSDocumentStorage(bucket=bucket, prefix="integration_tests") + + +@pytest.fixture +def datastore_document_storage( + sdk_manager: VectorSearchSDKManager, +) -> DataStoreDocumentStorage: + ds_client = sdk_manager.get_datastore_client(namespace="integration_tests") + return DataStoreDocumentStorage(datastore_client=ds_client) + + @pytest.mark.extended def test_vector_search_sdk_manager(sdk_manager: VectorSearchSDKManager): gcs_client = sdk_manager.get_gcs_client() @@ -57,50 +73,38 @@ def test_vector_search_sdk_manager(sdk_manager: VectorSearchSDKManager): @pytest.mark.extended -def test_gcs_document_storage(sdk_manager: VectorSearchSDKManager): - bucket = sdk_manager.get_gcs_bucket(os.environ["GCS_BUCKET_NAME"]) - prefix = "integration-test" - - storage = GCSDocumentStorage(bucket=bucket, prefix=prefix) - - id_ = "test-id" - text = "Test text" - - storage.store_by_id(id_, text) - - assert storage.get_by_id(id_) == text - - ids = [f"test-id_{i}" for i in range(5)] - texts = [f"Test Text {i}" for i in range(5)] - - storage.batch_store_by_id(ids, texts) - retrieved_texts = storage.batch_get_by_id(ids) - - for original_text, retrieved_text in zip(retrieved_texts, texts): - assert original_text == retrieved_text - - -@pytest.mark.extended -def test_datastore_document_storage(sdk_manager: VectorSearchSDKManager): - ds_client = sdk_manager.get_datastore_client(namespace="Foo") - - storage = DataStoreDocumentStorage(datastore_client=ds_client) - - id_ = "test-id" - text = "Test text" - - storage.store_by_id(id_, text) - - assert storage.get_by_id(id_) == text - - ids = [f"test-id_{i}" for i in range(5)] - texts = [f"Test Text {i}" for i in range(5)] - - storage.batch_store_by_id(ids, texts) - retrieved_texts = storage.batch_get_by_id(ids) - - for original_text, retrieved_text in zip(retrieved_texts, texts): - assert original_text == retrieved_text +@pytest.mark.parametrize( + "storage_class", ["gcs_document_storage", "datastore_document_storage"] +) +def test_document_storage( + sdk_manager: VectorSearchSDKManager, + storage_class: str, + request: pytest.FixtureRequest, +): + document_storage: DocumentStorage = request.getfixturevalue(storage_class) + + N = 10 + documents = [ + Document( + page_content=f"Text content of document {i}", + metadata={"index": i, "nested": {"a": i, "b": str(uuid4())}}, + ) + for i in range(N) + ] + ids = [str(uuid4()) for i in range(N)] + + # Test individual retrieval + for id, document in zip(ids, documents): + document_storage.store_by_id(document_id=id, document=document) + retrieved = document_storage.get_by_id(document_id=id) + assert document == retrieved + + # Test batch regtrieval + document_storage.batch_store_by_id(ids, documents) + retrieved_documents = document_storage.batch_get_by_id(ids) + + for og_document, retrieved_document in zip(documents, retrieved_documents): + assert og_document == retrieved_document @pytest.mark.extended From a74038bfc966dfc7a0db84441950837b1ea73617 Mon Sep 17 00:00:00 2001 From: Jorge Date: Sat, 9 Mar 2024 02:35:17 +0100 Subject: [PATCH 02/14] Fix spelling --- .../langchain_google_vertexai/vectorstores/_document_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py index cbdc8bb2..8fee51d5 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py @@ -99,7 +99,7 @@ def store_by_id(self, document_id: str, document: Document) -> None: """Stores a document text associated to a document_id. Args: document_id: Id of the document to be stored. - dcoument: Document to be stored. + document: Document to be stored. """ blob_name = self._get_blob_name(document_id) new_blow = self._bucket.blob(blob_name) From a54c6b84d15915c2d368153492201da54c9ea3c6 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 16:33:49 +0100 Subject: [PATCH 03/14] Add filtering support + stream updates --- .../vectorstores/_document_storage.py | 2 +- .../vectorstores/_searcher.py | 128 +++++++------- .../vectorstores/_utils.py | 157 ++++++++++++++++++ .../vectorstores/vectorstores.py | 9 +- .../integration_tests/test_image_utils.py | 2 +- .../integration_tests/test_vectorstores.py | 6 +- 6 files changed, 224 insertions(+), 80 deletions(-) create mode 100644 libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py index 8fee51d5..eb6c4611 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional -from google.cloud import storage # type: ignore[attr-defined] +from google.cloud import storage # type: ignore[attr-defined, unused-ignore] from langchain_core.documents import Document if TYPE_CHECKING: diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py index be4ce5dc..df07a9db 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py @@ -1,10 +1,7 @@ -import json -import time -import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple, Union +from typing import Any, List, Tuple, Union -from google.cloud import storage # type: ignore[attr-defined] +from google.cloud import storage # type: ignore[attr-defined, unused-ignore] from google.cloud.aiplatform.matching_engine import ( MatchingEngineIndex, MatchingEngineIndexEndpoint, @@ -14,6 +11,12 @@ Namespace, ) +from langchain_google_vertexai.vectorstores._utils import ( + batch_update_index, + stream_update_index, + to_data_points, +) + class Searcher(ABC): """Abstract implementation of a similarity searcher.""" @@ -41,9 +44,16 @@ def add_to_index( ids: List[str], embeddings: List[List[float]], metadatas: Union[List[dict], None] = None, + is_complete_overwrite: bool = False, **kwargs: Any, - ): - """ """ + ) -> None: + """Adds documents to the index. + + Args: + ids: List of unique ids. + embeddings: List of embedddings for each record. + metadatas: List of metadata of each record. + """ raise NotImplementedError() def _postprocess_response( @@ -63,13 +73,14 @@ def _postprocess_response( class VectorSearchSearcher(Searcher): - """ """ + """Class to interface with a VectorSearch index and endpoint.""" def __init__( self, endpoint: MatchingEngineIndexEndpoint, index: MatchingEngineIndex, staging_bucket: Union[storage.Bucket, None] = None, + stream_update: bool = False, ) -> None: """Constructor. Args: @@ -85,57 +96,41 @@ def __init__( self._index = index self._deployed_index_id = self._get_deployed_index_id() self._staging_bucket = staging_bucket + self._stream_update = stream_update def add_to_index( self, ids: List[str], embeddings: List[List[float]], metadatas: Union[List[dict], None] = None, + is_complete_overwrite: bool = False, **kwargs: Any, ) -> None: - """ """ - - if self._staging_bucket is None: - raise ValueError( - "In order to update a Vector Search index a staging bucket must" - " be defined." - ) + """Adds documents to the index. - record_list = [] - for i, (idx, embedding) in enumerate(zip(ids, embeddings)): - record: Dict[str, Any] = {"id": idx, "embedding": embedding} - if metadatas is not None: - record["metadata"] = metadatas[i] - record_list.append(record) - file_content = "\n".join([json.dumps(x) for x in record_list]) - - filename_prefix = f"indexes/{uuid.uuid4()}" - filename = f"{filename_prefix}/{time.time()}.json" - blob = self._staging_bucket.blob(filename) - blob.upload_from_string(data=file_content) - - self.index = self._index.update_embeddings( - contents_delta_uri=f"gs://{self._staging_bucket.name}/{filename_prefix}/" - ) - - def _get_deployed_index_id(self) -> str: - """Gets the deployed index id that matches with the provided index. - Raises: - ValueError if the index provided is not found in the endpoint. + Args: + ids: List of unique ids. + embeddings: List of embedddings for each record. + metadatas: List of metadata of each record. + is_complete_overwrite: Whether to overwrite everything. """ - for index in self._endpoint.deployed_indexes: - if index.index == self._index.resource_name: - return index.id - - raise ValueError( - f"No index with id {self._index.resource_name} " - f"deployed on endpoint " - f"{self._endpoint.display_name}." - ) - -class PublicEndpointVectorSearchSearcher(VectorSearchSearcher): - """ """ + data_points = to_data_points(ids, embeddings, metadatas) + + if self._stream_update: + stream_update_index(index=self._index, data_points=data_points) + else: + if self._staging_bucket is None: + raise ValueError( + "In order to update a Vector Search index a staging bucket must" + " be defined." + ) + batch_update_index( + index=self._index, + data_points=data_points, + staging_bucket=self._staging_bucket, + is_complete_overwrite=is_complete_overwrite, + ) def find_neighbors( self, @@ -152,6 +147,8 @@ def find_neighbors( List of lists of Tuples (id, distance) for each embedding vector. """ + # No need to implement other method for private VPC, find_neighbors now works + # with public and private. response = self._endpoint.find_neighbors( deployed_index_id=self._deployed_index_id, queries=embeddings, @@ -161,30 +158,17 @@ def find_neighbors( return self._postprocess_response(response) - -class VPCVertexVectorStore(VectorSearchSearcher): - """ """ - - def find_neighbors( - self, - embeddings: List[List[float]], - k: int = 4, - filter_: Union[List[Namespace], None] = None, - ) -> List[List[Tuple[str, float]]]: - """Finds the k closes neighbors of each instance of embeddings. - Args: - embedding: List of embeddings vectors. - k: Number of neighbors to be retrieved. - filter_: List of filters to apply. - Returns: - List of lists of Tuples (id, distance) for each embedding vector. + def _get_deployed_index_id(self) -> str: + """Gets the deployed index id that matches with the provided index. + Raises: + ValueError if the index provided is not found in the endpoint. """ + for index in self._endpoint.deployed_indexes: + if index.index == self._index.resource_name: + return index.id - response = self._endpoint.match( - deployed_index_id=self._deployed_index_id, - queries=embeddings, - num_neighbors=k, - filter=filter_, + raise ValueError( + f"No index with id {self._index.resource_name} " + f"deployed on endpoint " + f"{self._endpoint.display_name}." ) - - return self._postprocess_response(response) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py new file mode 100644 index 00000000..f79c9e36 --- /dev/null +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py @@ -0,0 +1,157 @@ +import json +import uuid +from typing import Any, Dict, List, Union + +from google.cloud.aiplatform import MatchingEngineIndex +from google.cloud.aiplatform.compat.types import ( # type: ignore[attr-defined, unused-ignore] + matching_engine_index as meidx_types, +) +from google.cloud.storage import Bucket # type: ignore[import-untyped, unused-ignore] + + +def stream_update_index( + index: MatchingEngineIndex, data_points: List["meidx_types.IndexDataPoint"] +) -> None: + """Updates an index using stream updating. + + Args: + index: Vector search index. + data_points: List of IndexDataPoint. + """ + index.upsert_datapoints(data_points) + + +def batch_update_index( + index: MatchingEngineIndex, + data_points: List["meidx_types.IndexDataPoint"], + *, + staging_bucket: Bucket, + prefix: Union[str, None] = None, + file_name: str = "documents.json", + is_complete_overwrite: bool = False, +) -> None: + """Updates an index using batch updating. + + Args: + index: Vector search index. + data_points: List of IndexDataPoint. + staging_bucket: Bucket where the staging data is stored. Must be in the same + region as the index. + prefix: Prefix for the blob name. If not provided an unique iid will be + generated. + file_name: File name of the staging embeddings. By default 'documents.json'. + is_complete_overwrite: Whether is an append or overwrite operation. + """ + + if prefix is None: + prefix = str(uuid.uuid4()) + + records = data_points_to_batch_update_records(data_points) + + file_content = "\n".join(json.dumps(record) for record in records) + + blob = staging_bucket.blob(f"{prefix}/{file_name}") + blob.upload_from_string(file_content) + + contents_delta_uri = f"gs://{staging_bucket.name}/{prefix}" + + index.update_embeddings( + contents_delta_uri=contents_delta_uri, + is_complete_overwrite=is_complete_overwrite, + ) + + +def to_data_points( + ids: List[str], + embeddings: List[List[float]], + metadatas: Union[List[Dict[str, Any]], None], +) -> List["meidx_types.IndexDataPoint"]: + """Converts triplets id, embedding, metadata into IndexDataPoints instances. + + Only metadata with values of type string, numeric or list of string will be + considered for the filtering. + + Args: + ids: List of unique ids. + embeddings: List of feature representatitons. + metadatas: List of metadatas. + """ + + if metadatas is None: + metadatas = [{}] * len(ids) + + data_points = [] + + for id_, embedding, metadata in zip(ids, embeddings, metadatas): + restricts = [] + numeric_restricts = [] + + for namespace, value in metadata.items(): + if not isinstance(namespace, str): + raise ValueError("All metadata keys must be strings") + + if isinstance(value, str): + restriction = meidx_types.Restriction( + namespace=namespace, allow_list=[value] + ) + restricts.append(restriction) + elif isinstance(value, list) and all( + isinstance(item, str) for item in value + ): + restriction = meidx_types.Restriction( + namespace=namespace, allow_list=value + ) + restricts.append(restriction) + elif isinstance(value, (int, float)) and not isinstance(value, bool): + restriction = meidx_types.NumericRestriction( + namespace=namespace, value_float=value + ) + numeric_restricts.append(restriction) + + data_point = meidx_types.IndexDataPoint( + datapoint_id=id_, + feature_vector=embedding, + restricts=restricts, + numeric_restricts=numeric_restricts, + ) + + data_points.append(data_point) + + return data_points + + +def data_points_to_batch_update_records( + data_points: List["meidx_types.IndexDataPoint"], +) -> List[Dict[str, Any]]: + """Given a list of datapoints, generates a list of records in the input format + required to do a bactch update. + + Args: + data_points: List of IndexDataPoints. + + Returns: + List of records with the format needed to do a batch update. + """ + + records = [] + + for data_point in data_points: + record = { + "id": data_point.datapoint_id, + "embedding": [component for component in data_point.feature_vector], + "restricts": [ + { + "namespace": restrict.namespace, + "allow": [item for item in restrict.allow_list], + } + for restrict in data_point.restricts + ], + "numeric_restricts": [ + {"namespace": restrict.namespace, "value_float": restrict.value_float} + for restrict in data_point.numeric_restricts + ], + } + + records.append(record) + + return records diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py index 3288ad9b..c9517ec3 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py @@ -15,8 +15,8 @@ ) from langchain_google_vertexai.vectorstores._sdk_manager import VectorSearchSDKManager from langchain_google_vertexai.vectorstores._searcher import ( - PublicEndpointVectorSearchSearcher, Searcher, + VectorSearchSearcher, ) @@ -147,6 +147,7 @@ def add_texts( self, texts: Iterable[str], metadatas: Union[List[dict], None] = None, + is_complete_overwrite: bool = False, **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. @@ -181,7 +182,9 @@ def add_texts( embeddings = self._embeddings.embed_documents(texts) - self._searcher.add_to_index(ids, embeddings, metadatas, **kwargs) + self._searcher.add_to_index( + ids, embeddings, metadatas, is_complete_overwrite, **kwargs + ) return ids @@ -278,7 +281,7 @@ def from_components( # Implemented in order to keep the current API return cls( document_storage=GCSDocumentStorage(bucket=bucket), - searcher=PublicEndpointVectorSearchSearcher( + searcher=VectorSearchSearcher( endpoint=endpoint, index=index, staging_bucket=bucket ), embbedings=embedding, diff --git a/libs/vertexai/tests/integration_tests/test_image_utils.py b/libs/vertexai/tests/integration_tests/test_image_utils.py index 87f75a81..2a23eda0 100644 --- a/libs/vertexai/tests/integration_tests/test_image_utils.py +++ b/libs/vertexai/tests/integration_tests/test_image_utils.py @@ -1,5 +1,5 @@ import pytest -from google.cloud import storage # type: ignore[attr-defined] +from google.cloud import storage # type: ignore[attr-defined, unused-ignore] from google.cloud.exceptions import NotFound from langchain_google_vertexai._image_utils import ImageBytesLoader diff --git a/libs/vertexai/tests/integration_tests/test_vectorstores.py b/libs/vertexai/tests/integration_tests/test_vectorstores.py index a4da55fe..fc3df8b8 100644 --- a/libs/vertexai/tests/integration_tests/test_vectorstores.py +++ b/libs/vertexai/tests/integration_tests/test_vectorstores.py @@ -15,7 +15,7 @@ from uuid import uuid4 import pytest -from google.cloud import storage # type: ignore[attr-defined] +from google.cloud import storage # type: ignore[attr-defined, unused-ignore] from google.cloud.aiplatform.matching_engine import ( MatchingEngineIndex, MatchingEngineIndexEndpoint, @@ -30,7 +30,7 @@ ) from langchain_google_vertexai.vectorstores._sdk_manager import VectorSearchSDKManager from langchain_google_vertexai.vectorstores._searcher import ( - PublicEndpointVectorSearchSearcher, + VectorSearchSearcher, ) from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore @@ -113,7 +113,7 @@ def test_public_endpoint_vector_searcher(sdk_manager: VectorSearchSDKManager): endpoint = sdk_manager.get_endpoint(os.environ["ENDPOINT_ID"]) embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default") - searcher = PublicEndpointVectorSearchSearcher(endpoint=endpoint, index=index) + searcher = VectorSearchSearcher(endpoint=endpoint, index=index) texts = ["What's your favourite animal", "What's your favourite city"] From d5fe56949f61bd49985c471f8a8564003675a7b2 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 16:38:06 +0100 Subject: [PATCH 04/14] Mypy complains --- libs/vertexai/langchain_google_vertexai/_utils.py | 2 +- .../langchain_google_vertexai/vectorstores/_sdk_manager.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/_utils.py b/libs/vertexai/langchain_google_vertexai/_utils.py index e2487940..5ca1c35b 100644 --- a/libs/vertexai/langchain_google_vertexai/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_utils.py @@ -8,7 +8,7 @@ import google.api_core import proto # type: ignore[import-untyped] from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud import storage # type: ignore[attr-defined] +from google.cloud import storage # type: ignore[attr-defined, unused-ignore] from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py index 4deccea0..a12b3aa8 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py @@ -1,6 +1,9 @@ from typing import TYPE_CHECKING, Any, Union -from google.cloud import aiplatform, storage # type: ignore[attr-defined] +from google.cloud import ( # type: ignore[attr-defined, unused-ignore] + aiplatform, + storage, +) from google.cloud.aiplatform.matching_engine import ( MatchingEngineIndex, MatchingEngineIndexEndpoint, From 727a9ca678739502edc5511f6592bffb34b30309 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 17:21:36 +0100 Subject: [PATCH 05/14] Add numeric filtering --- .../vectorstores/_searcher.py | 4 ++++ .../vectorstores/vectorstores.py | 23 ++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py index df07a9db..3cb9b20b 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py @@ -9,6 +9,7 @@ from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( MatchNeighbor, Namespace, + NumericNamespace, ) from langchain_google_vertexai.vectorstores._utils import ( @@ -27,6 +28,7 @@ def find_neighbors( embeddings: List[List[float]], k: int = 4, filter_: Union[List[Namespace], None] = None, + numeric_filter: Union[List[NumericNamespace], None] = None ) -> List[List[Tuple[str, float]]]: """Finds the k closes neighbors of each instance of embeddings. Args: @@ -137,6 +139,7 @@ def find_neighbors( embeddings: List[List[float]], k: int = 4, filter_: Union[List[Namespace], None] = None, + numeric_filter: Union[List[Namespace], None] = None ) -> List[List[Tuple[str, float]]]: """Finds the k closes neighbors of each instance of embeddings. Args: @@ -154,6 +157,7 @@ def find_neighbors( queries=embeddings, num_neighbors=k, filter=filter_, + numeric_filter=numeric_filter ) return self._postprocess_response(response) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py index c9517ec3..88128b77 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py @@ -4,6 +4,7 @@ from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( Namespace, + NumericNamespace, ) from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -50,6 +51,7 @@ def similarity_search_with_score( query: str, k: int = 4, filter: Optional[List[Namespace]] = None, + numeric_filter: Optional[List[NumericNamespace]] = None ) -> List[Tuple[Document, float]]: """Return docs most similar to query and their cosine distance from the query. Args: @@ -63,6 +65,10 @@ def similarity_search_with_score( datapoints with "squared shape". Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + numeric_filter: Optional. A list of NumericNamespaces for filterning + the matching results. Please refer to + https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json + for more detail. Returns: List[Tuple[Document, float]]: List of documents most similar to the query text and cosine distance in float for each. @@ -72,7 +78,7 @@ def similarity_search_with_score( embbedings = self._embeddings.embed_query(query) return self.similarity_search_by_vector_with_score( - embedding=embbedings, k=k, filter=filter + embedding=embbedings, k=k, filter=filter, numeric_filter=numeric_filter ) def similarity_search_by_vector_with_score( @@ -80,6 +86,7 @@ def similarity_search_by_vector_with_score( embedding: List[float], k: int = 4, filter: Optional[List[Namespace]] = None, + numeric_filter: Optional[List[NumericNamespace]] = None ) -> List[Tuple[Document, float]]: """Return docs most similar to the embedding and their cosine distance. Args: @@ -93,6 +100,10 @@ def similarity_search_by_vector_with_score( datapoints with "squared shape". Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + numeric_filter: Optional. A list of NumericNamespaces for filterning + the matching results. Please refer to + https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json + for more detail. Returns: List[Tuple[Document, float]]: List of documents most similar to the query text and cosine distance in float for each. @@ -100,7 +111,7 @@ def similarity_search_by_vector_with_score( """ neighbors_list = self._searcher.find_neighbors( - embeddings=[embedding], k=k, filter_=filter + embeddings=[embedding], k=k, filter_=filter, numeric_filter=numeric_filter ) results = [] @@ -122,6 +133,7 @@ def similarity_search( query: str, k: int = 4, filter: Optional[List[Namespace]] = None, + numeric_filter: Optional[List[NumericNamespace]] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to query. @@ -135,12 +147,17 @@ def similarity_search( datapoints with "squared shape". Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + numeric_filter: Optional. A list of NumericNamespaces for filterning + the matching results. Please refer to + https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json + for more detail. Returns: A list of k matching documents. """ return [ document - for document, _ in self.similarity_search_with_score(query, k, filter) + for document, _ in self.similarity_search_with_score( + query, k, filter, numeric_filter) ] def add_texts( From 7a159fdf7f3babfaced494c8bafe642796a9682b Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 17:22:54 +0100 Subject: [PATCH 06/14] Fix format and linting --- .../langchain_google_vertexai/vectorstores/_searcher.py | 6 +++--- .../langchain_google_vertexai/vectorstores/vectorstores.py | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py index 3cb9b20b..633b5d42 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_searcher.py @@ -28,7 +28,7 @@ def find_neighbors( embeddings: List[List[float]], k: int = 4, filter_: Union[List[Namespace], None] = None, - numeric_filter: Union[List[NumericNamespace], None] = None + numeric_filter: Union[List[NumericNamespace], None] = None, ) -> List[List[Tuple[str, float]]]: """Finds the k closes neighbors of each instance of embeddings. Args: @@ -139,7 +139,7 @@ def find_neighbors( embeddings: List[List[float]], k: int = 4, filter_: Union[List[Namespace], None] = None, - numeric_filter: Union[List[Namespace], None] = None + numeric_filter: Union[List[NumericNamespace], None] = None, ) -> List[List[Tuple[str, float]]]: """Finds the k closes neighbors of each instance of embeddings. Args: @@ -157,7 +157,7 @@ def find_neighbors( queries=embeddings, num_neighbors=k, filter=filter_, - numeric_filter=numeric_filter + numeric_filter=numeric_filter, ) return self._postprocess_response(response) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py index 88128b77..77fa08f7 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py @@ -51,7 +51,7 @@ def similarity_search_with_score( query: str, k: int = 4, filter: Optional[List[Namespace]] = None, - numeric_filter: Optional[List[NumericNamespace]] = None + numeric_filter: Optional[List[NumericNamespace]] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to query and their cosine distance from the query. Args: @@ -86,7 +86,7 @@ def similarity_search_by_vector_with_score( embedding: List[float], k: int = 4, filter: Optional[List[Namespace]] = None, - numeric_filter: Optional[List[NumericNamespace]] = None + numeric_filter: Optional[List[NumericNamespace]] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to the embedding and their cosine distance. Args: @@ -157,7 +157,8 @@ def similarity_search( return [ document for document, _ in self.similarity_search_with_score( - query, k, filter, numeric_filter) + query, k, filter, numeric_filter + ) ] def add_texts( From 6bd5302d8c1d08eb1fe25f26bc7ac7290c3c59c5 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 17:43:36 +0100 Subject: [PATCH 07/14] Added tests --- .../integration_tests/test_vectorstores.py | 233 ++++++++++++++++-- 1 file changed, 212 insertions(+), 21 deletions(-) diff --git a/libs/vertexai/tests/integration_tests/test_vectorstores.py b/libs/vertexai/tests/integration_tests/test_vectorstores.py index fc3df8b8..f7ff2e31 100644 --- a/libs/vertexai/tests/integration_tests/test_vectorstores.py +++ b/libs/vertexai/tests/integration_tests/test_vectorstores.py @@ -8,10 +8,10 @@ - GCS_BUCKET_NAME: Name of a Google Cloud Storage Bucket - INDEX_ID: Id of the Vector Search index. - ENDPOINT_ID: Id of the Vector Search endpoint. -If required to run slow tests, environment variable 'RUN_SLOW_TESTS' must be set """ import os +from typing import List from uuid import uuid4 import pytest @@ -20,6 +20,10 @@ MatchingEngineIndex, MatchingEngineIndexEndpoint, ) +from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import ( + Namespace, + NumericNamespace, +) from langchain_core.documents import Document from langchain_google_vertexai.embeddings import VertexAIEmbeddings @@ -57,6 +61,22 @@ def datastore_document_storage( return DataStoreDocumentStorage(datastore_client=ds_client) +@pytest.fixture +def vector_store() -> VectorSearchVectorStore: + embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default") + + vector_store = VectorSearchVectorStore.from_components( + project_id=os.environ["PROJECT_ID"], + region=os.environ["REGION"], + gcs_bucket_name=os.environ["GCS_BUCKET_NAME"], + index_id=os.environ["INDEX_ID"], + endpoint_id=os.environ["ENDPOINT_ID"], + embedding=embeddings, + ) + + return vector_store + + @pytest.mark.extended def test_vector_search_sdk_manager(sdk_manager: VectorSearchSDKManager): gcs_client = sdk_manager.get_gcs_client() @@ -125,18 +145,7 @@ def test_public_endpoint_vector_searcher(sdk_manager: VectorSearchSDKManager): @pytest.mark.extended -def test_vector_store(): - embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default") - - vector_store = VectorSearchVectorStore.from_components( - project_id=os.environ["PROJECT_ID"], - region=os.environ["REGION"], - gcs_bucket_name=os.environ["GCS_BUCKET_NAME"], - index_id=os.environ["INDEX_ID"], - endpoint_id=os.environ["ENDPOINT_ID"], - embedding=embeddings, - ) - +def test_vector_store(vector_store: VectorSearchVectorStore): assert isinstance(vector_store, VectorSearchVectorStore) query = "What are your favourite animals?" @@ -153,7 +162,20 @@ def test_vector_store(): @pytest.mark.extended -def test_vector_store_update_index(): +def test_vector_store_filtering(vector_store: VectorSearchVectorStore): + documents = vector_store.similarity_search( + "I want some pants", + filter=[Namespace(name="color", allow_tokens=["blue"])], + numeric_filter=[NumericNamespace(name="price", value_float=20.0, op="LESS")], + ) + + assert len(documents) > 0 + + assert all(document.metadata["price"] < 20.0 for document in documents) + + +@pytest.mark.extended +def test_vector_store_update_index(sample_documents: list[Document]): embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default") vector_store = VectorSearchVectorStore.from_components( @@ -165,10 +187,179 @@ def test_vector_store_update_index(): embedding=embeddings, ) - vector_store.add_texts( - texts=[ - "Lions are my favourite animals", - "There are two apples on the table", - "Today is raining a lot in Madrid", - ] - ) + vector_store.add_documents(documents=sample_documents, is_complete_overwrite=True) + + +@pytest.fixture +def sample_documents() -> List[Document]: + record_data = [ + { + "description": "A versatile pair of dark-wash denim jeans." + "Made from durable cotton with a classic straight-leg cut, these jeans" + " transition easily from casual days to dressier occasions.", + "price": 65.00, + "color": "blue", + "season": ["fall", "winter", "spring"], + }, + { + "description": "A lightweight linen button-down shirt in a crisp white." + " Perfect for keeping cool with breathable fabric and a relaxed fit.", + "price": 34.99, + "color": "white", + "season": ["summer", "spring"], + }, + { + "description": "A soft, chunky knit sweater in a vibrant forest green. " + "The oversized fit and cozy wool blend make this ideal for staying warm " + "when the temperature drops.", + "price": 89.99, + "color": "green", + "season": ["fall", "winter"], + }, + { + "description": "A classic crewneck t-shirt in a soft, heathered blue. " + "Made from comfortable cotton jersey, this t-shirt is a wardrobe essential " + "that works for every season.", + "price": 19.99, + "color": "blue", + "season": ["fall", "winter", "summer", "spring"], + }, + { + "description": "A flowing midi-skirt in a delicate floral print. " + "Lightweight and airy, this skirt adds a touch of feminine style " + "to warmer days.", + "price": 45.00, + "color": "white", + "season": ["spring", "summer"], + }, + { + "description": "A pair of tailored black trousers in a comfortable stretch " + "fabric. Perfect for work or dressier events, these trousers provide a" + " sleek, polished look.", + "price": 59.99, + "color": "black", + "season": ["fall", "winter", "spring"], + }, + { + "description": "A cozy fleece hoodie in a neutral heather grey. " + "This relaxed sweatshirt is perfect for casual days or layering when the " + "weather turns chilly.", + "price": 39.99, + "color": "grey", + "season": ["fall", "winter", "spring"], + }, + { + "description": "A bright yellow raincoat with a playful polka dot pattern. " + "This waterproof jacket will keep you dry and add a touch of cheer to " + "rainy days.", + "price": 75.00, + "color": "yellow", + "season": ["spring", "fall"], + }, + { + "description": "A pair of comfortable khaki chino shorts. These versatile " + "shorts are a summer staple, perfect for outdoor adventures or relaxed" + " weekends.", + "price": 34.99, + "color": "khaki", + "season": ["summer"], + }, + { + "description": "A bold red cocktail dress with a flattering A-line " + "silhouette. This statement piece is made from a luxurious satin fabric, " + "ensuring a head-turning look.", + "price": 125.00, + "color": "red", + "season": ["fall", "winter", "summer", "spring"], + }, + { + "description": "A pair of classic white sneakers crafted from smooth " + "leather. These timeless shoes offer a clean and polished look, perfect " + "for everyday wear.", + "price": 79.99, + "color": "white", + "season": ["fall", "winter", "summer", "spring"], + }, + { + "description": "A chunky cable-knit scarf in a rich burgundy color. " + "Made from a soft wool blend, this scarf will provide warmth and a touch " + "of classic style to cold-weather looks.", + "price": 45.00, + "color": "burgundy", + "season": ["fall", "winter"], + }, + { + "description": "A lightweight puffer vest in a vibrant teal hue. " + "This versatile piece adds a layer of warmth without bulk, transitioning" + " perfectly between seasons.", + "price": 65.00, + "color": "teal", + "season": ["fall", "spring"], + }, + { + "description": "A pair of high-waisted leggings in a sleek black." + " Crafted from a moisture-wicking fabric with plenty of stretch, " + "these leggings are perfect for workouts or comfortable athleisure style.", + "price": 49.99, + "color": "black", + "season": ["fall", "winter", "summer", "spring"], + }, + { + "description": "A denim jacket with a faded wash and distressed details. " + "This wardrobe staple adds a touch of effortless cool to any outfit.", + "price": 79.99, + "color": "blue", + "season": ["fall", "spring", "summer"], + }, + { + "description": "A woven straw sunhat with a wide brim. This stylish " + "accessory provides protection from the sun while adding a touch of " + "summery elegance.", + "price": 32.00, + "color": "beige", + "season": ["summer"], + }, + { + "description": "A graphic tee featuring a vintage band logo. " + "Made from a soft cotton blend, this casual tee adds a touch of " + "personal style to everyday looks.", + "price": 24.99, + "color": "white", + "season": ["fall", "winter", "summer", "spring"], + }, + { + "description": "A pair of well-tailored dress pants in a neutral grey. " + "Made from a wrinkle-resistant blend, these pants look sharp and " + "professional for workwear or formal occasions.", + "price": 69.99, + "color": "grey", + "season": ["fall", "winter", "summer", "spring"], + }, + { + "description": "A pair of classic leather ankle boots in a rich brown hue." + " Featuring a subtle stacked heel and sleek design, these boots are perfect" + " for elevating outfits in cooler seasons.", + "price": 120.00, + "color": "brown", + "season": ["fall", "winter", "spring"], + }, + { + "description": "A vibrant swimsuit with a bold geometric pattern. This fun " + "and eye-catching piece is perfect for making a splash by the pool or at " + "the beach.", + "price": 55.00, + "color": "multicolor", + "season": ["summer"], + }, + ] + + documents = [] + for record in record_data: + record = record.copy() + page_content = record.pop("description") + if isinstance(page_content, str): + metadata = {**record} + document = Document(page_content=page_content, metadata=metadata) + documents.append(document) + + return documents From a06ac39d706fbf0ca6920001ecec1b1229fc562f Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 17:50:44 +0100 Subject: [PATCH 08/14] Fix typing import --- .../langchain_google_vertexai/vectorstores/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py index f79c9e36..a92a41fb 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py @@ -91,19 +91,19 @@ def to_data_points( raise ValueError("All metadata keys must be strings") if isinstance(value, str): - restriction = meidx_types.Restriction( + restriction = meidx_types.Index.Restriction( namespace=namespace, allow_list=[value] ) restricts.append(restriction) elif isinstance(value, list) and all( isinstance(item, str) for item in value ): - restriction = meidx_types.Restriction( + restriction = meidx_types.Index.Restriction( namespace=namespace, allow_list=value ) restricts.append(restriction) elif isinstance(value, (int, float)) and not isinstance(value, bool): - restriction = meidx_types.NumericRestriction( + restriction = meidx_types.Index.NumericRestriction( namespace=namespace, value_float=value ) numeric_restricts.append(restriction) From c5201b1b59ba7c9e0957c50241cf67a21b786844 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 17:52:03 +0100 Subject: [PATCH 09/14] Fix typing --- .../langchain_google_vertexai/vectorstores/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py index a92a41fb..50141572 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py @@ -91,19 +91,19 @@ def to_data_points( raise ValueError("All metadata keys must be strings") if isinstance(value, str): - restriction = meidx_types.Index.Restriction( + restriction = meidx_types.IndexDataPoint.Restriction( namespace=namespace, allow_list=[value] ) restricts.append(restriction) elif isinstance(value, list) and all( isinstance(item, str) for item in value ): - restriction = meidx_types.Index.Restriction( + restriction = meidx_types.IndexDataPoint.Restriction( namespace=namespace, allow_list=value ) restricts.append(restriction) elif isinstance(value, (int, float)) and not isinstance(value, bool): - restriction = meidx_types.Index.NumericRestriction( + restriction = meidx_types.IndexDataPoint.NumericRestriction( namespace=namespace, value_float=value ) numeric_restricts.append(restriction) From 2d23bf62efeff333a395023187bcd320d9b2ab43 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 17:53:47 +0100 Subject: [PATCH 10/14] Fix typo --- .../langchain_google_vertexai/vectorstores/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py index 50141572..5538aaf1 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py @@ -91,24 +91,24 @@ def to_data_points( raise ValueError("All metadata keys must be strings") if isinstance(value, str): - restriction = meidx_types.IndexDataPoint.Restriction( + restriction = meidx_types.IndexDatapoint.Restriction( namespace=namespace, allow_list=[value] ) restricts.append(restriction) elif isinstance(value, list) and all( isinstance(item, str) for item in value ): - restriction = meidx_types.IndexDataPoint.Restriction( + restriction = meidx_types.IndexDatapoint.Restriction( namespace=namespace, allow_list=value ) restricts.append(restriction) elif isinstance(value, (int, float)) and not isinstance(value, bool): - restriction = meidx_types.IndexDataPoint.NumericRestriction( + restriction = meidx_types.IndexDatapoint.NumericRestriction( namespace=namespace, value_float=value ) numeric_restricts.append(restriction) - data_point = meidx_types.IndexDataPoint( + data_point = meidx_types.IndexDatapoint( datapoint_id=id_, feature_vector=embedding, restricts=restricts, From 104a6dfac81176346086083889c17e69dacef474 Mon Sep 17 00:00:00 2001 From: Jorge Date: Tue, 12 Mar 2024 18:03:02 +0100 Subject: [PATCH 11/14] Fix Py3.8 --- libs/vertexai/tests/integration_tests/test_vectorstores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/vertexai/tests/integration_tests/test_vectorstores.py b/libs/vertexai/tests/integration_tests/test_vectorstores.py index f7ff2e31..d3e769ca 100644 --- a/libs/vertexai/tests/integration_tests/test_vectorstores.py +++ b/libs/vertexai/tests/integration_tests/test_vectorstores.py @@ -175,7 +175,7 @@ def test_vector_store_filtering(vector_store: VectorSearchVectorStore): @pytest.mark.extended -def test_vector_store_update_index(sample_documents: list[Document]): +def test_vector_store_update_index(sample_documents: List[Document]): embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default") vector_store = VectorSearchVectorStore.from_components( From c14239b1ba1ad29b77e67a792489c8717074b0b0 Mon Sep 17 00:00:00 2001 From: Jorge Date: Wed, 13 Mar 2024 12:31:41 +0100 Subject: [PATCH 12/14] Add stream updating --- .../langchain_google_vertexai/vectorstores/vectorstores.py | 6 +++++- libs/vertexai/tests/integration_tests/test_vectorstores.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py index 77fa08f7..582010a5 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py @@ -269,6 +269,7 @@ def from_components( # Implemented in order to keep the current API endpoint_id: str, credentials_path: Optional[str] = None, embedding: Optional[Embeddings] = None, + stream_update: bool = False, **kwargs: Any, ) -> "VectorSearchVectorStore": """Takes the object creation out of the constructor. @@ -284,6 +285,8 @@ def from_components( # Implemented in order to keep the current API the local file system. embedding: The :class:`Embeddings` that will be used for embedding the texts. + stream_update: Whether to update with streaming or batching. VectorSearch + index must be compatible with stream/batch updates. kwargs: Additional keyword arguments to pass to VertexAIVectorSearch.__init__(). Returns: @@ -300,7 +303,8 @@ def from_components( # Implemented in order to keep the current API return cls( document_storage=GCSDocumentStorage(bucket=bucket), searcher=VectorSearchSearcher( - endpoint=endpoint, index=index, staging_bucket=bucket + endpoint=endpoint, index=index, staging_bucket=bucket, + stream_update=stream_update ), embbedings=embedding, ) diff --git a/libs/vertexai/tests/integration_tests/test_vectorstores.py b/libs/vertexai/tests/integration_tests/test_vectorstores.py index d3e769ca..eada386b 100644 --- a/libs/vertexai/tests/integration_tests/test_vectorstores.py +++ b/libs/vertexai/tests/integration_tests/test_vectorstores.py @@ -170,7 +170,7 @@ def test_vector_store_filtering(vector_store: VectorSearchVectorStore): ) assert len(documents) > 0 - + assert all(document.metadata["color"] == "blue" for document in documents) assert all(document.metadata["price"] < 20.0 for document in documents) From 38400a30aaa3a6e3cf3e55d17cb4bd8770af9746 Mon Sep 17 00:00:00 2001 From: Jorge Date: Wed, 13 Mar 2024 12:32:16 +0100 Subject: [PATCH 13/14] Fix format --- .../langchain_google_vertexai/vectorstores/vectorstores.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py index 582010a5..11035bef 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/vectorstores.py @@ -303,8 +303,10 @@ def from_components( # Implemented in order to keep the current API return cls( document_storage=GCSDocumentStorage(bucket=bucket), searcher=VectorSearchSearcher( - endpoint=endpoint, index=index, staging_bucket=bucket, - stream_update=stream_update + endpoint=endpoint, + index=index, + staging_bucket=bucket, + stream_update=stream_update, ), embbedings=embedding, ) From 8f784906e4daae382a6c92b6ffb0c16f5a90e166 Mon Sep 17 00:00:00 2001 From: Jorge Date: Wed, 13 Mar 2024 21:12:11 +0100 Subject: [PATCH 14/14] Added warning for unused fields in filtering. Added unit tests for to_datapoint --- .../vectorstores/_utils.py | 11 +++++ .../tests/unit_tests/test_vectorstores.py | 49 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 libs/vertexai/tests/unit_tests/test_vectorstores.py diff --git a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py index 5538aaf1..fc6f7b77 100644 --- a/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py +++ b/libs/vertexai/langchain_google_vertexai/vectorstores/_utils.py @@ -1,5 +1,6 @@ import json import uuid +import warnings from typing import Any, Dict, List, Union from google.cloud.aiplatform import MatchingEngineIndex @@ -81,6 +82,7 @@ def to_data_points( metadatas = [{}] * len(ids) data_points = [] + ignored_fields = set() for id_, embedding, metadata in zip(ids, embeddings, metadatas): restricts = [] @@ -107,6 +109,15 @@ def to_data_points( namespace=namespace, value_float=value ) numeric_restricts.append(restriction) + else: + ignored_fields.add(namespace) + + if len(ignored_fields) > 0: + warnings.warn( + f"Some values in fields {', '.join(ignored_fields)} are not usable for" + f" restrictions. In order to be used they must be str, list[str] or" + f" numeric." + ) data_point = meidx_types.IndexDatapoint( datapoint_id=id_, diff --git a/libs/vertexai/tests/unit_tests/test_vectorstores.py b/libs/vertexai/tests/unit_tests/test_vectorstores.py new file mode 100644 index 00000000..9b1c30bc --- /dev/null +++ b/libs/vertexai/tests/unit_tests/test_vectorstores.py @@ -0,0 +1,49 @@ +import pytest + +from langchain_google_vertexai.vectorstores._utils import to_data_points + + +def test_to_data_points(): + ids = ["Id1"] + embeddings = [[0.0, 0.0]] + metadatas = [ + { + "some_string": "string", + "some_number": 1.1, + "some_list": ["a", "b"], + "some_random_object": {"foo": 1, "bar": 2}, + } + ] + + with pytest.warns(): + result = to_data_points(ids, embeddings, metadatas) + + assert isinstance(result, list) + assert len(result) == 1 + + datapoint = result[0] + datapoint.datapoint_id == "Id1" + for component_emb, component_fv in (datapoint.feature_vector, embeddings[0]): + assert component_emb == pytest.approx(component_fv) + + metadata = metadatas[0] + + restriction_lookup = { + restriction.namespace: restriction for restriction in datapoint.restricts + } + + restriction = restriction_lookup.pop("some_string") + assert restriction.allow_list == [metadata["some_string"]] + + restriction = restriction_lookup.pop("some_list") + assert restriction.allow_list == metadata["some_list"] + + assert len(restriction_lookup) == 0 + + num_restriction_lookup = { + restriction.namespace: restriction + for restriction in datapoint.numeric_restricts + } + restriction = num_restriction_lookup.pop("some_number") + assert round(restriction.value_float, 1) == pytest.approx(metadata["some_number"]) + assert len(num_restriction_lookup) == 0