From 89209dec446a03b0ed3c619c9aef004d9d3c1356 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 5 Mar 2024 02:26:38 +0100 Subject: [PATCH] Cursor/AsyncCursor, find and distinct (#234) * insert_many(sync)+test, find(sync)+test * tests for find(sync); distinct in cursor(sync)+tests; distinct in collection(sync)+tests * insert_many (sync/async) + tests * refactor into base cursor and cursor * async cursors * distinct(s/a + tests), async cursors(+ tests) --- astrapy/db.py | 27 +- astrapy/idiomatic/collection.py | 175 +++++++- astrapy/idiomatic/cursors.py | 413 ++++++++++++++++++ astrapy/idiomatic/results.py | 8 +- astrapy/idiomatic/types.py | 21 + tests/idiomatic/integration/test_dml_async.py | 288 ++++++++++++ tests/idiomatic/integration/test_dml_sync.py | 277 ++++++++++++ .../idiomatic/unit/test_collections_async.py | 2 - tests/idiomatic/unit/test_collections_sync.py | 2 - 9 files changed, 1186 insertions(+), 27 deletions(-) create mode 100644 astrapy/idiomatic/cursors.py create mode 100644 astrapy/idiomatic/types.py diff --git a/astrapy/db.py b/astrapy/db.py index e2733148..9250d861 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -20,6 +20,12 @@ import json import threading +from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Generator, + Iterator, +) from concurrent.futures import ThreadPoolExecutor from functools import partial from queue import Queue @@ -28,14 +34,11 @@ Any, cast, Dict, - Iterable, List, Optional, Tuple, Union, Type, - AsyncIterable, - AsyncGenerator, ) from astrapy import __version__ @@ -119,7 +122,7 @@ def __init__( self.caller_name = self.astra_db.caller_name self.caller_version = self.astra_db.caller_version self.collection_name = collection_name - self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" + self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" def __repr__(self) -> str: return f'AstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' @@ -280,7 +283,7 @@ def find( self, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = {}, + sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, ) -> API_RESPONSE: """ @@ -356,7 +359,7 @@ def paginate( request_method: PaginableRequestMethod, options: Optional[Dict[str, Any]], prefetched: Optional[int] = None, - ) -> Iterable[API_DOC]: + ) -> Generator[API_DOC, None, None]: """ Generate paginated results for a given database query method. Args: @@ -415,7 +418,7 @@ def paginated_find( sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, prefetched: Optional[int] = None, - ) -> Iterable[API_DOC]: + ) -> Iterator[API_DOC]: """ Perform a paginated search in the collection. Args: @@ -1156,7 +1159,7 @@ def __init__( self.caller_version = self.astra_db.caller_version self.client = astra_db.client self.collection_name = collection_name - self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" + self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" def __repr__(self) -> str: return f'AsyncAstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]' @@ -1318,7 +1321,7 @@ async def find( self, filter: Optional[Dict[str, Any]] = None, projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = {}, + sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, ) -> API_RESPONSE: """ @@ -1449,7 +1452,7 @@ def paginated_find( sort: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, prefetched: Optional[int] = None, - ) -> AsyncIterable[API_DOC]: + ) -> AsyncIterator[API_DOC]: """ Perform a paginated search in the collection. Args: @@ -2141,7 +2144,7 @@ def __init__( self.namespace = namespace # Finally, construct the full base path - self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}" + self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}" def __repr__(self) -> str: return f'AstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' @@ -2428,7 +2431,7 @@ def __init__( self.namespace = namespace # Finally, construct the full base path - self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}" + self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}" def __repr__(self) -> str: return f'AsyncAstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]' diff --git a/astrapy/idiomatic/collection.py b/astrapy/idiomatic/collection.py index 4db416f4..5c7319e2 100644 --- a/astrapy/idiomatic/collection.py +++ b/astrapy/idiomatic/collection.py @@ -15,12 +15,16 @@ from __future__ import annotations import json -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, List, Optional from astrapy.db import AstraDBCollection, AsyncAstraDBCollection +from astrapy.idiomatic.types import DocumentType, ProjectionType from astrapy.idiomatic.utils import raise_unsupported_parameter, unsupported from astrapy.idiomatic.database import AsyncDatabase, Database -from astrapy.idiomatic.results import DeleteResult, InsertOneResult +from astrapy.idiomatic.results import DeleteResult, InsertManyResult, InsertOneResult +from astrapy.idiomatic.cursors import AsyncCursor, Cursor + +INSERT_MANY_CONCURRENCY = 20 class Collection: @@ -110,7 +114,7 @@ def set_caller( def insert_one( self, - document: Dict[str, Any], + document: DocumentType, *, bypass_document_validation: Optional[bool] = None, ) -> InsertOneResult: @@ -138,6 +142,84 @@ def insert_one( f"(gotten '${json.dumps(io_response)}')" ) + def insert_many( + self, + documents: Iterable[DocumentType], + *, + ordered: bool = True, + bypass_document_validation: Optional[bool] = None, + ) -> InsertManyResult: + if bypass_document_validation: + raise_unsupported_parameter( + class_name=self.__class__.__name__, + method_name="insert_many", + parameter_name="bypass_document_validation", + ) + if ordered: + cim_responses = self._astra_db_collection.chunked_insert_many( + documents=list(documents), + options={"ordered": True}, + partial_failures_allowed=False, + concurrency=1, + ) + else: + # unordered insertion: can do chunks concurrently + cim_responses = self._astra_db_collection.chunked_insert_many( + documents=list(documents), + options={"ordered": False}, + partial_failures_allowed=True, + concurrency=INSERT_MANY_CONCURRENCY, + ) + _exceptions = [cim_r for cim_r in cim_responses if isinstance(cim_r, Exception)] + _errors_in_response = [ + err + for response in cim_responses + if isinstance(response, dict) + for err in (response.get("errors") or []) + ] + if _exceptions: + raise _exceptions[0] + elif _errors_in_response: + raise ValueError(str(_errors_in_response[0])) + else: + inserted_ids = [ + ins_id + for response in cim_responses + if isinstance(response, dict) + for ins_id in (response.get("status") or {}).get("insertedIds", []) + ] + return InsertManyResult(inserted_ids=inserted_ids) + + def find( + self, + filter: Optional[Dict[str, Any]] = None, + *, + projection: Optional[ProjectionType] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + sort: Optional[Dict[str, Any]] = None, + ) -> Cursor: + return ( + Cursor( + collection=self, + filter=filter, + projection=projection, + ) + .skip(skip) + .limit(limit) + .sort(sort) + ) + + def distinct( + self, + key: str, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Any]: + return self.find( + filter=filter, + projection={key: True}, + ).distinct(key) + def count_documents( self, filter: Dict[str, Any], @@ -278,9 +360,6 @@ def list_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ... @unsupported def update_search_index(*pargs: Any, **kwargs: Any) -> Any: ... - @unsupported - def distinct(*pargs: Any, **kwargs: Any) -> Any: ... - class AsyncCollection: def __init__( @@ -369,7 +448,7 @@ def set_caller( async def insert_one( self, - document: Dict[str, Any], + document: DocumentType, *, bypass_document_validation: Optional[bool] = None, ) -> InsertOneResult: @@ -397,6 +476,85 @@ async def insert_one( f"(gotten '${json.dumps(io_response)}')" ) + async def insert_many( + self, + documents: Iterable[DocumentType], + *, + ordered: bool = True, + bypass_document_validation: Optional[bool] = None, + ) -> InsertManyResult: + if bypass_document_validation: + raise_unsupported_parameter( + class_name=self.__class__.__name__, + method_name="insert_many", + parameter_name="bypass_document_validation", + ) + if ordered: + cim_responses = await self._astra_db_collection.chunked_insert_many( + documents=list(documents), + options={"ordered": True}, + partial_failures_allowed=False, + concurrency=1, + ) + else: + # unordered insertion: can do chunks concurrently + cim_responses = await self._astra_db_collection.chunked_insert_many( + documents=list(documents), + options={"ordered": False}, + partial_failures_allowed=True, + concurrency=INSERT_MANY_CONCURRENCY, + ) + _exceptions = [cim_r for cim_r in cim_responses if isinstance(cim_r, Exception)] + _errors_in_response = [ + err + for response in cim_responses + if isinstance(response, dict) + for err in (response.get("errors") or []) + ] + if _exceptions: + raise _exceptions[0] + elif _errors_in_response: + raise ValueError(str(_errors_in_response[0])) + else: + inserted_ids = [ + ins_id + for response in cim_responses + if isinstance(response, dict) + for ins_id in (response.get("status") or {}).get("insertedIds", []) + ] + return InsertManyResult(inserted_ids=inserted_ids) + + def find( + self, + filter: Optional[Dict[str, Any]] = None, + *, + projection: Optional[ProjectionType] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, + sort: Optional[Dict[str, Any]] = None, + ) -> AsyncCursor: + return ( + AsyncCursor( + collection=self, + filter=filter, + projection=projection, + ) + .skip(skip) + .limit(limit) + .sort(sort) + ) + + async def distinct( + self, + key: str, + filter: Optional[Dict[str, Any]] = None, + ) -> List[Any]: + cursor = self.find( + filter=filter, + projection={key: True}, + ) + return await cursor.distinct(key) + async def count_documents( self, filter: Dict[str, Any], @@ -538,6 +696,3 @@ async def list_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ... @unsupported async def update_search_index(*pargs: Any, **kwargs: Any) -> Any: ... - - @unsupported - async def distinct(*pargs: Any, **kwargs: Any) -> Any: ... diff --git a/astrapy/idiomatic/cursors.py b/astrapy/idiomatic/cursors.py new file mode 100644 index 00000000..1d48fdd0 --- /dev/null +++ b/astrapy/idiomatic/cursors.py @@ -0,0 +1,413 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterator, AsyncIterator +from typing import ( + Any, + Dict, + List, + Optional, + TypeVar, + Union, + TYPE_CHECKING, +) + +from astrapy.idiomatic.types import DocumentType, ProjectionType + +if TYPE_CHECKING: + from astrapy.idiomatic.collection import AsyncCollection, Collection + + +BC = TypeVar("BC", bound="BaseCursor") + +FIND_PREFETCH = 20 + + +class BaseCursor: + _collection: Union[Collection, AsyncCollection] + _filter: Optional[Dict[str, Any]] + _projection: Optional[ProjectionType] + _limit: Optional[int] + _skip: Optional[int] + _sort: Optional[Dict[str, Any]] + _started: bool + _retrieved: int + _alive: bool + _iterator: Optional[Union[Iterator[DocumentType], AsyncIterator[DocumentType]]] = ( + None + ) + + def __init__( + self, + collection: Union[Collection, AsyncCollection], + filter: Optional[Dict[str, Any]], + projection: Optional[ProjectionType], + ) -> None: + raise NotImplementedError + + def __getitem__(self: BC, index: Union[int, slice]) -> Union[BC, DocumentType]: + self._ensure_not_started() + self._ensure_alive() + if isinstance(index, int): + # In this case, a separate cursor is run, not touching self + return self._item_at_index(index) + elif isinstance(index, slice): + start = index.start + stop = index.stop + step = index.step + if step is not None and step != 1: + raise ValueError("Cursor slicing cannot have arbitrary step") + _skip = start + _limit = stop - start + return self.limit(_limit).skip(_skip) + else: + raise TypeError( + f"cursor indices must be integers or slices, not {type(index).__name__}" + ) + + def __repr__(self) -> str: + _state_desc: str + if self._started: + if self._alive: + _state_desc = "running" + else: + _state_desc = "exhausted" + else: + _state_desc = "new" + return ( + f'{self.__class__.__name__}("{self._collection.name}", ' + f"{_state_desc}, " + f"retrieved: {self.retrieved})" + ) + + def _item_at_index(self, index: int) -> DocumentType: + # subclasses must implement this + raise NotImplementedError + + def _ensure_alive(self) -> None: + if not self._alive: + raise ValueError("Cursor is closed.") + + def _ensure_not_started(self) -> None: + if self._started: + raise ValueError("Cursor has already been used") + + def _copy( + self: BC, + *, + limit: Optional[int] = None, + skip: Optional[int] = None, + started: Optional[bool] = None, + sort: Optional[Dict[str, Any]] = None, + ) -> BC: + new_cursor = self.__class__( + collection=self._collection, + filter=self._filter, + projection=self._projection, + ) + # Cursor treated as mutable within this function scope: + new_cursor._limit = limit if limit is not None else self._limit + new_cursor._skip = skip if skip is not None else self._skip + new_cursor._started = started if started is not None else self._started + new_cursor._sort = sort if sort is not None else self._sort + if started is False: + new_cursor._retrieved = 0 + new_cursor._alive = True + else: + new_cursor._retrieved = self._retrieved + new_cursor._alive = self._alive + return new_cursor + + @property + def address(self) -> str: + """Return the api_endpoint used by this cursor.""" + return self._collection._astra_db_collection.base_path + + @property + def alive(self) -> bool: + return self._alive + + def clone(self: BC) -> BC: + return self._copy(started=False) + + def close(self) -> None: + self._alive = False + + @property + def cursor_id(self) -> int: + return id(self) + + def limit(self: BC, limit: Optional[int]) -> BC: + self._ensure_not_started() + self._ensure_alive() + self._limit = limit if limit != 0 else None + return self + + @property + def retrieved(self) -> int: + return self._retrieved + + def rewind(self: BC) -> BC: + self._started = False + self._retrieved = 0 + self._alive = True + self._iterator = None + return self + + def skip(self: BC, skip: Optional[int]) -> BC: + self._ensure_not_started() + self._ensure_alive() + self._skip = skip + return self + + def sort( + self: BC, + sort: Optional[Dict[str, Any]], + ) -> BC: + self._ensure_not_started() + self._ensure_alive() + self._sort = sort + return self + + +class Cursor(BaseCursor): + def __init__( + self, + collection: Collection, + filter: Optional[Dict[str, Any]], + projection: Optional[ProjectionType], + ) -> None: + self._collection: Collection = collection + self._filter = filter + self._projection = projection + self._limit: Optional[int] = None + self._skip: Optional[int] = None + self._sort: Optional[Dict[str, Any]] = None + self._started = False + self._retrieved = 0 + self._alive = True + # + self._iterator: Optional[Iterator[DocumentType]] = None + + def __iter__(self) -> Cursor: + self._ensure_alive() + if self._iterator is None: + self._iterator = self._create_iterator() + self._started = True + return self + + def __next__(self) -> DocumentType: + if not self.alive: + # keep raising once exhausted: + raise StopIteration + if self._iterator is None: + self._iterator = self._create_iterator() + self._started = True + try: + next_item = self._iterator.__next__() + self._retrieved = self._retrieved + 1 + return next_item + except StopIteration: + self._alive = False + raise + + def _item_at_index(self, index: int) -> DocumentType: + finder_cursor = self._copy().skip(index).limit(1) + items = list(finder_cursor) + if items: + return items[0] + else: + raise IndexError("no such item for Cursor instance") + + def _create_iterator(self) -> Iterator[DocumentType]: + self._ensure_not_started() + self._ensure_alive() + _options = { + k: v + for k, v in { + "limit": self._limit, + "skip": self._skip, + }.items() + if v is not None + } + + # recast parameters for paginated_find call + pf_projection: Optional[Dict[str, bool]] + if self._projection: + if isinstance(self._projection, dict): + pf_projection = self._projection + else: + # an iterable over strings + pf_projection = {field: True for field in self._projection} + else: + pf_projection = None + pf_sort: Optional[Dict[str, int]] + if self._sort: + pf_sort = dict(self._sort) + else: + pf_sort = None + + iterator = self._collection._astra_db_collection.paginated_find( + filter=self._filter, + projection=pf_projection, + sort=pf_sort, + options=_options, + prefetched=FIND_PREFETCH, + ) + return iterator + + @property + def collection(self) -> Collection: + return self._collection + + def distinct(self, key: str) -> List[Any]: + """ + This works on a fresh pristine copy of the cursor + and never touches self in any way. + """ + return list( + {document[key] for document in self._copy(started=False) if key in document} + ) + + +class AsyncCursor(BaseCursor): + def __init__( + self, + collection: AsyncCollection, + filter: Optional[Dict[str, Any]], + projection: Optional[ProjectionType], + ) -> None: + self._collection: AsyncCollection = collection + self._filter = filter + self._projection = projection + self._limit: Optional[int] = None + self._skip: Optional[int] = None + self._sort: Optional[Dict[str, Any]] = None + self._started = False + self._retrieved = 0 + self._alive = True + # + self._iterator: Optional[AsyncIterator[DocumentType]] = None + + def __aiter__(self) -> AsyncCursor: + self._ensure_alive() + if self._iterator is None: + self._iterator = self._create_iterator() + self._started = True + return self + + async def __anext__(self) -> DocumentType: + if not self.alive: + # keep raising once exhausted: + raise StopAsyncIteration + if self._iterator is None: + self._iterator = self._create_iterator() + self._started = True + try: + next_item = await self._iterator.__anext__() + self._retrieved = self._retrieved + 1 + return next_item + except StopAsyncIteration: + self._alive = False + raise + + def _item_at_index(self, index: int) -> DocumentType: + finder_cursor = self._to_sync().skip(index).limit(1) + items = list(finder_cursor) + if items: + return items[0] + else: + raise IndexError("no such item for AsyncCursor instance") + + def _create_iterator(self) -> AsyncIterator[DocumentType]: + self._ensure_not_started() + self._ensure_alive() + _options = { + k: v + for k, v in { + "limit": self._limit, + "skip": self._skip, + }.items() + if v is not None + } + + # recast parameters for paginated_find call + pf_projection: Optional[Dict[str, bool]] + if self._projection: + if isinstance(self._projection, dict): + pf_projection = self._projection + else: + # an iterable over strings + pf_projection = {field: True for field in self._projection} + else: + pf_projection = None + pf_sort: Optional[Dict[str, int]] + if self._sort: + pf_sort = dict(self._sort) + else: + pf_sort = None + + iterator = self._collection._astra_db_collection.paginated_find( + filter=self._filter, + projection=pf_projection, + sort=pf_sort, + options=_options, + prefetched=FIND_PREFETCH, + ) + return iterator + + def _to_sync( + self: AsyncCursor, + *, + limit: Optional[int] = None, + skip: Optional[int] = None, + started: Optional[bool] = None, + sort: Optional[Dict[str, Any]] = None, + ) -> Cursor: + new_cursor = Cursor( + collection=self._collection.to_sync(), + filter=self._filter, + projection=self._projection, + ) + # Cursor treated as mutable within this function scope: + new_cursor._limit = limit if limit is not None else self._limit + new_cursor._skip = skip if skip is not None else self._skip + new_cursor._started = started if started is not None else self._started + new_cursor._sort = sort if sort is not None else self._sort + if started is False: + new_cursor._retrieved = 0 + new_cursor._alive = True + else: + new_cursor._retrieved = self._retrieved + new_cursor._alive = self._alive + return new_cursor + + @property + def collection(self) -> AsyncCollection: + return self._collection + + async def distinct(self, key: str) -> List[Any]: + """ + This works on a fresh pristine copy of the cursor + and never touches self in any way. + """ + return list( + { + document[key] + async for document in self._copy(started=False) + if key in document + } + ) diff --git a/astrapy/idiomatic/results.py b/astrapy/idiomatic/results.py index 70fa8dd2..c83aca4f 100644 --- a/astrapy/idiomatic/results.py +++ b/astrapy/idiomatic/results.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional @dataclass @@ -29,3 +29,9 @@ class DeleteResult: class InsertOneResult: inserted_id: Any acknowledged: bool = True + + +@dataclass +class InsertManyResult: + inserted_ids: List[Any] + acknowledged: bool = True diff --git a/astrapy/idiomatic/types.py b/astrapy/idiomatic/types.py new file mode 100644 index 00000000..5a18a32b --- /dev/null +++ b/astrapy/idiomatic/types.py @@ -0,0 +1,21 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Dict, Iterable, Union + + +DocumentType = Dict[str, Any] +ProjectionType = Union[Iterable[str], Dict[str, bool]] diff --git a/tests/idiomatic/integration/test_dml_async.py b/tests/idiomatic/integration/test_dml_async.py index b4f76c82..1b5e1ca6 100644 --- a/tests/idiomatic/integration/test_dml_async.py +++ b/tests/idiomatic/integration/test_dml_async.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import pytest from astrapy import AsyncCollection from astrapy.results import DeleteResult, InsertOneResult +from astrapy.api import APIRequestError +from astrapy.idiomatic.types import DocumentType +from astrapy.idiomatic.cursors import AsyncCursor class TestDMLAsync: @@ -74,3 +79,286 @@ async def test_collection_delete_many_async( assert do_result1.acknowledged is True assert do_result1.deleted_count == 2 assert await async_empty_collection.count_documents(filter={}) == 1 + + @pytest.mark.describe("test of collection find, async") + async def test_collection_find_async( + self, + async_empty_collection: AsyncCollection, + ) -> None: + await async_empty_collection.insert_many([{"seq": i} for i in range(30)]) + Nski = 1 + Nlim = 28 + Nsor = {"seq": -1} + Nfil = {"seq": {"$exists": True}} + + async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + return [doc async for doc in acursor] + + # case 0000 of find-pattern matrix + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=None, sort=None, filter=None + ) + ) + ) + == 30 + ) + + # case 0001 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=None, sort=None, filter=Nfil + ) + ) + ) + == 30 + ) + + # case 0010 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=None, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 0011 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=None, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 0100 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=Nlim, sort=None, filter=None + ) + ) + ) + == 28 + ) + + # case 0101 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=Nlim, sort=None, filter=Nfil + ) + ) + ) + == 28 + ) + + # case 0110 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=Nlim, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 0111 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=None, limit=Nlim, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1000 + # len(list(async_empty_collection.find(skip=Nski, limit=None, sort=None, filter=None))) + + # case 1001 + # len(list(async_empty_collection.find(skip=Nski, limit=None, sort=None, filter=Nfil))) + + # case 1010 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=Nski, limit=None, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1011 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=Nski, limit=None, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1100 + # len(list(async_empty_collection.find(skip=Nski, limit=Nlim, sort=None, filter=None))) + + # case 1101 + # len(list(async_empty_collection.find(skip=Nski, limit=Nlim, sort=None, filter=Nfil))) + + # case 1110 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=Nski, limit=Nlim, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1111 + assert ( + len( + await _alist( + async_empty_collection.find( + skip=Nski, limit=Nlim, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + @pytest.mark.describe("test of cursors from collection.find, async") + async def test_collection_cursors_async( + self, + async_empty_collection: AsyncCollection, + ) -> None: + """ + Functionalities of cursors from find, other than the various + combinations of skip/limit/sort/filter specified above. + """ + await async_empty_collection.insert_many( + [{"seq": i, "ternary": (i % 3)} for i in range(10)] + ) + + # projection + cursor0 = async_empty_collection.find(projection={"ternary": False}) + document0 = await cursor0.__anext__() + assert "ternary" not in document0 + cursor0b = async_empty_collection.find(projection={"ternary": True}) + document0b = await cursor0b.__anext__() + assert "ternary" in document0b + + async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + return [doc async for doc in acursor] + + # rewinding, slicing and retrieved + cursor1 = async_empty_collection.find(sort={"seq": 1}) + await cursor1.__anext__() + await cursor1.__anext__() + items1 = (await _alist(cursor1))[:2] + assert await _alist(cursor1.rewind()) == await _alist( + async_empty_collection.find(sort={"seq": 1}) + ) + cursor1.rewind() + assert items1 == await _alist(cursor1[2:4]) # type: ignore[arg-type] + assert cursor1.retrieved == 2 + + # address, cursor_id, collection + assert cursor1.address == async_empty_collection._astra_db_collection.base_path + assert isinstance(cursor1.cursor_id, int) + assert cursor1.collection == async_empty_collection + + # clone, alive + cursor2 = async_empty_collection.find() + assert cursor2.alive is True + for _ in range(8): + await cursor2.__anext__() + assert cursor2.alive is True + cursor3 = cursor2.clone() + assert len(await _alist(cursor2)) == 2 + assert len(await _alist(cursor3)) == 10 + assert cursor2.alive is False + + # close + cursor4 = async_empty_collection.find() + for _ in range(8): + await cursor4.__anext__() + cursor4.close() + assert cursor4.alive is False + with pytest.raises(StopAsyncIteration): + await cursor4.__anext__() + + # distinct + cursor5 = async_empty_collection.find() + dist5 = await cursor5.distinct("ternary") + assert (len(await _alist(cursor5))) == 10 + assert set(dist5) == {0, 1, 2} + cursor6 = async_empty_collection.find() + for _ in range(9): + await cursor6.__anext__() + dist6 = await cursor6.distinct("ternary") + assert (len(await _alist(cursor6))) == 1 + assert set(dist6) == {0, 1, 2} + + # distinct from collections + assert set(await async_empty_collection.distinct("ternary")) == {0, 1, 2} + assert set(await async_empty_collection.distinct("nonfield")) == set() + + # indexing by integer + cursor7 = async_empty_collection.find(sort={"seq": 1}) + assert cursor7[5]["seq"] == 5 + + # indexing by wrong type + with pytest.raises(TypeError): + cursor7.rewind() + cursor7["wrong"] + + @pytest.mark.describe("test of collection insert_many, async") + async def test_collection_insert_many_async( + self, + async_empty_collection: AsyncCollection, + ) -> None: + acol = async_empty_collection + col = acol.to_sync() # TODO: replace with async find once implemented + + ins_result1 = await acol.insert_many([{"_id": "a"}, {"_id": "b"}]) + assert set(ins_result1.inserted_ids) == {"a", "b"} + assert {doc["_id"] for doc in col.find()} == {"a", "b"} + + with pytest.raises(APIRequestError): + await acol.insert_many([{"_id": "a"}, {"_id": "c"}]) + assert {doc["_id"] for doc in col.find()} == {"a", "b"} + + with pytest.raises(APIRequestError): + await acol.insert_many([{"_id": "c"}, {"_id": "a"}, {"_id": "d"}]) + assert {doc["_id"] for doc in col.find()} == {"a", "b", "c"} + + with pytest.raises(ValueError): + await acol.insert_many( + [{"_id": "c"}, {"_id": "d"}, {"_id": "e"}], + ordered=False, + ) + assert {doc["_id"] for doc in col.find()} == {"a", "b", "c", "d", "e"} diff --git a/tests/idiomatic/integration/test_dml_sync.py b/tests/idiomatic/integration/test_dml_sync.py index 7a286391..1933f5af 100644 --- a/tests/idiomatic/integration/test_dml_sync.py +++ b/tests/idiomatic/integration/test_dml_sync.py @@ -16,6 +16,7 @@ from astrapy import Collection from astrapy.results import DeleteResult, InsertOneResult +from astrapy.api import APIRequestError class TestDMLSync: @@ -74,3 +75,279 @@ def test_collection_delete_many_sync( assert do_result1.acknowledged is True assert do_result1.deleted_count == 2 assert sync_empty_collection.count_documents(filter={}) == 1 + + @pytest.mark.describe("test of collection find, sync") + def test_collection_find_sync( + self, + sync_empty_collection: Collection, + ) -> None: + sync_empty_collection.insert_many([{"seq": i} for i in range(30)]) + Nski = 1 + Nlim = 28 + Nsor = {"seq": -1} + Nfil = {"seq": {"$exists": True}} + + # case 0000 of find-pattern matrix + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=None, sort=None, filter=None + ) + ) + ) + == 30 + ) + + # case 0001 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=None, sort=None, filter=Nfil + ) + ) + ) + == 30 + ) + + # case 0010 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=None, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 0011 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=None, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 0100 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=Nlim, sort=None, filter=None + ) + ) + ) + == 28 + ) + + # case 0101 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=Nlim, sort=None, filter=Nfil + ) + ) + ) + == 28 + ) + + # case 0110 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=Nlim, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 0111 + assert ( + len( + list( + sync_empty_collection.find( + skip=None, limit=Nlim, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1000 + # len(list(sync_empty_collection.find(skip=Nski, limit=None, sort=None, filter=None))) + + # case 1001 + # len(list(sync_empty_collection.find(skip=Nski, limit=None, sort=None, filter=Nfil))) + + # case 1010 + assert ( + len( + list( + sync_empty_collection.find( + skip=Nski, limit=None, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1011 + assert ( + len( + list( + sync_empty_collection.find( + skip=Nski, limit=None, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1100 + # len(list(sync_empty_collection.find(skip=Nski, limit=Nlim, sort=None, filter=None))) + + # case 1101 + # len(list(sync_empty_collection.find(skip=Nski, limit=Nlim, sort=None, filter=Nfil))) + + # case 1110 + assert ( + len( + list( + sync_empty_collection.find( + skip=Nski, limit=Nlim, sort=Nsor, filter=None + ) + ) + ) + == 20 + ) # NONPAGINATED + + # case 1111 + assert ( + len( + list( + sync_empty_collection.find( + skip=Nski, limit=Nlim, sort=Nsor, filter=Nfil + ) + ) + ) + == 20 + ) # NONPAGINATED + + @pytest.mark.describe("test of cursors from collection.find, sync") + def test_collection_cursors_sync( + self, + sync_empty_collection: Collection, + ) -> None: + """ + Functionalities of cursors from find, other than the various + combinations of skip/limit/sort/filter specified above. + """ + sync_empty_collection.insert_many( + [{"seq": i, "ternary": (i % 3)} for i in range(10)] + ) + + # projection + cursor0 = sync_empty_collection.find(projection={"ternary": False}) + document0 = cursor0.__next__() + assert "ternary" not in document0 + cursor0b = sync_empty_collection.find(projection={"ternary": True}) + document0b = cursor0b.__next__() + assert "ternary" in document0b + + # rewinding, slicing and retrieved + cursor1 = sync_empty_collection.find(sort={"seq": 1}) + cursor1.__next__() + cursor1.__next__() + items1 = list(cursor1)[:2] + assert list(cursor1.rewind()) == list( + sync_empty_collection.find(sort={"seq": 1}) + ) + cursor1.rewind() + assert items1 == list(cursor1[2:4]) + assert cursor1.retrieved == 2 + + # address, cursor_id, collection + assert cursor1.address == sync_empty_collection._astra_db_collection.base_path + assert isinstance(cursor1.cursor_id, int) + assert cursor1.collection == sync_empty_collection + + # clone, alive + cursor2 = sync_empty_collection.find() + assert cursor2.alive is True + for _ in range(8): + cursor2.__next__() + assert cursor2.alive is True + cursor3 = cursor2.clone() + assert len(list(cursor2)) == 2 + assert len(list(cursor3)) == 10 + assert cursor2.alive is False + + # close + cursor4 = sync_empty_collection.find() + for _ in range(8): + cursor4.__next__() + cursor4.close() + assert cursor4.alive is False + with pytest.raises(StopIteration): + cursor4.__next__() + + # distinct + cursor5 = sync_empty_collection.find() + dist5 = cursor5.distinct("ternary") + assert (len(list(cursor5))) == 10 + assert set(dist5) == {0, 1, 2} + cursor6 = sync_empty_collection.find() + for _ in range(9): + cursor6.__next__() + dist6 = cursor6.distinct("ternary") + assert (len(list(cursor6))) == 1 + assert set(dist6) == {0, 1, 2} + + # distinct from collections + assert set(sync_empty_collection.distinct("ternary")) == {0, 1, 2} + assert set(sync_empty_collection.distinct("nonfield")) == set() + + # indexing by integer + cursor7 = sync_empty_collection.find(sort={"seq": 1}) + assert cursor7[5]["seq"] == 5 + + # indexing by wrong type + with pytest.raises(TypeError): + cursor7.rewind() + cursor7["wrong"] + + @pytest.mark.describe("test of collection insert_many, sync") + def test_collection_insert_many_sync( + self, + sync_empty_collection: Collection, + ) -> None: + col = sync_empty_collection + + ins_result1 = col.insert_many([{"_id": "a"}, {"_id": "b"}]) + assert set(ins_result1.inserted_ids) == {"a", "b"} + assert {doc["_id"] for doc in col.find()} == {"a", "b"} + + with pytest.raises(APIRequestError): + col.insert_many([{"_id": "a"}, {"_id": "c"}]) + assert {doc["_id"] for doc in col.find()} == {"a", "b"} + + with pytest.raises(APIRequestError): + col.insert_many([{"_id": "c"}, {"_id": "a"}, {"_id": "d"}]) + assert {doc["_id"] for doc in col.find()} == {"a", "b", "c"} + + with pytest.raises(ValueError): + col.insert_many( + [{"_id": "c"}, {"_id": "d"}, {"_id": "e"}], + ordered=False, + ) + assert {doc["_id"] for doc in col.find()} == {"a", "b", "c", "d", "e"} diff --git a/tests/idiomatic/unit/test_collections_async.py b/tests/idiomatic/unit/test_collections_async.py index a1f2e629..c6019915 100644 --- a/tests/idiomatic/unit/test_collections_async.py +++ b/tests/idiomatic/unit/test_collections_async.py @@ -210,8 +210,6 @@ async def test_collection_unsupported_methods_async( await async_collection_instance.list_search_indexes(1, "x") with pytest.raises(TypeError): await async_collection_instance.update_search_index(1, "x") - with pytest.raises(TypeError): - await async_collection_instance.distinct(1, "x") @pytest.mark.describe("test collection conversions with caller mutableness, async") async def test_collection_conversions_caller_mutableness_async( diff --git a/tests/idiomatic/unit/test_collections_sync.py b/tests/idiomatic/unit/test_collections_sync.py index 5977e7b8..fe8069fb 100644 --- a/tests/idiomatic/unit/test_collections_sync.py +++ b/tests/idiomatic/unit/test_collections_sync.py @@ -210,8 +210,6 @@ def test_collection_unsupported_methods_sync( sync_collection_instance.list_search_indexes(1, "x") with pytest.raises(TypeError): sync_collection_instance.update_search_index(1, "x") - with pytest.raises(TypeError): - sync_collection_instance.distinct(1, "x") @pytest.mark.describe("test collection conversions with caller mutableness, sync") def test_collection_conversions_caller_mutableness_sync(