Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vectorstores] Added support for metadata in document storage #55

Merged
merged 14 commits into from
Mar 14, 2024
Original file line number Diff line number Diff line change
@@ -1,58 +1,60 @@
from __future__ import annotations

import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from google.cloud import storage # type: ignore[attr-defined]
from langchain_core.documents import Document

if TYPE_CHECKING:
from google.cloud import datastore # type: ignore[attr-defined]
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]


class DocumentStorage(ABC):
"""Abstract interface of a key, text storage for retrieving documents."""

@abstractmethod
def get_by_id(self, document_id: str) -> str | None:
"""Gets the text of a document by its id. If not found, returns None.
def get_by_id(self, document_id: str) -> Document | None:
"""Gets a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
Document if found, otherwise None.
"""
raise NotImplementedError()

@abstractmethod
def store_by_id(self, document_id: str, text: str):
"""Stores a document text associated to a document_id.
def store_by_id(self, document_id: str, document: Document):
"""Stores a document associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
document: Document to be stored.
"""
raise NotImplementedError()

def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None:
"""Stores a list of ids and documents in batch.
The default implementation only loops to the individual `store_by_id`.
Subclasses that have faster ways to store data via batch uploading should
implement the proper way.
Args:
ids: List of ids for the text.
texts: List of texts.
documents: List of documents.
"""
for id_, text in zip(ids, texts):
self.store_by_id(id_, text)
for id_, document in zip(ids, documents):
self.store_by_id(id_, document)

def batch_get_by_id(self, ids: List[str]) -> List[str | None]:
def batch_get_by_id(self, ids: List[str]) -> List[Document | None]:
"""Gets a batch of documents by id.
The default implementation only loops `get_by_id`.
Subclasses that have faster ways to retrieve data by batch should implement
this method.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
List of documents. If the key id is not found for any id record returns a
None instead.
"""
return [self.get_by_id(id_) for id_ in ids]

