Skip to content

Commit

Permalink
Sl collection options + cap-aware count_documents (#244)
Browse files Browse the repository at this point in the history
* collection.options method

* provisionally raising an error if count_documents exceeds its max
  • Loading branch information
hemidactylus authored Mar 7, 2024
1 parent 9bfa169 commit ef6f3d9
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 6 deletions.
46 changes: 40 additions & 6 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,17 @@ def __init__(
caller_version=caller_version,
)
# this comes after the above, lets AstraDBCollection resolve namespace
self._database = database.copy(namespace=self.namespace)
self._database = database.copy(
namespace=self._astra_db_collection.astra_db.namespace
)

@property
def database(self) -> Database:
return self._database

@property
def namespace(self) -> str:
return self._astra_db_collection.astra_db.namespace
return self.database.namespace

@property
def name(self) -> str:
Expand Down Expand Up @@ -163,6 +165,17 @@ def set_caller(
caller_version=caller_version,
)

def options(self) -> Dict[str, Any]:
self_dicts = [
coll_dict
for coll_dict in self.database.list_collections()
if coll_dict["name"] == self.name
]
if self_dicts:
return self_dicts[0]
else:
raise ValueError(f"Collection {self.namespace}.{self.name} not found.")

def insert_one(
self,
document: DocumentType,
Expand Down Expand Up @@ -290,7 +303,11 @@ def count_documents(
) -> int:
cd_response = self._astra_db_collection.count_documents(filter=filter)
if "count" in cd_response.get("status", {}):
return cd_response["status"]["count"] # type: ignore[no-any-return]
count: int = cd_response["status"]["count"]
if cd_response["status"].get("moreData", False):
raise ValueError(f"Document count exceeds {count}")
else:
return count
else:
raise ValueError(
"Could not complete a count_documents operation. "
Expand Down Expand Up @@ -568,15 +585,17 @@ def __init__(
caller_version=caller_version,
)
# this comes after the above, lets AstraDBCollection resolve namespace
self._database = database.copy(namespace=self.namespace)
self._database = database.copy(
namespace=self._astra_db_collection.astra_db.namespace
)

@property
def database(self) -> AsyncDatabase:
return self._database

@property
def namespace(self) -> str:
return self._astra_db_collection.astra_db.namespace
return self.database.namespace

@property
def name(self) -> str:
Expand Down Expand Up @@ -656,6 +675,17 @@ def set_caller(
caller_version=caller_version,
)

async def options(self) -> Dict[str, Any]:
self_dicts = [
coll_dict
async for coll_dict in self.database.list_collections()
if coll_dict["name"] == self.name
]
if self_dicts:
return self_dicts[0]
else:
raise ValueError(f"Collection {self.namespace}.{self.name} not found.")

async def insert_one(
self,
document: DocumentType,
Expand Down Expand Up @@ -784,7 +814,11 @@ async def count_documents(
) -> int:
cd_response = await self._astra_db_collection.count_documents(filter=filter)
if "count" in cd_response.get("status", {}):
return cd_response["status"]["count"] # type: ignore[no-any-return]
count: int = cd_response["status"]["count"]
if cd_response["status"].get("moreData", False):
raise ValueError(f"Document count exceeds {count}")
else:
return count
else:
raise ValueError(
"Could not complete a count_documents operation. "
Expand Down
8 changes: 8 additions & 0 deletions tests/idiomatic/integration/test_ddl_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ async def test_database_list_collections_async(
) -> None:
assert TEST_COLLECTION_NAME in await async_database.list_collection_names()

@pytest.mark.describe("test of Collection options, async")
async def test_collection_options_async(
self,
async_collection: AsyncCollection,
) -> None:
options = await async_collection.options()
assert options["name"] == async_collection.name

@pytest.mark.skipif(
ASTRA_DB_SECONDARY_KEYSPACE is None, reason="No secondary keyspace provided"
)
Expand Down
8 changes: 8 additions & 0 deletions tests/idiomatic/integration/test_ddl_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_database_list_collections_sync(
) -> None:
assert TEST_COLLECTION_NAME in sync_database.list_collection_names()

@pytest.mark.describe("test of Collection options, sync")
def test_collection_options_sync(
self,
sync_collection: Collection,
) -> None:
options = sync_collection.options()
assert options["name"] == sync_collection.name

@pytest.mark.skipif(
ASTRA_DB_SECONDARY_KEYSPACE is None, reason="No secondary keyspace provided"
)
Expand Down
11 changes: 11 additions & 0 deletions tests/idiomatic/integration/test_dml_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ async def test_collection_count_documents_async(
assert await async_empty_collection.count_documents(filter={}) == 3
assert await async_empty_collection.count_documents(filter={"group": "A"}) == 2

@pytest.mark.describe("test of overflowing collection count_documents, async")
async def test_collection_overflowing_count_documents_async(
self,
async_empty_collection: AsyncCollection,
) -> None:
await async_empty_collection.insert_many([{"a": i} for i in range(999)])
assert await async_empty_collection.count_documents(filter={}) == 999
await async_empty_collection.insert_many([{"b": i} for i in range(2)])
with pytest.raises(ValueError):
assert await async_empty_collection.count_documents(filter={})

@pytest.mark.describe("test of collection insert_one, async")
async def test_collection_insert_one_async(
self,
Expand Down
11 changes: 11 additions & 0 deletions tests/idiomatic/integration/test_dml_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def test_collection_count_documents_sync(
assert sync_empty_collection.count_documents(filter={}) == 3
assert sync_empty_collection.count_documents(filter={"group": "A"}) == 2

@pytest.mark.describe("test of overflowing collection count_documents, sync")
def test_collection_overflowing_count_documents_sync(
self,
sync_empty_collection: Collection,
) -> None:
sync_empty_collection.insert_many([{"a": i} for i in range(999)])
assert sync_empty_collection.count_documents(filter={}) == 999
sync_empty_collection.insert_many([{"b": i} for i in range(2)])
with pytest.raises(ValueError):
assert sync_empty_collection.count_documents(filter={})

@pytest.mark.describe("test of collection insert_one, sync")
def test_collection_insert_one_sync(
self,
Expand Down

0 comments on commit ef6f3d9

Please sign in to comment.