From d98815999eb10ac255ef57b67c8ab0b16077eb6e Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 30 Apr 2024 17:47:18 +0200 Subject: [PATCH] DB reads make their projection(+similarity) explicit where needed (#20) * all find invocations make their projection(+similarity) explicit where needed across all classes * leave the _NOT_SET special object out of type hints for clarity to users --- .../astradb/langchain_astradb/document_loaders.py | 15 ++++++++++----- libs/astradb/langchain_astradb/storage.py | 8 ++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/libs/astradb/langchain_astradb/document_loaders.py b/libs/astradb/langchain_astradb/document_loaders.py index ec6c078..3019c76 100644 --- a/libs/astradb/langchain_astradb/document_loaders.py +++ b/libs/astradb/langchain_astradb/document_loaders.py @@ -23,6 +23,8 @@ logger = logging.getLogger(__name__) +_NOT_SET = object() + class AstraDBLoader(BaseLoader): def __init__( @@ -35,7 +37,7 @@ def __init__( async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, filter_criteria: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + projection: Optional[Dict[str, Any]] = _NOT_SET, # type: ignore[assignment] find_options: Optional[Dict[str, Any]] = None, nb_prefetched: int = 1000, page_content_mapper: Callable[[Dict], str] = json.dumps, @@ -55,7 +57,8 @@ def __init__( namespace: namespace (aka keyspace) where the collection is. Defaults to the database's "default namespace". filter_criteria: Criteria to filter documents. - projection: Specifies the fields to return. + projection: Specifies the fields to return. If not provided, reads + fall back to the Data API default projection. find_options: Additional options for the query. nb_prefetched: Max number of documents to pre-fetch. Defaults to 1000. page_content_mapper: Function applied to collection documents to create @@ -72,7 +75,9 @@ def __init__( ) self.astra_db_env = astra_db_env self.filter = filter_criteria - self.projection = projection + self._projection: Optional[Dict[str, Any]] = ( + projection if projection is not _NOT_SET else {"*": True} + ) self.find_options = find_options or {} self.nb_prefetched = nb_prefetched self.page_content_mapper = page_content_mapper @@ -94,7 +99,7 @@ def lazy_load(self) -> Iterator[Document]: for doc in self.astra_db_env.collection.paginated_find( filter=self.filter, options=self.find_options, - projection=self.projection, + projection=self._projection, sort=None, prefetched=self.nb_prefetched, ): @@ -108,7 +113,7 @@ async def alazy_load(self) -> AsyncIterator[Document]: async for doc in self.astra_db_env.async_collection.paginated_find( filter=self.filter, options=self.find_options, - projection=self.projection, + projection=self._projection, sort=None, prefetched=self.nb_prefetched, ): diff --git a/libs/astradb/langchain_astradb/storage.py b/libs/astradb/langchain_astradb/storage.py index 1e1ec9a..ab3f097 100644 --- a/libs/astradb/langchain_astradb/storage.py +++ b/libs/astradb/langchain_astradb/storage.py @@ -57,7 +57,10 @@ def encode_value(self, value: Optional[V]) -> Any: def mget(self, keys: Sequence[str]) -> List[Optional[V]]: self.astra_env.ensure_db_setup() docs_dict = {} - for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}): + for doc in self.collection.paginated_find( + filter={"_id": {"$in": list(keys)}}, + projection={"*": True}, + ): docs_dict[doc["_id"]] = doc.get("value") return [self.decode_value(docs_dict.get(key)) for key in keys] @@ -65,7 +68,8 @@ async def amget(self, keys: Sequence[str]) -> List[Optional[V]]: await self.astra_env.aensure_db_setup() docs_dict = {} async for doc in self.async_collection.paginated_find( - filter={"_id": {"$in": list(keys)}} + filter={"_id": {"$in": list(keys)}}, + projection={"*": True}, ): docs_dict[doc["_id"]] = doc.get("value") return [self.decode_value(docs_dict.get(key)) for key in keys]