Expand All @@ -75,12 +77,12 @@ def __init__(
self._bucket = bucket
self._prefix = prefix

def get_by_id(self, document_id: str) -> str | None:
def get_by_id(self, document_id: str) -> Document | None:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
Document if found, otherwise None.
"""

blob_name = self._get_blob_name(document_id)
Expand All @@ -89,17 +91,22 @@ def get_by_id(self, document_id: str) -> str | None:
if existing_blob is None:
return None

return existing_blob.download_as_text()
document_str = existing_blob.download_as_text()
document_json: Dict[str, Any] = json.loads(document_str)
return Document(**document_json)

def store_by_id(self, document_id: str, text: str) -> None:
def store_by_id(self, document_id: str, document: Document) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
document: Document to be stored.
"""
blob_name = self._get_blob_name(document_id)
new_blow = self._bucket.blob(blob_name)
new_blow.upload_from_string(text)

document_json = document.dict()
document_text = json.dumps(document_json)
new_blow.upload_from_string(document_text)

def _get_blob_name(self, document_id: str) -> str:
"""Builds a blob name using the prefix and the document_id.
Expand All @@ -119,6 +126,7 @@ def __init__(
datastore_client: datastore.Client,
kind: str = "document_id",
text_property_name: str = "text",
metadata_property_name: str = "metadata",
) -> None:
"""Constructor.
Args:
Expand All @@ -128,9 +136,10 @@ def __init__(
super().__init__()
self._client = datastore_client
self._text_property_name = text_property_name
self._metadata_property_name = metadata_property_name
self._kind = kind

def get_by_id(self, document_id: str) -> str | None:
def get_by_id(self, document_id: str) -> Document | None:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Expand All @@ -139,21 +148,31 @@ def get_by_id(self, document_id: str) -> str | None:
"""
key = self._client.key(self._kind, document_id)
entity = self._client.get(key)
return entity[self._text_property_name]

def store_by_id(self, document_id: str, text: str) -> None:
if entity is None:
return None

return Document(
page_content=entity[self._text_property_name],
metadata=self._convert_entity_to_dict(entity[self._metadata_property_name]),
)

def store_by_id(self, document_id: str, document: Document) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
with self._client.transaction():
key = self._client.key(self._kind, document_id)

entity = self._client.entity(key=key)
entity[self._text_property_name] = text
entity[self._text_property_name] = document.page_content
entity[self._metadata_property_name] = document.metadata

self._client.put(entity)

def batch_get_by_id(self, ids: List[str]) -> List[str | None]:
def batch_get_by_id(self, ids: List[str]) -> List[Document | None]:
"""Gets a batch of documents by id.
Args:
ids: List of ids for the text.
Expand All @@ -166,9 +185,22 @@ def batch_get_by_id(self, ids: List[str]) -> List[str | None]:
# TODO: Handle when a key is not present
entities = self._client.get_multi(keys)

return [entity[self._text_property_name] for entity in entities]

def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
# Entities are not sorted by key by default, the order is unclear. This orders
# the list by the id retrieved.
entity_id_lookup = {entity.key.id_or_name: entity for entity in entities}
entities = [entity_id_lookup[id_] for id_ in ids]

return [
Document(
page_content=entity[self._text_property_name],
metadata=self._convert_entity_to_dict(
entity[self._metadata_property_name]
),
)
for entity in entities
]

def batch_store_by_id(self, ids: List[str], documents: List[Document]) -> None:
"""Stores a list of ids and documents in batch.
Args:
ids: List of ids for the text.
Expand All @@ -179,9 +211,21 @@ def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
keys = [self._client.key(self._kind, id_) for id_ in ids]

entities = []
for key, text in zip(keys, texts):
for key, document in zip(keys, documents):
entity = self._client.entity(key=key)
entity[self._text_property_name] = text
entity[self._text_property_name] = document.page_content
entity[self._metadata_property_name] = document.metadata
entities.append(entity)

self._client.put_multi(entities)

def _convert_entity_to_dict(self, entity: datastore.Entity) -> Dict[str, Any]:
"""Recursively transform an entity into a plain dictionary."""
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]

dict_entity = dict(entity)
for key in dict_entity:
value = dict_entity[key]
if isinstance(value, datastore.Entity):
dict_entity[key] = self._convert_entity_to_dict(value)
return dict_entity
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from google.oauth2.service_account import Credentials # type: ignore

if TYPE_CHECKING:
from google.cloud import datastore # type: ignore[attr-defined]
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]


class VectorSearchSDKManager:
Expand Down Expand Up @@ -107,7 +107,7 @@ def get_datastore_client(self, **kwargs: Any) -> "datastore.Client":
Returns:
datastore Client.
"""
from google.cloud import datastore # type: ignore[attr-defined]
from google.cloud import datastore # type: ignore[attr-defined, unused-ignore]

ds_client = datastore.Client(
project=self._project_id, credentials=self._credentials, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,13 @@ def similarity_search_by_vector_with_score(
results = []

for neighbor_id, distance in neighbors_list[0]:
text = self._document_storage.get_by_id(neighbor_id)
document = self._document_storage.get_by_id(neighbor_id)

if text is None:
if document is None:
raise ValueError(
f"Document with id {neighbor_id} not found in document" "storage."
)
# TODO: Handle metadata
document = Document(page_content=text, metadata={})

results.append((document, distance))

return results
Expand Down Expand Up @@ -164,8 +163,24 @@ def add_texts(
texts = list(texts)
ids = self._generate_unique_ids(len(texts))

self._document_storage.batch_store_by_id(ids=ids, texts=texts)
if metadatas is None:
metadatas = [{}] * len(texts)

if len(metadatas) != len(texts):
raise ValueError(
"`metadatas` should be the same length as `texts` "
f"{len(metadatas)} != {len(texts)}"
)

documents = [
Document(page_content=text, metadata=metadata)
jzaldi marked this conversation as resolved.
Show resolved Hide resolved
for text, metadata in zip(texts, metadatas)
]

self._document_storage.batch_store_by_id(ids=ids, documents=documents)

embeddings = self._embeddings.embed_documents(texts)

self._searcher.add_to_index(ids, embeddings, metadatas, **kwargs)

return ids
Expand Down
Loading
Loading