Skip to content

Commit

Permalink
Feat: Ability to use @Skip @include graphql directives to exclude fie…
Browse files Browse the repository at this point in the history
…lds (#231)

* fix: remove unnecessary async

* fix: handle @include , @Skip directives when checking user queried fields

* fix: pagination errors

hasNextPage didn't become false when using after and first together

Fixed reverse Querying using last and before, in compliance to graphql relay spec

https://relay.dev/graphql/connections.htm#sec-Backward-pagination-arguments
https://relay.dev/graphql/connections.htm#sec-undefined.PageInfo

* bump: version 0.4.2
  • Loading branch information
mak626 authored Feb 24, 2024
1 parent a084895 commit 8f46790
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 220 deletions.
5 changes: 2 additions & 3 deletions graphene_mongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .fields import MongoengineConnectionField
from .fields_async import AsyncMongoengineConnectionField

from .types import MongoengineObjectType, MongoengineInputType, MongoengineInterfaceType
from .types import MongoengineInputType, MongoengineInterfaceType, MongoengineObjectType
from .types_async import AsyncMongoengineObjectType

__version__ = "0.1.1"
__version__ = "0.4.2"

__all__ = [
"__version__",
Expand Down
114 changes: 35 additions & 79 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
find_skip_and_limit,
get_model_reference_fields,
get_query_fields,
has_page_info,
)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand Down Expand Up @@ -276,7 +277,7 @@ def fields(self):
return self._type._meta.fields

def get_queryset(
self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args
self, model, info, required_fields=None, skip=None, limit=None, **args
) -> QuerySet:
if required_fields is None:
required_fields = list()
Expand Down Expand Up @@ -325,49 +326,22 @@ def get_queryset(
else:
args.update(queryset_or_filters)
if limit is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip if skip else 0)
.limit(limit)
)
else:
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit(limit)
)
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit(limit)
)
elif skip is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip)
)
else:
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip)
)
return (
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip)
)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
Expand Down Expand Up @@ -401,7 +375,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
skip = 0
count = 0
limit = None
reverse = False
first = args.pop("first", None)
after = args.pop("after", None)
if after:
Expand All @@ -410,14 +383,15 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
requires_page_info = has_page_info(info)
has_next_page = False

if resolved is not None:
items = resolved

if isinstance(items, QuerySet):
try:
if last is not None and after is not None:
if last is not None:
count = items.count(with_limit_and_skip=False)
else:
count = None
Expand All @@ -426,29 +400,24 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)

if isinstance(items, QuerySet):
if limit:
_base_query: QuerySet = (
items.order_by("-pk").skip(skip) if reverse else items.skip(skip)
)
_base_query: QuerySet = items.skip(skip)
items = _base_query.limit(limit)
has_next_page = len(_base_query.skip(limit).only("id").limit(1)) != 0
has_next_page = len(_base_query.skip(skip + limit).only("id").limit(1)) != 0
elif skip:
items = items.skip(skip)
else:
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
_base_query = items
items = items[skip : skip + limit]
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
elif skip:
items = items[skip:]
iterables = list(items)
Expand Down Expand Up @@ -503,11 +472,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
else:
count = self.model.objects(args_copy).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
self.model, info, required_fields, skip, limit, **args
)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
Expand All @@ -519,14 +488,11 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip : skip + limit]
else:
args["pk__in"] = args["pk__in"][skip : skip + limit]
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
Expand All @@ -542,18 +508,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
count = len(items)
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
iterables = items
Expand All @@ -567,11 +528,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
)
has_previous_page = True if skip else False

if reverse:
iterables = list(iterables)
iterables.reverse()
skip = limit

connection = connection_from_iterables(
edges=iterables,
start_offset=skip,
Expand Down
49 changes: 14 additions & 35 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
connection_from_iterables,
find_skip_and_limit,
get_query_fields,
sync_to_async,
has_page_info,
sync_to_async,
)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand Down Expand Up @@ -92,7 +92,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
skip = 0
count = 0
limit = None
reverse = False
first = args.pop("first", None)
after = args.pop("after", None)
if after:
Expand All @@ -109,7 +108,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

if isinstance(items, QuerySet):
try:
if last is not None and after is not None:
if last is not None:
count = await sync_to_async(items.count)(with_limit_and_skip=False)
else:
count = None
Expand All @@ -118,22 +117,18 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = len(items)

skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)

if isinstance(items, QuerySet):
if limit:
_base_query: QuerySet = (
await sync_to_async(items.order_by("-pk").skip)(skip)
if reverse
else await sync_to_async(items.skip)(skip)
)
_base_query: QuerySet = await sync_to_async(items.skip)(skip)
items = await sync_to_async(_base_query.limit)(limit)
has_next_page = (
(
await sync_to_async(len)(
await sync_to_async(_base_query.skip(limit).only("id").limit)(1)
_base_query.skip(skip + limit).only("id").limit(1)
)
!= 0
)
Expand All @@ -144,12 +139,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
items = await sync_to_async(items.skip)(skip)
else:
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
else:
_base_query = items
items = items[skip : skip + limit]
_base_query = items
items = items[skip : skip + limit]
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
Expand Down Expand Up @@ -200,11 +191,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = await sync_to_async(self.model.objects(args_copy).count)()
if count != 0:
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
self.model, info, required_fields, skip, limit, **args
)
iterables = await sync_to_async(list)(iterables)
list_length = len(iterables)
Expand All @@ -217,14 +208,11 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

elif "pk__in" in args and args["pk__in"]:
count = len(args["pk__in"])
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
args["pk__in"] = args["pk__in"][::-1][skip : skip + limit]
else:
args["pk__in"] = args["pk__in"][skip : skip + limit]
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = self.get_queryset(self.model, info, required_fields, **args)
Expand All @@ -241,16 +229,12 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
field_name = to_snake_case(info.field_name)
items = getattr(_root, field_name, [])
count = len(items)
skip, limit, reverse = find_skip_and_limit(
skip, limit = find_skip_and_limit(
first=first, last=last, after=after, before=before, count=count
)
if limit:
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
else:
_base_query = items
items = items[skip : skip + limit]
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
Expand All @@ -266,11 +250,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
)
has_previous_page = True if requires_page_info and skip else False

if reverse:
iterables = await sync_to_async(list)(iterables)
iterables.reverse()
skip = limit

connection = connection_from_iterables(
edges=iterables,
start_offset=skip,
Expand Down
Loading

0 comments on commit 8f46790

Please sign in to comment.