Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement: find().count() #313

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ ENV/
env.bak/
venv.bak/

# Jetbrains IDE
.idea/

# Spyder project settings
.spyderproject
.spyproject
Expand Down
30 changes: 29 additions & 1 deletion aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,31 @@ async def all(self, batch_size=10):
return await query.execute()
return await self.execute()

async def count(self) -> int:
args = [
"FT.AGGREGATE",
self.model.Meta.index_name,
self.query,
"APPLY",
"matched_terms()",
"AS",
"countable",
"GROUPBY",
"1",
"@countable",
"REDUCE",
"COUNT",
"0",
]
raw_result = await self.model.db().execute_command(*args)
print(raw_result, args)
try:
return sum(
[int(result[3].decode("utf-8", "ignore")) for result in raw_result[1:]]
)
except IndexError:
return 0

def sort_by(self, *fields: str):
if not fields:
return self
Expand Down Expand Up @@ -792,7 +817,10 @@ async def update(self, use_transaction=True, **field_values):
async def delete(self):
"""Delete all matching records in this query."""
# TODO: Better response type, error detection
return await self.model.db().delete(*[m.key() for m in await self.all()])
keys_to_delete = [m.key() for m in await self.all()]
if not keys_to_delete:
return 0
return await self.model.db().delete(*keys_to_delete)

async def __aiter__(self):
if self._model_cache:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def key_prefix(request, redis):
def cleanup_keys(request):
# Always use the sync Redis connection with finalizer. Setting up an
# async finalizer should work, but I'm not suer how yet!
from redis_om.connections import get_redis_connection as get_sync_redis
from aredis_om.connections import get_redis_connection as get_sync_redis

# Increment for every pytest-xdist worker
conn = get_sync_redis()
Expand Down
32 changes: 24 additions & 8 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# We need to run this check as sync code (during tests) even in async mode
# because we call it in the top-level module scope.
from redis_om import has_redisearch
from aredis_om import has_redisearch
from tests.conftest import py_test_mark_asyncio

if not has_redisearch():
Expand Down Expand Up @@ -96,6 +96,16 @@ async def members(m):
yield member1, member2, member3


@py_test_mark_asyncio
async def test_all_query(members, m):

actual = await m.Member.find().all()
assert all([member in actual for member in members])

actual_count = await m.Member.find().count()
assert actual_count == len(members)


@py_test_mark_asyncio
async def test_exact_match_queries(members, m):
member1, member2, member3 = members
Expand Down Expand Up @@ -129,6 +139,11 @@ async def test_exact_match_queries(members, m):
).all()
assert actual == [member2]

actual_count = await m.Member.find(
m.Member.first_name == "Kim", m.Member.last_name == "Brookins"
).count()
assert actual_count == 1


@py_test_mark_asyncio
async def test_full_text_search_queries(members, m):
Expand Down Expand Up @@ -162,16 +177,17 @@ async def test_recursive_query_resolution(members, m):
async def test_tag_queries_boolean_logic(members, m):
member1, member2, member3 = members

actual = await (
m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)
.sort_by("age")
.all()
find_query = m.Member.find(
(m.Member.first_name == "Andrew") & (m.Member.last_name == "Brookins")
| (m.Member.last_name == "Smith")
)

actual = await find_query.sort_by("age").all()
assert actual == [member1, member3]

actual_count = await find_query.count()
assert actual_count == 2


@py_test_mark_asyncio
async def test_tag_queries_punctuation(m):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# We need to run this check as sync code (during tests) even in async mode
# because we call it in the top-level module scope.
from redis_om import has_redis_json
from aredis_om import has_redis_json
from tests.conftest import py_test_mark_asyncio

if not has_redis_json():
Expand Down Expand Up @@ -291,7 +291,7 @@ async def test_saves_many_explicit_transaction(address, m):
async with m.Member.db().pipeline(transaction=True) as pipeline:
await m.Member.add(members, pipeline=pipeline)
assert result == [member1, member2]
assert await pipeline.execute() == ["OK", "OK"]
assert await pipeline.execute() == [b"OK", b"OK"]

assert await m.Member.get(pk=member1.pk) == member1
assert await m.Member.get(pk=member2.pk) == member2
Expand Down