Skip to content

Commit

Permalink
Refactor matching_engine (migration from lc monorepo)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorge committed Feb 19, 2024
1 parent 7938fe3 commit e443604
Show file tree
Hide file tree
Showing 6 changed files with 913 additions and 0 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Union

from google.cloud import storage

if TYPE_CHECKING:
from google.cloud import datastore


class DocumentStorage(ABC):
"""Abstract interface of a key, text storage for retrieving documents."""

@abstractmethod
def get_by_id(self, document_id: str) -> Union[str, None]:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""
raise NotImplementedError()

@abstractmethod
def store_by_id(self, document_id: str, text: str):
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
raise NotImplementedError()

def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
"""Stores a list of ids and documents in batch.
The default implementation only loops to the individual `store_by_id`.
Subclasses that have faster ways to store data via batch uploading should
implement the proper way.
Args:
ids: List of ids for the text.
texts: List of texts.
"""
for id_, text in zip(ids, texts):
self.store_by_id(id_, text)

def batch_get_by_id(self, ids: List[str]) -> List[Union[str, None]]:
"""Gets a batch of documents by id.
The default implementation only loops `get_by_id`.
Subclasses that have faster ways to retrieve data by batch should implement
this method.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
"""
return [self.get_by_id(id_) for id_ in ids]


class GCSDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud Storage.
For each pair id, document_text the name of the blob will be {prefix}/{id} stored
in plain text format.
"""

def __init__(
self, bucket: "storage.Bucket", prefix: Optional[str] = "documents"
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
super().__init__()
self._bucket = bucket
self._prefix = prefix

def get_by_id(self, document_id: str) -> Union[str, None]:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""

blob_name = self._get_blob_name(document_id)
existing_blob = self._bucket.get_blob(blob_name)

if existing_blob is None:
return None

return existing_blob.download_as_text()

def store_by_id(self, document_id: str, text: str) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
blob_name = self._get_blob_name(document_id)
new_blow = self._bucket.blob(blob_name)
new_blow.upload_from_string(text)

def _get_blob_name(self, document_id: str) -> str:
"""Builds a blob name using the prefix and the document_id.
Args:
document_id: Id of the document.
Returns:
Name of the blob that the document will be/is stored in
"""
return f"{self._prefix}/{document_id}"


class DataStoreDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud DataStore."""

def __init__(
self,
datastore_client: "datastore.Client",
kind: str = "document_id",
text_property_name: str = "text",
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
super().__init__()
self._client = datastore_client
self._text_property_name = text_property_name
self._kind = kind

def get_by_id(self, document_id: str) -> Union[str, None]:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""
key = self._client.key(self._kind, document_id)
entity = self._client.get(key)
return entity[self._text_property_name]

def store_by_id(self, document_id: str, text: str) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
with self._client.transaction():
key = self._client.key(self._kind, document_id)
entity = self._client.entity(key=key)
entity[self._text_property_name] = text
self._client.put(entity)

def batch_get_by_id(self, ids: List[str]) -> List[Union[str, None]]:
"""Gets a batch of documents by id.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
"""
keys = [self._client.key(self._kind, id_) for id_ in ids]

# TODO: Handle when a key is not present
entities = self._client.get_multi(keys)

return [entity[self._text_property_name] for entity in entities]

def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
"""Stores a list of ids and documents in batch.
Args:
ids: List of ids for the text.
texts: List of texts.
"""

with self._client.transaction():
keys = [self._client.key(self._kind, id_) for id_ in ids]

entities = []
for key, text in zip(keys, texts):
entity = self._client.entity(key=key)
entity[self._text_property_name] = text
entities.append(entity)

self._client.put_multi(entities)
116 changes: 116 additions & 0 deletions libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import TYPE_CHECKING, Any, Union

from google.cloud import aiplatform, storage
from google.cloud.aiplatform.matching_engine import (
MatchingEngineIndex,
MatchingEngineIndexEndpoint,
)
from google.oauth2.service_account import Credentials

if TYPE_CHECKING:
from google.cloud import datastore


class VectorSearchSDKManager:
"""Class in charge of building all Google Cloud SDK Objects needed to build
VectorStores from project_id, credentials or other specifications. Abstracts
away the authentication layer.
"""

def __init__(
self,
*,
project_id: str,
region: str,
credentials: Union[Credentials, None] = None,
credentials_path: Union[str, None] = None,
) -> None:
"""Constructor.
If `credentials` is provided, those credentials are used. If not provided
`credentials_path` is used to retrieve credentials from a file. If also not
provided, falls back to default credentials.
Args:
project_id: Id of the project.
region: Region of the project. E.j. 'us-central1'
credentials: Google cloud Credentials object.
credentials_path: Google Cloud Credentials json file path.
"""
self._project_id = project_id
self._region = region

if credentials is not None:
self._credentials = credentials
elif credentials_path is not None:
self._credentials = Credentials.from_service_account_file(credentials_path)
else:
self._credentials = None

self.initialize_aiplatform()

def initialize_aiplatform(self) -> None:
"""Initializes aiplatform."""
aiplatform.init(
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_gcs_client(self) -> storage.Client:
"""Retrieves a Google Cloud Storage client.
Returns:
Google Cloud Storage Agent.
"""
return storage.Client(project=self._project_id, credentials=self._credentials)

def get_gcs_bucket(self, bucket_name: str) -> storage.Bucket:
"""Retrieves a Google Cloud Bucket by bucket name.
Args:
bucket_name: Name of the bucket to be retrieved.
Returns:
Google Cloud Bucket.
"""
client = self.get_gcs_client()
return client.get_bucket(bucket_name)

def get_index(self, index_id: str) -> MatchingEngineIndex:
"""Retrieves a MatchingEngineIndex (VectorSearchIndex) by id.
Args:
index_id: Id of the index to be retrieved.
Returns:
MatchingEngineIndex instance.
"""
return MatchingEngineIndex(
index_name=index_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_endpoint(self, endpoint_id: str) -> MatchingEngineIndexEndpoint:
"""Retrieves a MatchingEngineIndexEndpoint (VectorSearchIndexEndpoint) by id.
Args:
endpoint_id: Id of the endpoint to be retrieved.
Returns:
MatchingEngineIndexEndpoint instance.
"""
return MatchingEngineIndexEndpoint(
index_endpoint_name=endpoint_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_datastore_client(self, **kwargs: Any) -> "datastore.Client":
"""Gets a datastore Client.
Args:
**kwargs: Keyword arguments to pass to datatastore.Client constructor.
Returns:
datastore Client.
"""
from google.cloud import datastore

ds_client = datastore.Client(
project=self._project_id, credentials=self._credentials, **kwargs
)

return ds_client
Loading

0 comments on commit e443604

Please sign in to comment.