-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor matching_engine (migration from lc monorepo)
- Loading branch information
Jorge
committed
Feb 19, 2024
1 parent
7938fe3
commit e443604
Showing
6 changed files
with
913 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
185 changes: 185 additions & 0 deletions
185
libs/vertexai/langchain_google_vertexai/vectorstores/_document_storage.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
from google.cloud import storage | ||
|
||
if TYPE_CHECKING: | ||
from google.cloud import datastore | ||
|
||
|
||
class DocumentStorage(ABC): | ||
"""Abstract interface of a key, text storage for retrieving documents.""" | ||
|
||
@abstractmethod | ||
def get_by_id(self, document_id: str) -> Union[str, None]: | ||
"""Gets the text of a document by its id. If not found, returns None. | ||
Args: | ||
document_id: Id of the document to get from the storage. | ||
Returns: | ||
Text of the document if found, otherwise None. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def store_by_id(self, document_id: str, text: str): | ||
"""Stores a document text associated to a document_id. | ||
Args: | ||
document_id: Id of the document to be stored. | ||
text: Text of the document to be stored. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None: | ||
"""Stores a list of ids and documents in batch. | ||
The default implementation only loops to the individual `store_by_id`. | ||
Subclasses that have faster ways to store data via batch uploading should | ||
implement the proper way. | ||
Args: | ||
ids: List of ids for the text. | ||
texts: List of texts. | ||
""" | ||
for id_, text in zip(ids, texts): | ||
self.store_by_id(id_, text) | ||
|
||
def batch_get_by_id(self, ids: List[str]) -> List[Union[str, None]]: | ||
"""Gets a batch of documents by id. | ||
The default implementation only loops `get_by_id`. | ||
Subclasses that have faster ways to retrieve data by batch should implement | ||
this method. | ||
Args: | ||
ids: List of ids for the text. | ||
Returns: | ||
List of texts. If the key id is not found for any id record returns a None | ||
instead. | ||
""" | ||
return [self.get_by_id(id_) for id_ in ids] | ||
|
||
|
||
class GCSDocumentStorage(DocumentStorage): | ||
"""Stores documents in Google Cloud Storage. | ||
For each pair id, document_text the name of the blob will be {prefix}/{id} stored | ||
in plain text format. | ||
""" | ||
|
||
def __init__( | ||
self, bucket: "storage.Bucket", prefix: Optional[str] = "documents" | ||
) -> None: | ||
"""Constructor. | ||
Args: | ||
bucket: Bucket where the documents will be stored. | ||
prefix: Prefix that is prepended to all document names. | ||
""" | ||
super().__init__() | ||
self._bucket = bucket | ||
self._prefix = prefix | ||
|
||
def get_by_id(self, document_id: str) -> Union[str, None]: | ||
"""Gets the text of a document by its id. If not found, returns None. | ||
Args: | ||
document_id: Id of the document to get from the storage. | ||
Returns: | ||
Text of the document if found, otherwise None. | ||
""" | ||
|
||
blob_name = self._get_blob_name(document_id) | ||
existing_blob = self._bucket.get_blob(blob_name) | ||
|
||
if existing_blob is None: | ||
return None | ||
|
||
return existing_blob.download_as_text() | ||
|
||
def store_by_id(self, document_id: str, text: str) -> None: | ||
"""Stores a document text associated to a document_id. | ||
Args: | ||
document_id: Id of the document to be stored. | ||
text: Text of the document to be stored. | ||
""" | ||
blob_name = self._get_blob_name(document_id) | ||
new_blow = self._bucket.blob(blob_name) | ||
new_blow.upload_from_string(text) | ||
|
||
def _get_blob_name(self, document_id: str) -> str: | ||
"""Builds a blob name using the prefix and the document_id. | ||
Args: | ||
document_id: Id of the document. | ||
Returns: | ||
Name of the blob that the document will be/is stored in | ||
""" | ||
return f"{self._prefix}/{document_id}" | ||
|
||
|
||
class DataStoreDocumentStorage(DocumentStorage): | ||
"""Stores documents in Google Cloud DataStore.""" | ||
|
||
def __init__( | ||
self, | ||
datastore_client: "datastore.Client", | ||
kind: str = "document_id", | ||
text_property_name: str = "text", | ||
) -> None: | ||
"""Constructor. | ||
Args: | ||
bucket: Bucket where the documents will be stored. | ||
prefix: Prefix that is prepended to all document names. | ||
""" | ||
super().__init__() | ||
self._client = datastore_client | ||
self._text_property_name = text_property_name | ||
self._kind = kind | ||
|
||
def get_by_id(self, document_id: str) -> Union[str, None]: | ||
"""Gets the text of a document by its id. If not found, returns None. | ||
Args: | ||
document_id: Id of the document to get from the storage. | ||
Returns: | ||
Text of the document if found, otherwise None. | ||
""" | ||
key = self._client.key(self._kind, document_id) | ||
entity = self._client.get(key) | ||
return entity[self._text_property_name] | ||
|
||
def store_by_id(self, document_id: str, text: str) -> None: | ||
"""Stores a document text associated to a document_id. | ||
Args: | ||
document_id: Id of the document to be stored. | ||
text: Text of the document to be stored. | ||
""" | ||
with self._client.transaction(): | ||
key = self._client.key(self._kind, document_id) | ||
entity = self._client.entity(key=key) | ||
entity[self._text_property_name] = text | ||
self._client.put(entity) | ||
|
||
def batch_get_by_id(self, ids: List[str]) -> List[Union[str, None]]: | ||
"""Gets a batch of documents by id. | ||
Args: | ||
ids: List of ids for the text. | ||
Returns: | ||
List of texts. If the key id is not found for any id record returns a None | ||
instead. | ||
""" | ||
keys = [self._client.key(self._kind, id_) for id_ in ids] | ||
|
||
# TODO: Handle when a key is not present | ||
entities = self._client.get_multi(keys) | ||
|
||
return [entity[self._text_property_name] for entity in entities] | ||
|
||
def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None: | ||
"""Stores a list of ids and documents in batch. | ||
Args: | ||
ids: List of ids for the text. | ||
texts: List of texts. | ||
""" | ||
|
||
with self._client.transaction(): | ||
keys = [self._client.key(self._kind, id_) for id_ in ids] | ||
|
||
entities = [] | ||
for key, text in zip(keys, texts): | ||
entity = self._client.entity(key=key) | ||
entity[self._text_property_name] = text | ||
entities.append(entity) | ||
|
||
self._client.put_multi(entities) |
116 changes: 116 additions & 0 deletions
116
libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from typing import TYPE_CHECKING, Any, Union | ||
|
||
from google.cloud import aiplatform, storage | ||
from google.cloud.aiplatform.matching_engine import ( | ||
MatchingEngineIndex, | ||
MatchingEngineIndexEndpoint, | ||
) | ||
from google.oauth2.service_account import Credentials | ||
|
||
if TYPE_CHECKING: | ||
from google.cloud import datastore | ||
|
||
|
||
class VectorSearchSDKManager: | ||
"""Class in charge of building all Google Cloud SDK Objects needed to build | ||
VectorStores from project_id, credentials or other specifications. Abstracts | ||
away the authentication layer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
project_id: str, | ||
region: str, | ||
credentials: Union[Credentials, None] = None, | ||
credentials_path: Union[str, None] = None, | ||
) -> None: | ||
"""Constructor. | ||
If `credentials` is provided, those credentials are used. If not provided | ||
`credentials_path` is used to retrieve credentials from a file. If also not | ||
provided, falls back to default credentials. | ||
Args: | ||
project_id: Id of the project. | ||
region: Region of the project. E.j. 'us-central1' | ||
credentials: Google cloud Credentials object. | ||
credentials_path: Google Cloud Credentials json file path. | ||
""" | ||
self._project_id = project_id | ||
self._region = region | ||
|
||
if credentials is not None: | ||
self._credentials = credentials | ||
elif credentials_path is not None: | ||
self._credentials = Credentials.from_service_account_file(credentials_path) | ||
else: | ||
self._credentials = None | ||
|
||
self.initialize_aiplatform() | ||
|
||
def initialize_aiplatform(self) -> None: | ||
"""Initializes aiplatform.""" | ||
aiplatform.init( | ||
project=self._project_id, | ||
location=self._region, | ||
credentials=self._credentials, | ||
) | ||
|
||
def get_gcs_client(self) -> storage.Client: | ||
"""Retrieves a Google Cloud Storage client. | ||
Returns: | ||
Google Cloud Storage Agent. | ||
""" | ||
return storage.Client(project=self._project_id, credentials=self._credentials) | ||
|
||
def get_gcs_bucket(self, bucket_name: str) -> storage.Bucket: | ||
"""Retrieves a Google Cloud Bucket by bucket name. | ||
Args: | ||
bucket_name: Name of the bucket to be retrieved. | ||
Returns: | ||
Google Cloud Bucket. | ||
""" | ||
client = self.get_gcs_client() | ||
return client.get_bucket(bucket_name) | ||
|
||
def get_index(self, index_id: str) -> MatchingEngineIndex: | ||
"""Retrieves a MatchingEngineIndex (VectorSearchIndex) by id. | ||
Args: | ||
index_id: Id of the index to be retrieved. | ||
Returns: | ||
MatchingEngineIndex instance. | ||
""" | ||
return MatchingEngineIndex( | ||
index_name=index_id, | ||
project=self._project_id, | ||
location=self._region, | ||
credentials=self._credentials, | ||
) | ||
|
||
def get_endpoint(self, endpoint_id: str) -> MatchingEngineIndexEndpoint: | ||
"""Retrieves a MatchingEngineIndexEndpoint (VectorSearchIndexEndpoint) by id. | ||
Args: | ||
endpoint_id: Id of the endpoint to be retrieved. | ||
Returns: | ||
MatchingEngineIndexEndpoint instance. | ||
""" | ||
return MatchingEngineIndexEndpoint( | ||
index_endpoint_name=endpoint_id, | ||
project=self._project_id, | ||
location=self._region, | ||
credentials=self._credentials, | ||
) | ||
|
||
def get_datastore_client(self, **kwargs: Any) -> "datastore.Client": | ||
"""Gets a datastore Client. | ||
Args: | ||
**kwargs: Keyword arguments to pass to datatastore.Client constructor. | ||
Returns: | ||
datastore Client. | ||
""" | ||
from google.cloud import datastore | ||
|
||
ds_client = datastore.Client( | ||
project=self._project_id, credentials=self._credentials, **kwargs | ||
) | ||
|
||
return ds_client |
Oops, something went wrong.