Skip to content

Commit

Permalink
fix: Merge issues
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinand-c committed Dec 9, 2023
2 parents 4e91e7d + 8828d6b commit 2c2c1df
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 134 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ To create a GraphQL schema and async executor; for it you simply have to write t
import graphene

from graphene_mongo import AsyncMongoengineObjectType
from asgiref.sync import sync_to_async
from graphene_mongo.utils import sync_to_async
from concurrent.futures import ThreadPoolExecutor

from .models import User as UserModel
Expand Down
27 changes: 18 additions & 9 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,13 @@ async def reference_resolver_async(root, *args, **kwargs):
)
tasks.append(task)
result = await asyncio.gather(*tasks)
result = [each[0] for each in result]
result_object_ids = list()
for each in result:
result_object_ids.append(each.id)
result_object = {}
for items in result:
for item in items:
result_object[item.id] = item
ordered_result = list()
for each in to_resolve_object_ids:
ordered_result.append(result[result_object_ids.index(each)])
ordered_result.append(result_object[each])
return ordered_result
return None

Expand Down Expand Up @@ -354,10 +354,19 @@ def convert_field_to_union(field, registry=None, executor: ExecutorEnum = Execut
if len(_types) == 0:
return None

name = (
to_camel_case("{}_{}".format(field._owner_document.__name__, field.db_field)) + "UnionType"
if ExecutorEnum.SYNC
else "AsyncUnionType"
field_name = field.db_field
if field_name is None:
# Get db_field name from parent mongo_field
for db_field_name, _mongo_parent_field in field.owner_document._fields.items():
if hasattr(_mongo_parent_field, "field") and _mongo_parent_field.field == field:
field_name = db_field_name
break

name = to_camel_case(
"{}_{}_union_type".format(
field._owner_document.__name__,
field_name,
)
)
Meta = type("Meta", (object,), {"types": tuple(_types)})
_union = type(name, (graphene.Union,), {"Meta": Meta})
Expand Down
25 changes: 17 additions & 8 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
find_skip_and_limit,
get_query_fields,
sync_to_async,
has_page_info,
)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand Down Expand Up @@ -100,6 +101,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
before = args.pop("before", None)
if before:
before = cursor_to_offset(before)
requires_page_info = has_page_info(info)
has_next_page = False

if resolved is not None:
Expand All @@ -112,7 +114,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
else:
count = None
except OperationFailure:
count = len(items)
count = await sync_to_async(len)(items)
else:
count = len(items)

Expand All @@ -129,7 +131,14 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
)
items = await sync_to_async(_base_query.limit)(limit)
has_next_page = (
len(await sync_to_async(_base_query.skip(limit).only("id").limit)(1)) != 0
(
await sync_to_async(len)(
await sync_to_async(_base_query.skip(limit).only("id").limit)(1)
)
!= 0
)
if requires_page_info
else False
)
elif skip:
items = await sync_to_async(items.skip)(skip)
Expand All @@ -138,11 +147,12 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
has_next_page = (
(skip + limit) < len(_base_query) if requires_page_info else False
)
elif skip:
items = items[skip:]
iterables = await sync_to_async(list)(items)
Expand Down Expand Up @@ -238,24 +248,23 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
if reverse:
_base_query = items[::-1]
items = _base_query[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
else:
_base_query = items
items = items[skip : skip + limit]
has_next_page = (skip + limit) < len(_base_query)
has_next_page = (skip + limit) < len(_base_query) if requires_page_info else False
elif skip:
items = items[skip:]
iterables = items
iterables = await sync_to_async(list)(iterables)
list_length = len(iterables)

if count:
if requires_page_info and count:
has_next_page = (
True
if (0 if limit is None else limit) + (0 if skip is None else skip) < count
else False
)
has_previous_page = True if skip else False
has_previous_page = True if requires_page_info and skip else False

if reverse:
iterables = await sync_to_async(list)(iterables)
Expand Down
20 changes: 8 additions & 12 deletions graphene_mongo/types.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor

import graphene
import mongoengine
from asgiref.sync import sync_to_async
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.inputobjecttype import InputObjectType, InputObjectTypeOptions
from graphene.types.interface import Interface, InterfaceOptions
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from graphene.utils.str_converters import to_snake_case
from graphene_mongo import MongoengineConnectionField

from graphene_mongo import MongoengineConnectionField
from .converter import convert_mongoengine_field
from .registry import Registry, get_global_registry, get_inputs_registry
from .utils import (
ExecutorEnum,
get_model_fields,
is_valid_mongoengine_model,
get_query_fields,
ExecutorEnum,
is_valid_mongoengine_model,
sync_to_async,
)


Expand Down Expand Up @@ -246,8 +245,6 @@ async def get_node(cls, info, id):
required_fields = list(set(required_fields))
return await sync_to_async(
cls._meta.model.objects.no_dereference().only(*required_fields).get,
thread_sensitive=False,
executor=ThreadPoolExecutor(),
)(pk=id)

def resolve_id(self, info):
Expand All @@ -259,10 +256,9 @@ def resolve_id(self, info):
MongoengineObjectType, MongoengineObjectTypeOptions = create_graphene_generic_class(
ObjectType, ObjectTypeOptions
)
(
MongoengineInterfaceType,
MongoengineInterfaceTypeOptions,
) = create_graphene_generic_class(Interface, InterfaceOptions)
MongoengineInterfaceType, MongoengineInterfaceTypeOptions = create_graphene_generic_class(
Interface, InterfaceOptions
)
MongoengineInputType, MongoengineInputTypeOptions = create_graphene_generic_class(
InputObjectType, InputObjectTypeOptions
)
Expand Down
3 changes: 1 addition & 2 deletions graphene_mongo/types_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import graphene
import mongoengine
from asgiref.sync import sync_to_async
from graphene import InputObjectType
from graphene.relay import Connection, Node
from graphene.types.interface import Interface, InterfaceOptions
Expand All @@ -11,7 +10,7 @@
from graphene_mongo import AsyncMongoengineConnectionField
from .registry import Registry, get_global_async_registry, get_inputs_async_registry
from .types import construct_fields, construct_self_referenced_fields
from .utils import ExecutorEnum, get_query_fields, is_valid_mongoengine_model
from .utils import ExecutorEnum, get_query_fields, is_valid_mongoengine_model, sync_to_async


def create_graphene_generic_class_async(object_type, option_type):
Expand Down
25 changes: 24 additions & 1 deletion graphene_mongo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Any, Callable, Union

import mongoengine
from asgiref.sync import SyncToAsync, sync_to_async as asgiref_sync_to_async
from asgiref.sync import sync_to_async as asgiref_sync_to_async
from asgiref.sync import SyncToAsync
from graphene import Node
from graphene.utils.trim_docstring import trim_docstring
from graphql import FieldNode
Expand Down Expand Up @@ -173,6 +174,28 @@ def get_query_fields(info):
return query


def has_page_info(info):
"""A convenience function to call collect_query_fields with info
for retrieving if page_info details are required
Args:
info (ResolveInfo)
Returns:
bool: True if it received pageinfo
"""

fragments = {}
if not info:
return True # Returning True if invalid info is provided
node = ast_to_dict(info.field_nodes[0])
for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)

query = collect_query_fields(node, fragments)
return next((True for x in query.keys() if x.lower() == "pageinfo"), False)


def ast_to_dict(node, include_loc=False):
if isinstance(node, FieldNode):
d = {"kind": node.__class__.__name__}
Expand Down
Loading

0 comments on commit 2c2c1df

Please sign in to comment.