Skip to content

Commit

Permalink
DB reads make their projection(+similarity) explicit where needed (#20)
Browse files Browse the repository at this point in the history
* all find invocations make their projection(+similarity) explicit where needed across all classes

* leave the _NOT_SET special object out of type hints for clarity to users
  • Loading branch information
hemidactylus authored Apr 30, 2024
1 parent f35247c commit d988159
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
15 changes: 10 additions & 5 deletions libs/astradb/langchain_astradb/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

logger = logging.getLogger(__name__)

_NOT_SET = object()


class AstraDBLoader(BaseLoader):
def __init__(
Expand All @@ -35,7 +37,7 @@ def __init__(
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
filter_criteria: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = _NOT_SET, # type: ignore[assignment]
find_options: Optional[Dict[str, Any]] = None,
nb_prefetched: int = 1000,
page_content_mapper: Callable[[Dict], str] = json.dumps,
Expand All @@ -55,7 +57,8 @@ def __init__(
namespace: namespace (aka keyspace) where the
collection is. Defaults to the database's "default namespace".
filter_criteria: Criteria to filter documents.
projection: Specifies the fields to return.
projection: Specifies the fields to return. If not provided, reads
fall back to the Data API default projection.
find_options: Additional options for the query.
nb_prefetched: Max number of documents to pre-fetch. Defaults to 1000.
page_content_mapper: Function applied to collection documents to create
Expand All @@ -72,7 +75,9 @@ def __init__(
)
self.astra_db_env = astra_db_env
self.filter = filter_criteria
self.projection = projection
self._projection: Optional[Dict[str, Any]] = (
projection if projection is not _NOT_SET else {"*": True}
)
self.find_options = find_options or {}
self.nb_prefetched = nb_prefetched
self.page_content_mapper = page_content_mapper
Expand All @@ -94,7 +99,7 @@ def lazy_load(self) -> Iterator[Document]:
for doc in self.astra_db_env.collection.paginated_find(
filter=self.filter,
options=self.find_options,
projection=self.projection,
projection=self._projection,
sort=None,
prefetched=self.nb_prefetched,
):
Expand All @@ -108,7 +113,7 @@ async def alazy_load(self) -> AsyncIterator[Document]:
async for doc in self.astra_db_env.async_collection.paginated_find(
filter=self.filter,
options=self.find_options,
projection=self.projection,
projection=self._projection,
sort=None,
prefetched=self.nb_prefetched,
):
Expand Down
8 changes: 6 additions & 2 deletions libs/astradb/langchain_astradb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,19 @@ def encode_value(self, value: Optional[V]) -> Any:
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
for doc in self.collection.paginated_find(
filter={"_id": {"$in": list(keys)}},
projection={"*": True},
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]

async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
filter={"_id": {"$in": list(keys)}}
filter={"_id": {"$in": list(keys)}},
projection={"*": True},
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
Expand Down

0 comments on commit d988159

Please sign in to comment.