From 554a0f786b3d57c07e012f8a112e36d815914c93 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 19:12:18 +0530 Subject: [PATCH 1/7] refact: sync to async wrapper --- graphene_mongo/converter.py | 27 +++++++++------------------ graphene_mongo/fields_async.py | 29 +++++++++-------------------- graphene_mongo/utils.py | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/graphene_mongo/converter.py b/graphene_mongo/converter.py index 7a6beeb3..63424d85 100644 --- a/graphene_mongo/converter.py +++ b/graphene_mongo/converter.py @@ -3,13 +3,12 @@ import graphene import mongoengine -from asgiref.sync import sync_to_async from graphene.types.json import JSONString from graphene.utils.str_converters import to_snake_case, to_camel_case from mongoengine.base import get_document, LazyReference from . import advanced_types -from .utils import import_single_dispatch, get_field_description, get_query_fields, ExecutorEnum +from .utils import import_single_dispatch, get_field_description, get_query_fields, ExecutorEnum, sync_to_async from concurrent.futures import ThreadPoolExecutor, as_completed singledispatch = import_single_dispatch() @@ -211,8 +210,7 @@ async def get_reference_objects_async(*args, **kwargs): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) - return await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( - document.objects().no_dereference().only( + return await sync_to_async(list)(document.objects().no_dereference().only( *set(list(document_field_type._meta.required_fields) + queried_fields)).filter( pk__in=args[1])) @@ -398,10 +396,8 @@ async def reference_resolver_async(root, *args, **kwargs): if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) return await sync_to_async(document.objects().no_dereference().only(*list( - set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False, - executor=ThreadPoolExecutor())(pk=de_referenced["_ref"].id) - return await sync_to_async(document, thread_sensitive=False, - executor=ThreadPoolExecutor())() + set(list(_type._meta.required_fields) + queried_fields))).get)(pk=de_referenced["_ref"].id) + return await sync_to_async(document)() return None async def lazy_reference_resolver_async(root, *args, **kwargs): @@ -424,10 +420,8 @@ async def lazy_reference_resolver_async(root, *args, **kwargs): queried_fields.append(item) _type = registry.get_type_for_model(document.document_type, executor=executor) return await sync_to_async(document.document_type.objects().no_dereference().only( - *(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False, - executor=ThreadPoolExecutor())(pk=document.pk) - return await sync_to_async(document.document_type, thread_sensitive=False, - executor=ThreadPoolExecutor())() + *(set((list(_type._meta.required_fields) + queried_fields)))).get)(pk=document.pk) + return await sync_to_async(document.document_type)() return None if isinstance(field, mongoengine.GenericLazyReferenceField): @@ -520,8 +514,7 @@ async def reference_resolver_async(root, *args, **kwargs): if item in field.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) return await sync_to_async(field.document_type.objects().no_dereference().only( - *(set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False, - executor=ThreadPoolExecutor())(pk=document.id) + *(set(list(_type._meta.required_fields) + queried_fields))).get)(pk=document.id) return None async def cached_reference_resolver_async(root, *args, **kwargs): @@ -539,8 +532,7 @@ async def cached_reference_resolver_async(root, *args, **kwargs): queried_fields.append(item) return await sync_to_async(field.document_type.objects().no_dereference().only( *(set( - list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False, - executor=ThreadPoolExecutor())( + list(_type._meta.required_fields) + queried_fields))).get)( pk=getattr(root, field.name or field.db_name)) return None @@ -614,8 +606,7 @@ async def lazy_resolver_async(root, *args, **kwargs): if item in document.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) return await sync_to_async(document.document_type.objects().no_dereference().only( - *(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False, - executor=ThreadPoolExecutor())(pk=document.pk) + *(set((list(_type._meta.required_fields) + queried_fields)))).get)(pk=document.pk) return None def dynamic_type(): diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index 8886823d..d061968a 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -14,12 +14,11 @@ from mongoengine import QuerySet from promise import Promise from pymongo.errors import OperationFailure -from asgiref.sync import sync_to_async from concurrent.futures import ThreadPoolExecutor from .registry import get_global_async_registry from . import MongoengineConnectionField from .utils import get_query_fields, find_skip_and_limit, \ - connection_from_iterables, ExecutorEnum + connection_from_iterables, ExecutorEnum, sync_to_async import pymongo PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) @@ -100,8 +99,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if isinstance(items, QuerySet): try: if last is not None and after is not None: - count = await sync_to_async(items.count, thread_sensitive=False, - executor=ThreadPoolExecutor())(with_limit_and_skip=False) + count = await sync_to_async(items.count)(with_limit_and_skip=False) else: count = None except OperationFailure: @@ -130,7 +128,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non items = items[skip:skip + limit] elif skip: items = items[skip:] - iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())(items) + iterables = await sync_to_async(list)(items) list_length = len(iterables) elif callable(getattr(self.model, "objects", None)): @@ -157,10 +155,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non args_copy[key] = args_copy[key].value if PYMONGO_VERSION >= (3, 7): - count = await sync_to_async( - (mongoengine.get_db()[self.model._get_collection_name()]).count_documents, - thread_sensitive=False, - executor=ThreadPoolExecutor())(args_copy) + count = await sync_to_async((mongoengine.get_db()[self.model._get_collection_name()]).count_documents)(args_copy) else: count = await sync_to_async(self.model.objects(args_copy).count, thread_sensitive=False, executor=ThreadPoolExecutor())() @@ -168,8 +163,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before, count=count) iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args) - iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( - iterables) + iterables = await sync_to_async(list)(iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: @@ -188,8 +182,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non elif skip: args["pk__in"] = args["pk__in"][skip:] iterables = self.get_queryset(self.model, info, required_fields, **args) - iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( - iterables) + iterables = await sync_to_async(list)(iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: @@ -210,23 +203,19 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non elif skip: items = items[skip:] iterables = items - iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( - iterables) + iterables = await sync_to_async(list)( iterables) list_length = len(iterables) if count: has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False else: if isinstance(queryset, QuerySet) and iterables: - has_next_page = bool(await sync_to_async(queryset(pk__gt=iterables[-1].pk).limit(1).first, - thread_sensitive=False, - executor=ThreadPoolExecutor())()) + has_next_page = bool(await sync_to_async(queryset(pk__gt=iterables[-1].pk).limit(1).first)()) else: has_next_page = False has_previous_page = True if skip else False if reverse: - iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( - iterables) + iterables = await sync_to_async(list)(iterables) iterables.reverse() skip = limit connection = connection_from_iterables(edges=iterables, start_offset=skip, diff --git a/graphene_mongo/utils.py b/graphene_mongo/utils.py index 1c680f0a..0d587b70 100644 --- a/graphene_mongo/utils.py +++ b/graphene_mongo/utils.py @@ -3,8 +3,11 @@ import enum import inspect from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable import mongoengine +from asgiref.sync import SyncToAsync, sync_to_async as asgiref_sync_to_async from graphene import Node from graphene.utils.trim_docstring import trim_docstring from graphql import FieldNode @@ -255,6 +258,32 @@ def connection_from_iterables(edges, start_offset, has_previous_page, has_next_p start_cursor=first_edge_cursor, end_cursor=last_edge_cursor, has_previous_page=has_previous_page, - has_next_page=has_next_page - ) + has_next_page=has_next_page, + ), + ) + + +def sync_to_async( + func: Callable = None, + thread_sensitive: bool = False, + executor: Any = None, # noqa +) -> SyncToAsync | Callable[[Callable[..., Any]], SyncToAsync]: + """ + Wrapper over sync_to_async from asgiref.sync + Defaults to thread insensitive with ThreadPoolExecutor of n workers + Args: + func: + Function to be converted to coroutine + thread_sensitive: + If the operation is thread sensitive and should run in synchronous thread + executor: + Threadpool executor, if thread_sensitive=False + + Returns: + coroutine version of func + """ + if executor is None: + executor = ThreadPoolExecutor() + return asgiref_sync_to_async( + func=func, thread_sensitive=thread_sensitive, executor=executor ) From 44c47aa5893136e228d3f6a5c697511596274b90 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 19:31:46 +0530 Subject: [PATCH 2/7] chore: add ruff formatting --- docs/conf.py | 20 +- examples/django_mongoengine/bike/urls.py | 4 +- .../bike_catalog/settings.py | 4 +- .../bike_catalog/settings_test.py | 4 +- examples/falcon_mongoengine/api.py | 4 +- examples/falcon_mongoengine/tests/tests.py | 22 +- examples/flask_mongoengine/app.py | 4 +- examples/flask_mongoengine/models.py | 4 - examples/flask_mongoengine/schema.py | 13 +- graphene_mongo/__init__.py | 2 +- graphene_mongo/advanced_types.py | 7 +- graphene_mongo/converter.py | 341 ++++++++++++------ graphene_mongo/fields.py | 329 +++++++++++------ graphene_mongo/fields_async.py | 184 ++++++---- graphene_mongo/registry.py | 16 +- graphene_mongo/tests/conftest.py | 29 +- graphene_mongo/tests/models.py | 11 +- graphene_mongo/tests/test_converter.py | 61 +--- graphene_mongo/tests/test_inputs.py | 2 - graphene_mongo/tests/test_query.py | 8 +- graphene_mongo/tests/test_relay_query.py | 79 ++-- .../tests/test_relay_query_async.py | 81 ++--- graphene_mongo/tests/test_types.py | 2 +- graphene_mongo/tests/test_utils.py | 35 +- graphene_mongo/types.py | 106 +++--- graphene_mongo/types_async.py | 87 ++--- graphene_mongo/utils.py | 45 +-- poetry.lock | 62 ++-- pyproject.toml | 4 + 29 files changed, 886 insertions(+), 684 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index dba8cd82..deb0f8af 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -60,18 +60,18 @@ master_doc = "index" # General information about the project. -project = u"Graphene Mongo" -copyright = u"Graphene 2018" -author = u"Abaw Chen" +project = "Graphene Mongo" +copyright = "Graphene 2018" +author = "Abaw Chen" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = u"0.1" +version = "0.1" # The full version, including alpha/beta/rc tags. -release = u"0.1.2" +release = "0.1.2" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -275,9 +275,7 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, "Graphene.tex", u"Graphene Documentation", u"Syrus Akbary", "manual") -] +latex_documents = [(master_doc, "Graphene.tex", "Graphene Documentation", "Syrus Akbary", "manual")] # The name of an image file (relative to this directory) to place at the top of # the title page. @@ -316,9 +314,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, "graphene_django", u"Graphene Django Documentation", [author], 1) -] +man_pages = [(master_doc, "graphene_django", "Graphene Django Documentation", [author], 1)] # If true, show URL addresses after external links. # @@ -334,7 +330,7 @@ ( master_doc, "Graphene-Django", - u"Graphene Django Documentation", + "Graphene Django Documentation", author, "Graphene Django", "One line description of project.", diff --git a/examples/django_mongoengine/bike/urls.py b/examples/django_mongoengine/bike/urls.py index 55cfd515..b18e9883 100644 --- a/examples/django_mongoengine/bike/urls.py +++ b/examples/django_mongoengine/bike/urls.py @@ -3,7 +3,5 @@ from graphene_django.views import GraphQLView urlpatterns = [ - path( - "graphql", csrf_exempt(GraphQLView.as_view(graphiql=True)), name="graphql-query" - ) + path("graphql", csrf_exempt(GraphQLView.as_view(graphiql=True)), name="graphql-query") ] diff --git a/examples/django_mongoengine/bike_catalog/settings.py b/examples/django_mongoengine/bike_catalog/settings.py index f17c7f24..be86032f 100644 --- a/examples/django_mongoengine/bike_catalog/settings.py +++ b/examples/django_mongoengine/bike_catalog/settings.py @@ -85,9 +85,7 @@ # https://docs.djangoproject.com/en/2.2/ref/settings/#auth-password-validators AUTH_PASSWORD_VALIDATORS = [ - { - "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator" - }, + {"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator"}, {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator"}, {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator"}, {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator"}, diff --git a/examples/django_mongoengine/bike_catalog/settings_test.py b/examples/django_mongoengine/bike_catalog/settings_test.py index ddd3467c..0c608b02 100644 --- a/examples/django_mongoengine/bike_catalog/settings_test.py +++ b/examples/django_mongoengine/bike_catalog/settings_test.py @@ -1,5 +1,3 @@ from .settings import * # flake8: noqa -mongoengine.connect( - "graphene-mongo-test", host="mongomock://localhost", alias="default" -) +mongoengine.connect("graphene-mongo-test", host="mongomock://localhost", alias="default") diff --git a/examples/falcon_mongoengine/api.py b/examples/falcon_mongoengine/api.py index 59f3781a..c7f1810e 100644 --- a/examples/falcon_mongoengine/api.py +++ b/examples/falcon_mongoengine/api.py @@ -3,9 +3,7 @@ from .schema import schema -def set_graphql_allow_header( - req: falcon.Request, resp: falcon.Response, resource: object -): +def set_graphql_allow_header(req: falcon.Request, resp: falcon.Response, resource: object): resp.set_header("Allow", "GET, POST, OPTIONS") diff --git a/examples/falcon_mongoengine/tests/tests.py b/examples/falcon_mongoengine/tests/tests.py index 182f749c..80e0def7 100644 --- a/examples/falcon_mongoengine/tests/tests.py +++ b/examples/falcon_mongoengine/tests/tests.py @@ -1,11 +1,10 @@ import mongoengine from graphene.test import Client + from examples.falcon_mongoengine.schema import schema from .fixtures import fixtures_data -mongoengine.connect( - "graphene-mongo-test", host="mongomock://localhost", alias="default" -) +mongoengine.connect("graphene-mongo-test", host="mongomock://localhost", alias="default") def test_category_last_1_item_query(fixtures_data): @@ -23,7 +22,16 @@ def test_category_last_1_item_query(fixtures_data): expected = { "data": { - "categories": {"edges": [{"node": {"name": "Work", "color": "#1769ff"}}]} + "categories": { + "edges": [ + { + "node": { + "name": "Work", + "color": "#1769ff", + } + } + ] + } } } @@ -45,11 +53,7 @@ def test_category_filter_item_query(fixtures_data): } }""" - expected = { - "data": { - "categories": {"edges": [{"node": {"name": "Work", "color": "#1769ff"}}]} - } - } + expected = {"data": {"categories": {"edges": [{"node": {"name": "Work", "color": "#1769ff"}}]}}} client = Client(schema) result = client.execute(query) diff --git a/examples/flask_mongoengine/app.py b/examples/flask_mongoengine/app.py index 055cd319..d62f9b7d 100644 --- a/examples/flask_mongoengine/app.py +++ b/examples/flask_mongoengine/app.py @@ -42,9 +42,7 @@ } }""".strip() -app.add_url_rule( - "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True) -) +app.add_url_rule("/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True)) if __name__ == "__main__": init_db() diff --git a/examples/flask_mongoengine/models.py b/examples/flask_mongoengine/models.py index 556282fa..734fc625 100644 --- a/examples/flask_mongoengine/models.py +++ b/examples/flask_mongoengine/models.py @@ -10,25 +10,21 @@ class Department(Document): - meta = {"collection": "department"} name = StringField() class Role(Document): - meta = {"collection": "role"} name = StringField() class Task(EmbeddedDocument): - name = StringField() deadline = DateTimeField(default=datetime.now) class Employee(Document): - meta = {"collection": "employee"} name = StringField() hired_on = DateTimeField(default=datetime.now) diff --git a/examples/flask_mongoengine/schema.py b/examples/flask_mongoengine/schema.py index 2205c67b..228eaca5 100644 --- a/examples/flask_mongoengine/schema.py +++ b/examples/flask_mongoengine/schema.py @@ -1,6 +1,5 @@ import graphene from graphene.relay import Node -from graphene_mongo.tests.nodes import PlayerNode, ReporterNode from graphene_mongo import MongoengineConnectionField, MongoengineObjectType from .models import Department as DepartmentModel @@ -20,7 +19,11 @@ class Meta: model = RoleModel interfaces = (Node,) filter_fields = { - 'name': ['exact', 'icontains', 'istartswith'] + "name": [ + "exact", + "icontains", + "istartswith", + ] } @@ -35,7 +38,11 @@ class Meta: model = EmployeeModel interfaces = (Node,) filter_fields = { - 'name': ['exact', 'icontains', 'istartswith'] + "name": [ + "exact", + "icontains", + "istartswith", + ] } diff --git a/graphene_mongo/__init__.py b/graphene_mongo/__init__.py index 2a39c5f7..e1e64801 100644 --- a/graphene_mongo/__init__.py +++ b/graphene_mongo/__init__.py @@ -13,5 +13,5 @@ "MongoengineInputType", "MongoengineInterfaceType", "MongoengineConnectionField", - "AsyncMongoengineConnectionField" + "AsyncMongoengineConnectionField", ] diff --git a/graphene_mongo/advanced_types.py b/graphene_mongo/advanced_types.py index 10e2c11e..cedec69e 100644 --- a/graphene_mongo/advanced_types.py +++ b/graphene_mongo/advanced_types.py @@ -1,4 +1,5 @@ import base64 + import graphene @@ -65,5 +66,9 @@ class PolygonFieldType(_CoordinatesTypeField): class MultiPolygonFieldType(_CoordinatesTypeField): coordinates = graphene.List( - graphene.List(graphene.List(graphene.List(graphene.Float))) + graphene.List( + graphene.List( + graphene.List(graphene.Float), + ) + ) ) diff --git a/graphene_mongo/converter.py b/graphene_mongo/converter.py index 63424d85..f4cce42c 100644 --- a/graphene_mongo/converter.py +++ b/graphene_mongo/converter.py @@ -8,7 +8,13 @@ from graphene.utils.str_converters import to_snake_case, to_camel_case from mongoengine.base import get_document, LazyReference from . import advanced_types -from .utils import import_single_dispatch, get_field_description, get_query_fields, ExecutorEnum, sync_to_async +from .utils import ( + import_single_dispatch, + get_field_description, + get_query_fields, + ExecutorEnum, + sync_to_async, +) from concurrent.futures import ThreadPoolExecutor, as_completed singledispatch = import_single_dispatch() @@ -21,8 +27,7 @@ class MongoEngineConversionError(Exception): @singledispatch def convert_mongoengine_field(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): raise MongoEngineConversionError( - "Don't know how to convert the MongoEngine field %s (%s)" - % (field, field.__class__) + "Don't know how to convert the MongoEngine field %s (%s)" % (field, field.__class__) ) @@ -38,18 +43,14 @@ def convert_field_to_string(field, registry=None, executor: ExecutorEnum = Execu @convert_mongoengine_field.register(mongoengine.UUIDField) @convert_mongoengine_field.register(mongoengine.ObjectIdField) def convert_field_to_id(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return graphene.ID( - description=get_field_description(field, registry), required=field.required - ) + return graphene.ID(description=get_field_description(field, registry), required=field.required) @convert_mongoengine_field.register(mongoengine.IntField) @convert_mongoengine_field.register(mongoengine.LongField) @convert_mongoengine_field.register(mongoengine.SequenceField) def convert_field_to_int(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return graphene.Int( - description=get_field_description(field, registry), required=field.required - ) + return graphene.Int(description=get_field_description(field, registry), required=field.required) @convert_mongoengine_field.register(mongoengine.BooleanField) @@ -91,33 +92,43 @@ def convert_field_to_date(field, registry=None, executor: ExecutorEnum = Executo @convert_mongoengine_field.register(mongoengine.DictField) @convert_mongoengine_field.register(mongoengine.MapField) def convert_field_to_jsonstring(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return JSONString( - description=get_field_description(field, registry), required=field.required - ) + return JSONString(description=get_field_description(field, registry), required=field.required) @convert_mongoengine_field.register(mongoengine.PointField) def convert_point_to_field(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return graphene.Field(advanced_types.PointFieldType, description=get_field_description(field, registry), - required=field.required) + return graphene.Field( + advanced_types.PointFieldType, + description=get_field_description(field, registry), + required=field.required, + ) @convert_mongoengine_field.register(mongoengine.PolygonField) def convert_polygon_to_field(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return graphene.Field(advanced_types.PolygonFieldType, description=get_field_description(field, registry), - required=field.required) + return graphene.Field( + advanced_types.PolygonFieldType, + description=get_field_description(field, registry), + required=field.required, + ) @convert_mongoengine_field.register(mongoengine.MultiPolygonField) def convert_multipolygon_to_field(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return graphene.Field(advanced_types.MultiPolygonFieldType, description=get_field_description(field, registry), - required=field.required) + return graphene.Field( + advanced_types.MultiPolygonFieldType, + description=get_field_description(field, registry), + required=field.required, + ) @convert_mongoengine_field.register(mongoengine.FileField) def convert_file_to_field(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): - return graphene.Field(advanced_types.FileFieldType, description=get_field_description(field, registry), - required=field.required) + return graphene.Field( + advanced_types.FileFieldType, + description=get_field_description(field, registry), + required=field.required, + ) @convert_mongoengine_field.register(mongoengine.ListField) @@ -127,6 +138,7 @@ def convert_field_to_list(field, registry=None, executor: ExecutorEnum = Executo base_type = convert_mongoengine_field(field.field, registry=registry, executor=executor) if isinstance(base_type, graphene.Field): if isinstance(field.field, mongoengine.GenericReferenceField): + def get_reference_objects(*args, **kwargs): document = get_document(args[0][0]) document_field = mongoengine.ReferenceField(document) @@ -142,8 +154,12 @@ def get_reference_objects(*args, **kwargs): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) - return document.objects().no_dereference().only( - *set(list(document_field_type._meta.required_fields) + queried_fields)).filter(pk__in=args[0][1]) + return ( + document.objects() + .no_dereference() + .only(*set(list(document_field_type._meta.required_fields) + queried_fields)) + .filter(pk__in=args[0][1]) + ) def get_non_querying_object(*args, **kwargs): model = get_document(args[0][0]) @@ -154,8 +170,8 @@ def reference_resolver(root, *args, **kwargs): if to_resolve: choice_to_resolve = dict() querying_union_types = list(get_query_fields(args[0]).keys()) - if '__typename' in querying_union_types: - querying_union_types.remove('__typename') + if "__typename" in querying_union_types: + querying_union_types.remove("__typename") to_resolve_models = list() for each in querying_union_types: if executor == ExecutorEnum.SYNC: @@ -172,17 +188,24 @@ def reference_resolver(root, *args, **kwargs): choice_to_resolve[model].append(each.pk) else: to_resolve_object_ids.append(each["_ref"].id) - if each['_cls'] not in choice_to_resolve: - choice_to_resolve[each['_cls']] = list() - choice_to_resolve[each['_cls']].append(each["_ref"].id) + if each["_cls"] not in choice_to_resolve: + choice_to_resolve[each["_cls"]] = list() + choice_to_resolve[each["_cls"]].append(each["_ref"].id) pool = ThreadPoolExecutor(5) futures = list() for model, object_id_list in choice_to_resolve.items(): if model in to_resolve_models: - futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args))) + futures.append( + pool.submit( + get_reference_objects, (model, object_id_list, registry, args) + ) + ) else: futures.append( - pool.submit(get_non_querying_object, (model, object_id_list, registry, args))) + pool.submit( + get_non_querying_object, (model, object_id_list, registry, args) + ) + ) result = list() for x in as_completed(futures): result += x.result() @@ -198,7 +221,9 @@ def reference_resolver(root, *args, **kwargs): async def get_reference_objects_async(*args, **kwargs): document = get_document(args[0]) document_field = mongoengine.ReferenceField(document) - document_field = convert_mongoengine_field(document_field, registry, executor=ExecutorEnum.ASYNC) + document_field = convert_mongoengine_field( + document_field, registry, executor=ExecutorEnum.ASYNC + ) document_field_type = document_field.get_type().type queried_fields = list() filter_args = list() @@ -210,9 +235,12 @@ async def get_reference_objects_async(*args, **kwargs): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) - return await sync_to_async(list)(document.objects().no_dereference().only( - *set(list(document_field_type._meta.required_fields) + queried_fields)).filter( - pk__in=args[1])) + return await sync_to_async(list)( + document.objects() + .no_dereference() + .only(*set(list(document_field_type._meta.required_fields) + queried_fields)) + .filter(pk__in=args[1]) + ) async def get_non_querying_object_async(*args, **kwargs): model = get_document(args[0]) @@ -223,8 +251,8 @@ async def reference_resolver_async(root, *args, **kwargs): if to_resolve: choice_to_resolve = dict() querying_union_types = list(get_query_fields(args[0]).keys()) - if '__typename' in querying_union_types: - querying_union_types.remove('__typename') + if "__typename" in querying_union_types: + querying_union_types.remove("__typename") to_resolve_models = list() for each in querying_union_types: if executor == ExecutorEnum.SYNC: @@ -241,17 +269,20 @@ async def reference_resolver_async(root, *args, **kwargs): choice_to_resolve[model].append(each.pk) else: to_resolve_object_ids.append(each["_ref"].id) - if each['_cls'] not in choice_to_resolve: - choice_to_resolve[each['_cls']] = list() - choice_to_resolve[each['_cls']].append(each["_ref"].id) + if each["_cls"] not in choice_to_resolve: + choice_to_resolve[each["_cls"]] = list() + choice_to_resolve[each["_cls"]].append(each["_ref"].id) loop = asyncio.get_event_loop() tasks = [] for model, object_id_list in choice_to_resolve.items(): if model in to_resolve_models: - task = loop.create_task(get_reference_objects_async(model, object_id_list, registry, args)) + task = loop.create_task( + get_reference_objects_async(model, object_id_list, registry, args) + ) else: task = loop.create_task( - get_non_querying_object_async(model, object_id_list, registry, args)) + get_non_querying_object_async(model, object_id_list, registry, args) + ) tasks.append(task) result = await asyncio.gather(*tasks) result = [each[0] for each in result] @@ -268,12 +299,14 @@ async def reference_resolver_async(root, *args, **kwargs): base_type._type, description=get_field_description(field, registry), required=field.required, - resolver=reference_resolver if executor == ExecutorEnum.SYNC else reference_resolver_async + resolver=reference_resolver + if executor == ExecutorEnum.SYNC + else reference_resolver_async, ) return graphene.List( base_type._type, description=get_field_description(field, registry), - required=field.required + required=field.required, ) if isinstance(base_type, (graphene.Dynamic)): base_type = base_type.get_type() @@ -287,14 +320,12 @@ async def reference_resolver_async(root, *args, **kwargs): # Non-relationship field relations = (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField) if not isinstance(base_type, (graphene.List, graphene.NonNull)) and not isinstance( - field.field, relations + field.field, relations ): base_type = type(base_type) return graphene.List( - base_type, - description=get_field_description(field, registry), - required=field.required + base_type, description=get_field_description(field, registry), required=field.required ) @@ -319,10 +350,11 @@ def convert_field_to_union(field, registry=None, executor: ExecutorEnum = Execut if len(_types) == 0: return None - name = to_camel_case("{}_{}".format( - field._owner_document.__name__, - field.db_field - )) + "UnionType" if ExecutorEnum.SYNC else "AsyncUnionType" + name = ( + to_camel_case("{}_{}".format(field._owner_document.__name__, field.db_field)) + "UnionType" + if ExecutorEnum.SYNC + else "AsyncUnionType" + ) Meta = type("Meta", (object,), {"types": tuple(_types)}) _union = type(name, (graphene.Union,), {"Meta": Meta}) @@ -345,9 +377,12 @@ def reference_resolver(root, *args, **kwargs): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) - return document.objects().no_dereference().only(*list( - set(list(_type._meta.required_fields) + queried_fields))).get( - pk=de_referenced["_ref"].id) + return ( + document.objects() + .no_dereference() + .only(*list(set(list(_type._meta.required_fields) + queried_fields))) + .get(pk=de_referenced["_ref"].id) + ) return document() return None @@ -357,7 +392,9 @@ def lazy_reference_resolver(root, *args, **kwargs): if document._cached_doc: return document._cached_doc queried_fields = list() - document_field_type = registry.get_type_for_model(document.document_type, executor=executor) + document_field_type = registry.get_type_for_model( + document.document_type, executor=executor + ) querying_types = list(get_query_fields(args[0]).keys()) filter_args = list() if document_field_type._meta.filter_fields: @@ -370,9 +407,12 @@ def lazy_reference_resolver(root, *args, **kwargs): if item in document.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) _type = registry.get_type_for_model(document.document_type, executor=executor) - return document.document_type.objects().no_dereference().only( - *(set((list(_type._meta.required_fields) + queried_fields)))).get( - pk=document.pk) + return ( + document.document_type.objects() + .no_dereference() + .only(*(set((list(_type._meta.required_fields) + queried_fields)))) + .get(pk=document.pk) + ) return document.document_type() return None @@ -381,7 +421,9 @@ async def reference_resolver_async(root, *args, **kwargs): if de_referenced: document = get_document(de_referenced["_cls"]) document_field = mongoengine.ReferenceField(document) - document_field = convert_mongoengine_field(document_field, registry, executor=ExecutorEnum.ASYNC) + document_field = convert_mongoengine_field( + document_field, registry, executor=ExecutorEnum.ASYNC + ) _type = document_field.get_type().type filter_args = list() if _type._meta.filter_fields: @@ -395,8 +437,12 @@ async def reference_resolver_async(root, *args, **kwargs): item = to_snake_case(each) if item in document._fields_ordered + tuple(filter_args): queried_fields.append(item) - return await sync_to_async(document.objects().no_dereference().only(*list( - set(list(_type._meta.required_fields) + queried_fields))).get)(pk=de_referenced["_ref"].id) + return await sync_to_async( + document.objects() + .no_dereference() + .only(*list(set(list(_type._meta.required_fields) + queried_fields))) + .get + )(pk=de_referenced["_ref"].id) return await sync_to_async(document)() return None @@ -406,7 +452,9 @@ async def lazy_reference_resolver_async(root, *args, **kwargs): if document._cached_doc: return document._cached_doc queried_fields = list() - document_field_type = registry.get_type_for_model(document.document_type, executor=executor) + document_field_type = registry.get_type_for_model( + document.document_type, executor=executor + ) querying_types = list(get_query_fields(args[0]).keys()) filter_args = list() if document_field_type._meta.filter_fields: @@ -419,8 +467,12 @@ async def lazy_reference_resolver_async(root, *args, **kwargs): if item in document.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) _type = registry.get_type_for_model(document.document_type, executor=executor) - return await sync_to_async(document.document_type.objects().no_dereference().only( - *(set((list(_type._meta.required_fields) + queried_fields)))).get)(pk=document.pk) + return await sync_to_async( + document.document_type.objects() + .no_dereference() + .only(*(set((list(_type._meta.required_fields) + queried_fields)))) + .get + )(pk=document.pk) return await sync_to_async(document.document_type)() return None @@ -429,28 +481,48 @@ async def lazy_reference_resolver_async(root, *args, **kwargs): required = False if field.db_field is not None: required = field.required - resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor), - "resolve_" + field.db_field, - None) + resolver_function = getattr( + registry.get_type_for_model(field.owner_document, executor=executor), + "resolve_" + field.db_field, + None, + ) if resolver_function and callable(resolver_function): field_resolver = resolver_function - return graphene.Field(_union, resolver=field_resolver if field_resolver else ( - lazy_reference_resolver if executor == ExecutorEnum.SYNC else lazy_reference_resolver_async), - description=get_field_description(field, registry), required=required) + return graphene.Field( + _union, + resolver=field_resolver + if field_resolver + else ( + lazy_reference_resolver + if executor == ExecutorEnum.SYNC + else lazy_reference_resolver_async + ), + description=get_field_description(field, registry), + required=required, + ) elif isinstance(field, mongoengine.GenericReferenceField): field_resolver = None required = False if field.db_field is not None: required = field.required - resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor), - "resolve_" + field.db_field, - None) + resolver_function = getattr( + registry.get_type_for_model(field.owner_document, executor=executor), + "resolve_" + field.db_field, + None, + ) if resolver_function and callable(resolver_function): field_resolver = resolver_function - return graphene.Field(_union, resolver=field_resolver if field_resolver else ( - reference_resolver if executor == ExecutorEnum.SYNC else reference_resolver_async), - description=get_field_description(field, registry), required=required) + return graphene.Field( + _union, + resolver=field_resolver + if field_resolver + else ( + reference_resolver if executor == ExecutorEnum.SYNC else reference_resolver_async + ), + description=get_field_description(field, registry), + required=required, + ) return graphene.Field(_union) @@ -475,9 +547,12 @@ def reference_resolver(root, *args, **kwargs): item = to_snake_case(each) if item in field.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) - return field.document_type.objects().no_dereference().only( - *(set(list(_type._meta.required_fields) + queried_fields))).get( - pk=document.id) + return ( + field.document_type.objects() + .no_dereference() + .only(*(set(list(_type._meta.required_fields) + queried_fields))) + .get(pk=document.id) + ) return None def cached_reference_resolver(root, *args, **kwargs): @@ -493,10 +568,12 @@ def cached_reference_resolver(root, *args, **kwargs): item = to_snake_case(each) if item in field.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) - return field.document_type.objects().no_dereference().only( - *(set( - list(_type._meta.required_fields) + queried_fields))).get( - pk=getattr(root, field.name or field.db_name)) + return ( + field.document_type.objects() + .no_dereference() + .only(*(set(list(_type._meta.required_fields) + queried_fields))) + .get(pk=getattr(root, field.name or field.db_name)) + ) return None async def reference_resolver_async(root, *args, **kwargs): @@ -513,8 +590,12 @@ async def reference_resolver_async(root, *args, **kwargs): item = to_snake_case(each) if item in field.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) - return await sync_to_async(field.document_type.objects().no_dereference().only( - *(set(list(_type._meta.required_fields) + queried_fields))).get)(pk=document.id) + return await sync_to_async( + field.document_type.objects() + .no_dereference() + .only(*(set(list(_type._meta.required_fields) + queried_fields))) + .get + )(pk=document.id) return None async def cached_reference_resolver_async(root, *args, **kwargs): @@ -530,10 +611,12 @@ async def cached_reference_resolver_async(root, *args, **kwargs): item = to_snake_case(each) if item in field.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) - return await sync_to_async(field.document_type.objects().no_dereference().only( - *(set( - list(_type._meta.required_fields) + queried_fields))).get)( - pk=getattr(root, field.name or field.db_name)) + return await sync_to_async( + field.document_type.objects() + .no_dereference() + .only(*(set(list(_type._meta.required_fields) + queried_fields))) + .get + )(pk=getattr(root, field.name or field.db_name)) return None def dynamic_type(): @@ -541,25 +624,46 @@ def dynamic_type(): if not _type: return None if isinstance(field, mongoengine.EmbeddedDocumentField): - return graphene.Field(_type, - description=get_field_description(field, registry), required=field.required) + return graphene.Field( + _type, description=get_field_description(field, registry), required=field.required + ) field_resolver = None required = False if field.db_field is not None: required = field.required - resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor), - "resolve_" + field.db_field, - None) + resolver_function = getattr( + registry.get_type_for_model(field.owner_document, executor=executor), + "resolve_" + field.db_field, + None, + ) if resolver_function and callable(resolver_function): field_resolver = resolver_function if isinstance(field, mongoengine.ReferenceField): - return graphene.Field(_type, resolver=field_resolver if field_resolver else ( - reference_resolver if executor == ExecutorEnum.SYNC else reference_resolver_async), - description=get_field_description(field, registry), required=required) + return graphene.Field( + _type, + resolver=field_resolver + if field_resolver + else ( + reference_resolver + if executor == ExecutorEnum.SYNC + else reference_resolver_async + ), + description=get_field_description(field, registry), + required=required, + ) else: - return graphene.Field(_type, resolver=field_resolver if field_resolver else ( - cached_reference_resolver if executor == ExecutorEnum.SYNC else cached_reference_resolver_async), - description=get_field_description(field, registry), required=required) + return graphene.Field( + _type, + resolver=field_resolver + if field_resolver + else ( + cached_reference_resolver + if executor == ExecutorEnum.SYNC + else cached_reference_resolver_async + ), + description=get_field_description(field, registry), + required=required, + ) return graphene.Dynamic(dynamic_type) @@ -584,9 +688,12 @@ def lazy_resolver(root, *args, **kwargs): item = to_snake_case(each) if item in document.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) - return document.document_type.objects().no_dereference().only( - *(set((list(_type._meta.required_fields) + queried_fields)))).get( - pk=document.pk) + return ( + document.document_type.objects() + .no_dereference() + .only(*(set((list(_type._meta.required_fields) + queried_fields)))) + .get(pk=document.pk) + ) return None async def lazy_resolver_async(root, *args, **kwargs): @@ -605,8 +712,12 @@ async def lazy_resolver_async(root, *args, **kwargs): item = to_snake_case(each) if item in document.document_type._fields_ordered + tuple(filter_args): queried_fields.append(item) - return await sync_to_async(document.document_type.objects().no_dereference().only( - *(set((list(_type._meta.required_fields) + queried_fields)))).get)(pk=document.pk) + return await sync_to_async( + document.document_type.objects() + .no_dereference() + .only(*(set((list(_type._meta.required_fields) + queried_fields)))) + .get + )(pk=document.pk) return None def dynamic_type(): @@ -617,26 +728,32 @@ def dynamic_type(): required = False if field.db_field is not None: required = field.required - resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor), - "resolve_" + field.db_field, - None) + resolver_function = getattr( + registry.get_type_for_model(field.owner_document, executor=executor), + "resolve_" + field.db_field, + None, + ) if resolver_function and callable(resolver_function): field_resolver = resolver_function return graphene.Field( _type, - resolver=field_resolver if field_resolver else ( - lazy_resolver if executor == ExecutorEnum.SYNC else lazy_resolver_async), - description=get_field_description(field, registry), required=required, + resolver=field_resolver + if field_resolver + else (lazy_resolver if executor == ExecutorEnum.SYNC else lazy_resolver_async), + description=get_field_description(field, registry), + required=required, ) return graphene.Dynamic(dynamic_type) if sys.version_info >= (3, 6): + @convert_mongoengine_field.register(mongoengine.EnumField) def convert_field_to_enum(field, registry=None, executor: ExecutorEnum = ExecutorEnum.SYNC): if not registry.check_enum_already_exist(field._enum_cls): registry.register_enum(field._enum_cls) _type = registry.get_type_for_enum(field._enum_cls) - return graphene.Field(_type, - description=get_field_description(field, registry), required=field.required) + return graphene.Field( + _type, description=get_field_description(field, registry), required=field.required + ) diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index 3cef0f44..7fb9fed9 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -26,12 +26,18 @@ FileFieldType, PointFieldType, MultiPolygonFieldType, - PolygonFieldType, PointFieldInputType, + PolygonFieldType, + PointFieldInputType, ) from .converter import convert_mongoengine_field, MongoEngineConversionError from .registry import get_global_registry -from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \ - connection_from_iterables, ExecutorEnum +from .utils import ( + get_model_reference_fields, + get_query_fields, + find_skip_and_limit, + connection_from_iterables, + ExecutorEnum, +) import pymongo PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) @@ -100,14 +106,13 @@ def args(self): _filter_args.pop(_field) if _field in _extended_args: _filter_args.pop(_field) - extra_args = dict(dict(dict(_field_args, **_advance_args), **_filter_args), **_extended_args) + extra_args = dict( + dict(dict(_field_args, **_advance_args), **_filter_args), **_extended_args + ) for key in list(self._base_args.keys()): extra_args.pop(key, None) - return to_arguments( - self._base_args or OrderedDict(), - extra_args - ) + return to_arguments(self._base_args or OrderedDict(), extra_args) @args.setter def args(self, args): @@ -123,7 +128,7 @@ def is_filterable(k): Returns: bool """ - if hasattr(self.fields[k].type, '_sdl'): + if hasattr(self.fields[k].type, "_sdl"): return False if not hasattr(self.model, k): return False @@ -146,25 +151,32 @@ def is_filterable(k): if isinstance(converted, (ConnectionField, Dynamic)): return False if callable(getattr(converted, "type", None)) and isinstance( - converted.type(), - ( - FileFieldType, - PointFieldType, - MultiPolygonFieldType, - graphene.Union, - PolygonFieldType, - ), + converted.type(), + ( + FileFieldType, + PointFieldType, + MultiPolygonFieldType, + graphene.Union, + PolygonFieldType, + ), ): return False - if getattr(converted, "type", None) and getattr(converted.type, "_of_type", None) and issubclass( - (get_type(converted.type.of_type)), graphene.Union): + if ( + getattr(converted, "type", None) + and getattr(converted.type, "_of_type", None) + and issubclass((get_type(converted.type.of_type)), graphene.Union) + ): return False if isinstance(converted, (graphene.List)) and issubclass( - getattr(converted, "_of_type", None), graphene.Union + getattr(converted, "_of_type", None), graphene.Union ): return False # below if condition: workaround for DB filterable field redefined as custom graphene type - if hasattr(field_, 'type') and hasattr(converted, 'type') and converted.type != field_.type: + if ( + hasattr(field_, "type") + and hasattr(converted, "type") + and converted.type != field_.type + ): return False return True @@ -188,8 +200,11 @@ def filter_args(self): if self._type._meta.filter_fields: for field, filter_collection in self._type._meta.filter_fields.items(): for each in filter_collection: - if str(self._type._meta.fields[field].type) in ('PointFieldType', 'PointFieldType!'): - if each == 'max_distance': + if str(self._type._meta.fields[field].type) in ( + "PointFieldType", + "PointFieldType!", + ): + if each == "max_distance": filter_type = graphene.Int else: filter_type = PointFieldInputType @@ -205,9 +220,7 @@ def filter_args(self): "all": graphene.List(filter_type), } filter_type = advanced_filter_types.get(each, filter_type) - filter_args[field + "__" + each] = graphene.Argument( - type_=filter_type - ) + filter_args[field + "__" + each] = graphene.Argument(type_=filter_type) return filter_args @property @@ -219,8 +232,12 @@ def get_advance_field(r, kv): r.update({kv[0]: graphene.Argument(PointFieldInputType)}) return r if isinstance( - mongo_field, - (mongoengine.LazyReferenceField, mongoengine.ReferenceField, mongoengine.GenericReferenceField), + mongo_field, + ( + mongoengine.LazyReferenceField, + mongoengine.ReferenceField, + mongoengine.GenericReferenceField, + ), ): r.update({kv[0]: graphene.ID()}) return r @@ -230,9 +247,13 @@ def get_advance_field(r, kv): if callable(getattr(field, "get_type", None)): _type = field.get_type() if _type: - node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta + node = ( + _type.type._meta + if hasattr(_type.type, "_meta") + else _type.type._of_type._meta + ) if "id" in node.fields and not issubclass( - node.model, (mongoengine.EmbeddedDocument,) + node.model, (mongoengine.EmbeddedDocument,) ): r.update({kv[0]: node.fields["id"]._type.of_type()}) @@ -244,7 +265,7 @@ def get_advance_field(r, kv): def extended_args(self): args = OrderedDict() for k, each in self.fields.items(): - if hasattr(each.type, '_sdl'): + if hasattr(each.type, "_sdl"): args.update({k: graphene.ID()}) return args @@ -253,7 +274,9 @@ def fields(self): self._type = get_type(self._type) return self._type._meta.fields - def get_queryset(self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args): + def get_queryset( + self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args + ): if required_fields is None: required_fields = list() @@ -261,28 +284,33 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None, reference_fields = get_model_reference_fields(self.model) hydrated_references = {} for arg_name, arg in args.copy().items(): - if arg_name in reference_fields and not isinstance(arg, - mongoengine.base.metaclasses.TopLevelDocumentMetaclass): + if arg_name in reference_fields and not isinstance( + arg, mongoengine.base.metaclasses.TopLevelDocumentMetaclass + ): try: - reference_obj = reference_fields[arg_name].document_type(pk=from_global_id(arg)[1]) + reference_obj = reference_fields[arg_name].document_type( + pk=from_global_id(arg)[1] + ) except TypeError: reference_obj = reference_fields[arg_name].document_type(pk=arg) hydrated_references[arg_name] = reference_obj - elif arg_name in self.model._fields_ordered and isinstance(getattr(self.model, arg_name), - mongoengine.fields.GenericReferenceField): + elif arg_name in self.model._fields_ordered and isinstance( + getattr(self.model, arg_name), mongoengine.fields.GenericReferenceField + ): try: - reference_obj = get_document(self.registry._registry_string_map[from_global_id(arg)[0]])( - pk=from_global_id(arg)[1]) + reference_obj = get_document( + self.registry._registry_string_map[from_global_id(arg)[0]] + )(pk=from_global_id(arg)[1]) except TypeError: - reference_obj = get_document(arg["_cls"])( - pk=arg["_ref"].id) + reference_obj = get_document(arg["_cls"])(pk=arg["_ref"].id) hydrated_references[arg_name] = reference_obj - elif '__near' in arg_name and isinstance(getattr(self.model, arg_name.split('__')[0]), - mongoengine.fields.PointField): + elif "__near" in arg_name and isinstance( + getattr(self.model, arg_name.split("__")[0]), mongoengine.fields.PointField + ): location = args.pop(arg_name, None) hydrated_references[arg_name] = location["coordinates"] - if (arg_name.split('__')[0] + "__max_distance") not in args: - hydrated_references[arg_name.split('__')[0] + "__max_distance"] = 10000 + if (arg_name.split("__")[0] + "__max_distance") not in args: + hydrated_references[arg_name.split("__")[0] + "__max_distance"] = 10000 elif arg_name == "id": hydrated_references["id"] = from_global_id(args.pop("id", None))[1] args.update(hydrated_references) @@ -299,22 +327,44 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None, order_by = self.order_by + ",-pk" else: order_by = "-pk" - return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip( - skip if skip else 0).limit(limit) + return ( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(order_by) + .skip(skip if skip else 0) + .limit(limit) + ) else: - return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip( - skip if skip else 0).limit(limit) + return ( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(self.order_by) + .skip(skip if skip else 0) + .limit(limit) + ) elif skip is not None: if reversed: if self.order_by: order_by = self.order_by + ",-pk" else: order_by = "-pk" - return model.objects(**args).no_dereference().only(*required_fields).order_by(order_by).skip( - skip) + return ( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(order_by) + .skip(skip) + ) else: - return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by).skip( - skip) + return ( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(self.order_by) + .skip(skip) + ) return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by) def default_resolver(self, _root, info, required_fields=None, resolved=None, **args): @@ -329,17 +379,19 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a if not hasattr(_root, "_fields_ordered"): if isinstance(getattr(_root, field_name, []), list): args["pk__in"] = [r.id for r in getattr(_root, field_name, [])] - elif field_name in _root._fields_ordered and not (isinstance(_root._fields[field_name].field, - mongoengine.EmbeddedDocumentField) or - isinstance(_root._fields[field_name].field, - mongoengine.GenericEmbeddedDocumentField)): + elif field_name in _root._fields_ordered and not ( + isinstance(_root._fields[field_name].field, mongoengine.EmbeddedDocumentField) + or isinstance( + _root._fields[field_name].field, mongoengine.GenericEmbeddedDocumentField + ) + ): if getattr(_root, field_name, []) is not None: args["pk__in"] = [r.id for r in getattr(_root, field_name, [])] - _id = args.pop('id', None) + _id = args.pop("id", None) if _id is not None: - args['pk'] = from_global_id(_id)[-1] + args["pk"] = from_global_id(_id)[-1] iterables = [] list_length = 0 skip = 0 @@ -373,8 +425,9 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a queryset = None count = len(items) - skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before, - count=count) + skip, limit, reverse = find_skip_and_limit( + first=first, last=last, after=after, before=before, count=count + ) if isinstance(items, QuerySet): if limit: @@ -387,64 +440,84 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a else: if limit: if reverse: - items = items[::-1][skip:skip + limit] + items = items[::-1][skip : skip + limit] else: - items = items[skip:skip + limit] + items = items[skip : skip + limit] elif skip: items = items[skip:] iterables = list(items) list_length = len(iterables) elif callable(getattr(self.model, "objects", None)): - if _root is None or args or isinstance(getattr(_root, field_name, []), MongoengineConnectionField): + if ( + _root is None + or args + or isinstance(getattr(_root, field_name, []), MongoengineConnectionField) + ): args_copy = args.copy() for key in args.copy(): if key not in self.model._fields_ordered: args_copy.pop(key) - elif isinstance(getattr(self.model, key), - mongoengine.fields.ReferenceField) or isinstance(getattr(self.model, key), - mongoengine.fields.GenericReferenceField) or isinstance( - getattr(self.model, key), - mongoengine.fields.LazyReferenceField) or isinstance(getattr(self.model, key), - mongoengine.fields.CachedReferenceField): + elif ( + isinstance(getattr(self.model, key), mongoengine.fields.ReferenceField) + or isinstance( + getattr(self.model, key), mongoengine.fields.GenericReferenceField + ) + or isinstance( + getattr(self.model, key), mongoengine.fields.LazyReferenceField + ) + or isinstance( + getattr(self.model, key), mongoengine.fields.CachedReferenceField + ) + ): if not isinstance(args_copy[key], ObjectId): _from_global_id = from_global_id(args_copy[key])[1] if bson.objectid.ObjectId.is_valid(_from_global_id): args_copy[key] = ObjectId(_from_global_id) else: args_copy[key] = _from_global_id - elif isinstance(getattr(self.model, key), - mongoengine.fields.EnumField): + elif isinstance(getattr(self.model, key), mongoengine.fields.EnumField): if getattr(args_copy[key], "value", None): args_copy[key] = args_copy[key].value if PYMONGO_VERSION >= (3, 7): - if hasattr(self.model, '_meta') and 'db_alias' in self.model._meta: - count = (mongoengine.get_db(self.model._meta['db_alias'])[ - self.model._get_collection_name()]).count_documents(args_copy) + if hasattr(self.model, "_meta") and "db_alias" in self.model._meta: + count = ( + mongoengine.get_db(self.model._meta["db_alias"])[ + self.model._get_collection_name() + ] + ).count_documents(args_copy) else: - count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy) + count = ( + mongoengine.get_db()[self.model._get_collection_name()] + ).count_documents(args_copy) else: count = self.model.objects(args_copy).count() if count != 0: - skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before, - count=count) - iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args) + skip, limit, reverse = find_skip_and_limit( + first=first, after=after, last=last, before=before, count=count + ) + iterables = self.get_queryset( + self.model, info, required_fields, skip, limit, reverse, **args + ) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset(self.model, info, required_fields, **args) + info.context.queryset = self.get_queryset( + self.model, info, required_fields, **args + ) elif "pk__in" in args and args["pk__in"]: count = len(args["pk__in"]) - skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before, - count=count) + skip, limit, reverse = find_skip_and_limit( + first=first, last=last, after=after, before=before, count=count + ) if limit: if reverse: - args["pk__in"] = args["pk__in"][::-1][skip:skip + limit] + args["pk__in"] = args["pk__in"][::-1][skip : skip + limit] else: - args["pk__in"] = args["pk__in"][skip:skip + limit] + args["pk__in"] = args["pk__in"][skip : skip + limit] elif skip: args["pk__in"] = args["pk__in"][skip:] iterables = self.get_queryset(self.model, info, required_fields, **args) @@ -452,26 +525,33 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset(self.model, info, required_fields, **args) + info.context.queryset = self.get_queryset( + self.model, info, required_fields, **args + ) elif _root is not None: field_name = to_snake_case(info.field_name) items = getattr(_root, field_name, []) count = len(items) - skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before, - count=count) + skip, limit, reverse = find_skip_and_limit( + first=first, last=last, after=after, before=before, count=count + ) if limit: if reverse: - items = items[::-1][skip:skip + limit] + items = items[::-1][skip : skip + limit] else: - items = items[skip:skip + limit] + items = items[skip : skip + limit] elif skip: items = items[skip:] iterables = items list_length = len(iterables) if count: - has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False + has_next_page = ( + True + if (0 if limit is None else limit) + (0 if skip is None else skip) < count + else False + ) else: if isinstance(queryset, QuerySet) and iterables: has_next_page = bool(queryset(pk__gt=iterables[-1].pk).limit(1).first()) @@ -482,19 +562,21 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a iterables = list(iterables) iterables.reverse() skip = limit - connection = connection_from_iterables(edges=iterables, start_offset=skip, - has_previous_page=has_previous_page, - has_next_page=has_next_page, - connection_type=self.type, - edge_type=self.type.Edge, - pageinfo_type=graphene.PageInfo) + connection = connection_from_iterables( + edges=iterables, + start_offset=skip, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + connection_type=self.type, + edge_type=self.type.Edge, + pageinfo_type=graphene.PageInfo, + ) connection.iterable = iterables connection.list_length = list_length return connection def chained_resolver(self, resolver, is_partial, root, info, **args): - for key, value in dict(args).items(): if value is None: del args[key] @@ -512,20 +594,28 @@ def chained_resolver(self, resolver, is_partial, root, info, **args): args_copy = args.copy() if not bool(args) or not is_partial: - if isinstance(self.model, mongoengine.Document) or isinstance(self.model, - mongoengine.base.metaclasses.TopLevelDocumentMetaclass): - + if isinstance(self.model, mongoengine.Document) or isinstance( + self.model, mongoengine.base.metaclasses.TopLevelDocumentMetaclass + ): from itertools import filterfalse - connection_fields = [field for field in self.fields if - type(self.fields[field]) == MongoengineConnectionField] - filterable_args = tuple(filterfalse(connection_fields.__contains__, list(self.model._fields_ordered))) + + connection_fields = [ + field + for field in self.fields + if type(self.fields[field]) == MongoengineConnectionField + ] + filterable_args = tuple( + filterfalse(connection_fields.__contains__, list(self.model._fields_ordered)) + ) for arg_name, arg in args.copy().items(): if arg_name not in filterable_args + tuple(self.filter_args.keys()): args_copy.pop(arg_name) if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy) + info.context.queryset = self.get_queryset( + self.model, info, required_fields, **args_copy + ) # XXX: Filter nested args resolved = resolver(root, info, **args) @@ -542,28 +632,35 @@ def chained_resolver(self, resolver, is_partial, root, info, **args): args.update(resolved._query) args_copy = args.copy() for arg_name, arg in args.copy().items(): - if "." in arg_name or arg_name not in self.model._fields_ordered \ - + ('first', 'last', 'before', 'after') + tuple(self.filter_args.keys()): + if "." in arg_name or arg_name not in self.model._fields_ordered + ( + "first", + "last", + "before", + "after", + ) + tuple(self.filter_args.keys()): args_copy.pop(arg_name) - if arg_name == '_id' and isinstance(arg, dict): + if arg_name == "_id" and isinstance(arg, dict): operation = list(arg.keys())[0] - args_copy['pk' + operation.replace('$', '__')] = arg[operation] - if not isinstance(arg, ObjectId) and '.' in arg_name: + args_copy["pk" + operation.replace("$", "__")] = arg[operation] + if not isinstance(arg, ObjectId) and "." in arg_name: if type(arg) == dict: operation = list(arg.keys())[0] - args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[ - operation] + args_copy[ + arg_name.replace(".", "__") + operation.replace("$", "__") + ] = arg[operation] else: - args_copy[arg_name.replace('.', '__')] = arg - elif '.' in arg_name and isinstance(arg, ObjectId): - args_copy[arg_name.replace('.', '__')] = arg + args_copy[arg_name.replace(".", "__")] = arg + elif "." in arg_name and isinstance(arg, ObjectId): + args_copy[arg_name.replace(".", "__")] = arg else: operations = ["$lte", "$gte", "$ne", "$in"] if isinstance(arg, dict) and any(op in arg for op in operations): operation = list(arg.keys())[0] - args_copy[arg_name + operation.replace('$', '__')] = arg[operation] + args_copy[arg_name + operation.replace("$", "__")] = arg[operation] del args_copy[arg_name] - return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy) + return self.default_resolver( + root, info, required_fields, resolved=resolved, **args_copy + ) elif isinstance(resolved, Promise): return resolved.value else: diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index d061968a..7f077677 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -17,8 +17,13 @@ from concurrent.futures import ThreadPoolExecutor from .registry import get_global_async_registry from . import MongoengineConnectionField -from .utils import get_query_fields, find_skip_and_limit, \ - connection_from_iterables, ExecutorEnum, sync_to_async +from .utils import ( + get_query_fields, + find_skip_and_limit, + connection_from_iterables, + ExecutorEnum, + sync_to_async, +) import pymongo PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) @@ -65,17 +70,19 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if not hasattr(_root, "_fields_ordered"): if isinstance(getattr(_root, field_name, []), list): args["pk__in"] = [r.id for r in getattr(_root, field_name, [])] - elif field_name in _root._fields_ordered and not (isinstance(_root._fields[field_name].field, - mongoengine.EmbeddedDocumentField) or - isinstance(_root._fields[field_name].field, - mongoengine.GenericEmbeddedDocumentField)): + elif field_name in _root._fields_ordered and not ( + isinstance(_root._fields[field_name].field, mongoengine.EmbeddedDocumentField) + or isinstance( + _root._fields[field_name].field, mongoengine.GenericEmbeddedDocumentField + ) + ): if getattr(_root, field_name, []) is not None: args["pk__in"] = [r.id for r in getattr(_root, field_name, [])] - _id = args.pop('id', None) + _id = args.pop("id", None) if _id is not None: - args['pk'] = from_global_id(_id)[-1] + args["pk"] = from_global_id(_id)[-1] iterables = [] list_length = 0 skip = 0 @@ -107,8 +114,9 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non else: count = len(items) - skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before, - count=count) + skip, limit, reverse = find_skip_and_limit( + first=first, last=last, after=after, before=before, count=count + ) if isinstance(items, QuerySet): queryset = items.clone() @@ -123,62 +131,82 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non queryset = None if limit: if reverse: - items = items[::-1][skip:skip + limit] + items = items[::-1][skip : skip + limit] else: - items = items[skip:skip + limit] + items = items[skip : skip + limit] elif skip: items = items[skip:] iterables = await sync_to_async(list)(items) list_length = len(iterables) elif callable(getattr(self.model, "objects", None)): - if _root is None or args or isinstance(getattr(_root, field_name, []), AsyncMongoengineConnectionField): + if ( + _root is None + or args + or isinstance(getattr(_root, field_name, []), AsyncMongoengineConnectionField) + ): args_copy = args.copy() for key in args.copy(): if key not in self.model._fields_ordered: args_copy.pop(key) - elif isinstance(getattr(self.model, key), - mongoengine.fields.ReferenceField) or isinstance(getattr(self.model, key), - mongoengine.fields.GenericReferenceField) or isinstance( - getattr(self.model, key), - mongoengine.fields.LazyReferenceField) or isinstance(getattr(self.model, key), - mongoengine.fields.CachedReferenceField): + elif ( + isinstance(getattr(self.model, key), mongoengine.fields.ReferenceField) + or isinstance( + getattr(self.model, key), mongoengine.fields.GenericReferenceField + ) + or isinstance( + getattr(self.model, key), mongoengine.fields.LazyReferenceField + ) + or isinstance( + getattr(self.model, key), mongoengine.fields.CachedReferenceField + ) + ): if not isinstance(args_copy[key], ObjectId): _from_global_id = from_global_id(args_copy[key])[1] if bson.objectid.ObjectId.is_valid(_from_global_id): args_copy[key] = ObjectId(_from_global_id) else: args_copy[key] = _from_global_id - elif isinstance(getattr(self.model, key), - mongoengine.fields.EnumField): + elif isinstance(getattr(self.model, key), mongoengine.fields.EnumField): if getattr(args_copy[key], "value", None): args_copy[key] = args_copy[key].value if PYMONGO_VERSION >= (3, 7): - count = await sync_to_async((mongoengine.get_db()[self.model._get_collection_name()]).count_documents)(args_copy) + count = await sync_to_async( + (mongoengine.get_db()[self.model._get_collection_name()]).count_documents + )(args_copy) else: - count = await sync_to_async(self.model.objects(args_copy).count, thread_sensitive=False, - executor=ThreadPoolExecutor())() + count = await sync_to_async( + self.model.objects(args_copy).count, + thread_sensitive=False, + executor=ThreadPoolExecutor(), + )() if count != 0: - skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before, - count=count) - iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args) + skip, limit, reverse = find_skip_and_limit( + first=first, after=after, last=last, before=before, count=count + ) + iterables = self.get_queryset( + self.model, info, required_fields, skip, limit, reverse, **args + ) iterables = await sync_to_async(list)(iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset(self.model, info, required_fields, **args) + info.context.queryset = self.get_queryset( + self.model, info, required_fields, **args + ) elif "pk__in" in args and args["pk__in"]: count = len(args["pk__in"]) - skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before, - count=count) + skip, limit, reverse = find_skip_and_limit( + first=first, last=last, after=after, before=before, count=count + ) if limit: if reverse: - args["pk__in"] = args["pk__in"][::-1][skip:skip + limit] + args["pk__in"] = args["pk__in"][::-1][skip : skip + limit] else: - args["pk__in"] = args["pk__in"][skip:skip + limit] + args["pk__in"] = args["pk__in"][skip : skip + limit] elif skip: args["pk__in"] = args["pk__in"][skip:] iterables = self.get_queryset(self.model, info, required_fields, **args) @@ -187,30 +215,39 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset(self.model, info, required_fields, **args) + info.context.queryset = self.get_queryset( + self.model, info, required_fields, **args + ) elif _root is not None: field_name = to_snake_case(info.field_name) items = getattr(_root, field_name, []) count = len(items) - skip, limit, reverse = find_skip_and_limit(first=first, last=last, after=after, before=before, - count=count) + skip, limit, reverse = find_skip_and_limit( + first=first, last=last, after=after, before=before, count=count + ) if limit: if reverse: - items = items[::-1][skip:skip + limit] + items = items[::-1][skip : skip + limit] else: - items = items[skip:skip + limit] + items = items[skip : skip + limit] elif skip: items = items[skip:] iterables = items - iterables = await sync_to_async(list)( iterables) + iterables = await sync_to_async(list)(iterables) list_length = len(iterables) if count: - has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False + has_next_page = ( + True + if (0 if limit is None else limit) + (0 if skip is None else skip) < count + else False + ) else: if isinstance(queryset, QuerySet) and iterables: - has_next_page = bool(await sync_to_async(queryset(pk__gt=iterables[-1].pk).limit(1).first)()) + has_next_page = bool( + await sync_to_async(queryset(pk__gt=iterables[-1].pk).limit(1).first)() + ) else: has_next_page = False has_previous_page = True if skip else False @@ -218,19 +255,21 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non iterables = await sync_to_async(list)(iterables) iterables.reverse() skip = limit - connection = connection_from_iterables(edges=iterables, start_offset=skip, - has_previous_page=has_previous_page, - has_next_page=has_next_page, - connection_type=self.type, - edge_type=self.type.Edge, - pageinfo_type=graphene.PageInfo) + connection = connection_from_iterables( + edges=iterables, + start_offset=skip, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + connection_type=self.type, + edge_type=self.type.Edge, + pageinfo_type=graphene.PageInfo, + ) connection.iterable = iterables connection.list_length = list_length return connection async def chained_resolver(self, resolver, is_partial, root, info, **args): - for key, value in dict(args).items(): if value is None: del args[key] @@ -248,20 +287,28 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args): args_copy = args.copy() if not bool(args) or not is_partial: - if isinstance(self.model, mongoengine.Document) or isinstance(self.model, - mongoengine.base.metaclasses.TopLevelDocumentMetaclass): - + if isinstance(self.model, mongoengine.Document) or isinstance( + self.model, mongoengine.base.metaclasses.TopLevelDocumentMetaclass + ): from itertools import filterfalse - connection_fields = [field for field in self.fields if - type(self.fields[field]) == AsyncMongoengineConnectionField] - filterable_args = tuple(filterfalse(connection_fields.__contains__, list(self.model._fields_ordered))) + + connection_fields = [ + field + for field in self.fields + if type(self.fields[field]) == AsyncMongoengineConnectionField + ] + filterable_args = tuple( + filterfalse(connection_fields.__contains__, list(self.model._fields_ordered)) + ) for arg_name, arg in args.copy().items(): if arg_name not in filterable_args + tuple(self.filter_args.keys()): args_copy.pop(arg_name) if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset(self.model, info, required_fields, **args_copy) + info.context.queryset = self.get_queryset( + self.model, info, required_fields, **args_copy + ) # XXX: Filter nested args resolved = resolver(root, info, **args) @@ -282,28 +329,35 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args): args_copy = args.copy() for arg_name, arg in args.copy().items(): if "." in arg_name or arg_name not in self.model._fields_ordered + ( - 'first', 'last', 'before', 'after') + tuple(self.filter_args.keys()): + "first", + "last", + "before", + "after", + ) + tuple(self.filter_args.keys()): args_copy.pop(arg_name) - if arg_name == '_id' and isinstance(arg, dict): + if arg_name == "_id" and isinstance(arg, dict): operation = list(arg.keys())[0] - args_copy['pk' + operation.replace('$', '__')] = arg[operation] - if not isinstance(arg, ObjectId) and '.' in arg_name: + args_copy["pk" + operation.replace("$", "__")] = arg[operation] + if not isinstance(arg, ObjectId) and "." in arg_name: if type(arg) == dict: operation = list(arg.keys())[0] - args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[ - operation] + args_copy[ + arg_name.replace(".", "__") + operation.replace("$", "__") + ] = arg[operation] else: - args_copy[arg_name.replace('.', '__')] = arg - elif '.' in arg_name and isinstance(arg, ObjectId): - args_copy[arg_name.replace('.', '__')] = arg + args_copy[arg_name.replace(".", "__")] = arg + elif "." in arg_name and isinstance(arg, ObjectId): + args_copy[arg_name.replace(".", "__")] = arg else: operations = ["$lte", "$gte", "$ne", "$in"] if isinstance(arg, dict) and any(op in arg for op in operations): operation = list(arg.keys())[0] - args_copy[arg_name + operation.replace('$', '__')] = arg[operation] + args_copy[arg_name + operation.replace("$", "__")] = arg[operation] del args_copy[arg_name] - return await self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy) + return await self.default_resolver( + root, info, required_fields, resolved=resolved, **args_copy + ) elif isinstance(resolved, Promise): return resolved.value else: diff --git a/graphene_mongo/registry.py b/graphene_mongo/registry.py index 62fd0cea..70e88480 100644 --- a/graphene_mongo/registry.py +++ b/graphene_mongo/registry.py @@ -15,13 +15,10 @@ def register(self, cls): from .types import GrapheneMongoengineObjectTypes from .types_async import AsyncGrapheneMongoengineObjectTypes - assert (issubclass( - cls, - GrapheneMongoengineObjectTypes - ) or issubclass( - cls, - AsyncGrapheneMongoengineObjectTypes - )), 'Only Mongoengine/Async Mongoengine object types can be registered, received "{}"'.format( + assert ( + issubclass(cls, GrapheneMongoengineObjectTypes) + or issubclass(cls, AsyncGrapheneMongoengineObjectTypes) + ), 'Only Mongoengine/Async Mongoengine object types can be registered, received "{}"'.format( cls.__name__ ) assert cls._meta.registry == self, "Registry for a Model have to match." @@ -38,11 +35,12 @@ def register(self, cls): def register_enum(self, cls): from enum import EnumMeta + assert isinstance( cls, EnumMeta ), f'Only EnumMeta can be registered, received "{cls.__name__}"' - if not cls.__name__.endswith('Enum'): - name = cls.__name__ + 'Enum' + if not cls.__name__.endswith("Enum"): + name = cls.__name__ + "Enum" else: name = cls.__name__ cls.__name__ = name diff --git a/graphene_mongo/tests/conftest.py b/graphene_mongo/tests/conftest.py index 9d197285..479f29f7 100644 --- a/graphene_mongo/tests/conftest.py +++ b/graphene_mongo/tests/conftest.py @@ -1,22 +1,23 @@ import os +from datetime import datetime + import pytest -from datetime import datetime from .models import ( + AnotherChild, Article, + CellTower, + Child, + ChildRegisteredAfter, + ChildRegisteredBefore, Editor, EmbeddedArticle, + ParentWithRelationship, Player, - Reporter, - Child, - AnotherChild, ProfessorMetadata, ProfessorVector, - ChildRegisteredBefore, - ChildRegisteredAfter, - ParentWithRelationship, - CellTower, Publisher, + Reporter, ) current_dirname = os.path.dirname(os.path.abspath(__file__)) @@ -68,7 +69,7 @@ def fixtures(): last_name="Iverson", email="ai@gmail.com", awards=["2010-mvp"], - generic_references=[article1] + generic_references=[article1], ) reporter1.articles = [article1, article2] embedded_article1 = EmbeddedArticle(headline="Real", editor=editor1) @@ -82,13 +83,15 @@ def fixtures(): player1 = Player( first_name="Michael", last_name="Jordan", - articles=[article1, article2]) + articles=[article1, article2], + ) player1.save() player2 = Player( first_name="Magic", last_name="Johnson", opponent=player1, - articles=[article3]) + articles=[article3], + ) player2.save() player3 = Player(first_name="Larry", last_name="Bird", players=[player1, player2]) player3.save() @@ -165,7 +168,9 @@ def fixtures(): child4.save() parent = ParentWithRelationship( - name="Yui", before_child=[child3], after_child=[child4] + name="Yui", + before_child=[child3], + after_child=[child4], ) parent.save() diff --git a/graphene_mongo/tests/models.py b/graphene_mongo/tests/models.py index f9c2cae6..871b5c9e 100644 --- a/graphene_mongo/tests/models.py +++ b/graphene_mongo/tests/models.py @@ -1,5 +1,6 @@ -import mongoengine from datetime import datetime + +import mongoengine import mongomock from mongomock import gridfs @@ -80,7 +81,7 @@ class Reporter(mongoengine.Document): awards = mongoengine.ListField(mongoengine.StringField()) articles = mongoengine.ListField(mongoengine.ReferenceField(Article)) embedded_articles = mongoengine.ListField( - mongoengine.EmbeddedDocumentField(EmbeddedArticle) + mongoengine.EmbeddedDocumentField(EmbeddedArticle), ) embedded_list_articles = mongoengine.EmbeddedDocumentListField(EmbeddedArticle) generic_reference = mongoengine.GenericReferenceField(choices=[Article, Editor], required=True) @@ -116,14 +117,12 @@ class CellTower(mongoengine.Document): class Child(Parent): - meta = {"collection": "test_parent"} baz = mongoengine.StringField() loc = mongoengine.PointField() class AnotherChild(Parent): - meta = {"collection": "test_parent"} qux = mongoengine.StringField() loc = mongoengine.PointField() @@ -146,10 +145,10 @@ class ProfessorVector(mongoengine.Document): class ParentWithRelationship(mongoengine.Document): meta = {"collection": "test_parent_reference"} before_child = mongoengine.ListField( - mongoengine.ReferenceField("ChildRegisteredBefore") + mongoengine.ReferenceField("ChildRegisteredBefore"), ) after_child = mongoengine.ListField( - mongoengine.ReferenceField("ChildRegisteredAfter") + mongoengine.ReferenceField("ChildRegisteredAfter"), ) name = mongoengine.StringField() diff --git a/graphene_mongo/tests/test_converter.py b/graphene_mongo/tests/test_converter.py index 7109bb8f..b111e4a6 100644 --- a/graphene_mongo/tests/test_converter.py +++ b/graphene_mongo/tests/test_converter.py @@ -91,9 +91,7 @@ def test_should_dict_convert_json(): def test_should_map_convert_json(): - assert_conversion( - mongoengine.MapField, graphene.JSONString, field=mongoengine.StringField() - ) + assert_conversion(mongoengine.MapField, graphene.JSONString, field=mongoengine.StringField()) def test_should_point_convert_field(): @@ -127,15 +125,11 @@ def test_should_file_convert_field(): def test_should_field_convert_list(): - assert_conversion( - mongoengine.ListField, graphene.List, field=mongoengine.StringField() - ) + assert_conversion(mongoengine.ListField, graphene.List, field=mongoengine.StringField()) def test_should_geo_convert_list(): - assert_conversion( - mongoengine.GeoPointField, graphene.List, field=mongoengine.FloatField() - ) + assert_conversion(mongoengine.GeoPointField, graphene.List, field=mongoengine.FloatField()) def test_should_reference_convert_dynamic(): @@ -144,9 +138,7 @@ class Meta: model = Editor interfaces = (graphene.Node,) - dynamic_field = convert_mongoengine_field( - EmbeddedArticle._fields["editor"], E._meta.registry - ) + dynamic_field = convert_mongoengine_field(EmbeddedArticle._fields["editor"], E._meta.registry) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -159,9 +151,7 @@ class Meta: model = Publisher interfaces = (graphene.Node,) - dynamic_field = convert_mongoengine_field( - Editor._fields["company"], P._meta.registry - ) + dynamic_field = convert_mongoengine_field(Editor._fields["company"], P._meta.registry) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -209,9 +199,7 @@ class A(MongoengineObjectType): class Meta: model = Article - graphene_field = convert_mongoengine_field( - Reporter._fields["articles"], A._meta.registry - ) + graphene_field = convert_mongoengine_field(Reporter._fields["articles"], A._meta.registry) assert isinstance(graphene_field, graphene.List) dynamic_field = graphene_field.get_type() assert dynamic_field._of_type == A @@ -270,17 +258,13 @@ class Meta: model = Player interfaces = (graphene.Node,) - dynamic_field = convert_mongoengine_field( - Player._fields["opponent"], P._meta.registry - ) + dynamic_field = convert_mongoengine_field(Player._fields["opponent"], P._meta.registry) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) assert graphene_type.type == P - graphene_field = convert_mongoengine_field( - Player._fields["players"], P._meta.registry - ) + graphene_field = convert_mongoengine_field(Player._fields["players"], P._meta.registry) assert isinstance(graphene_field, MongoengineConnectionField) @@ -293,9 +277,7 @@ class P(MongoengineObjectType): class Meta: model = Player - graphene_field = convert_mongoengine_field( - Player._fields["players"], P._meta.registry - ) + graphene_field = convert_mongoengine_field(Player._fields["players"], P._meta.registry) assert isinstance(graphene_field, graphene.List) dynamic_field = graphene_field.get_type() assert dynamic_field._of_type == P @@ -306,27 +288,16 @@ class A(MongoengineObjectType): class Meta: model = Article - headline_field = convert_mongoengine_field( - Article._fields["headline"], A._meta.registry - ) + headline_field = convert_mongoengine_field(Article._fields["headline"], A._meta.registry) assert headline_field.kwargs["description"] == "The article headline." - pubDate_field = convert_mongoengine_field( - Article._fields["pub_date"], A._meta.registry - ) - assert ( - pubDate_field.kwargs["description"] - == "Publication Date\nThe date of first press." - ) + pubDate_field = convert_mongoengine_field(Article._fields["pub_date"], A._meta.registry) + assert pubDate_field.kwargs["description"] == "Publication Date\nThe date of first press." - firstName_field = convert_mongoengine_field( - Editor._fields["first_name"], A._meta.registry - ) + firstName_field = convert_mongoengine_field(Editor._fields["first_name"], A._meta.registry) assert firstName_field.kwargs["description"] == "Editor's first name.\n(fname)" - metadata_field = convert_mongoengine_field( - Editor._fields["metadata"], A._meta.registry - ) + metadata_field = convert_mongoengine_field(Editor._fields["metadata"], A._meta.registry) assert metadata_field.kwargs["description"] == "Arbitrary metadata." @@ -339,9 +310,7 @@ class E(MongoengineObjectType): class Meta: model = Editor - editor_field = convert_mongoengine_field( - Article._fields["editor"], A._meta.registry - ).get_type() + editor_field = convert_mongoengine_field(Article._fields["editor"], A._meta.registry).get_type() assert editor_field.description == "An Editor of a publication." diff --git a/graphene_mongo/tests/test_inputs.py b/graphene_mongo/tests/test_inputs.py index 2006b1d6..9f792ef6 100644 --- a/graphene_mongo/tests/test_inputs.py +++ b/graphene_mongo/tests/test_inputs.py @@ -64,11 +64,9 @@ async def mutate(self, info, id, editor): return UpdateEditor(editor=editor_to_update) class Query(graphene.ObjectType): - node = Node.Field() class Mutation(graphene.ObjectType): - update_editor = UpdateEditor.Field() query = """ diff --git a/graphene_mongo/tests/test_query.py b/graphene_mongo/tests/test_query.py index 1a1f302f..85a15056 100644 --- a/graphene_mongo/tests/test_query.py +++ b/graphene_mongo/tests/test_query.py @@ -199,9 +199,7 @@ async def resolve_all_players(self, *args, **kwargs): @pytest.mark.asyncio async def test_should_query_with_embedded_document(fixtures): class Query(graphene.ObjectType): - professor_vector = graphene.Field( - types.ProfessorVectorType, id=graphene.String() - ) + professor_vector = graphene.Field(types.ProfessorVectorType, id=graphene.String()) async def resolve_professor_vector(self, info, id): return models.ProfessorVector.objects(metadata__id=id).first() @@ -217,9 +215,7 @@ async def resolve_professor_vector(self, info, id): } """ - expected = { - "professorVector": {"vec": [1.0, 2.3], "metadata": {"firstName": "Steven"}} - } + expected = {"professorVector": {"vec": [1.0, 2.3], "metadata": {"firstName": "Steven"}}} schema = graphene.Schema(query=Query, types=[types.ProfessorVectorType]) result = await schema.execute_async(query) assert not result.errors diff --git a/graphene_mongo/tests/test_relay_query.py b/graphene_mongo/tests/test_relay_query.py index dcc1897d..dc09b4e3 100644 --- a/graphene_mongo/tests/test_relay_query.py +++ b/graphene_mongo/tests/test_relay_query.py @@ -1,10 +1,9 @@ -import os -import json import base64 +import json +import os import graphene import pytest - from graphene.relay import Node from graphql_relay.node.node import to_global_id @@ -216,10 +215,7 @@ class ArticleLoader(DataLoader): def batch_load_fn(self, instances): queryset = models.Article.objects(editor__in=instances) return Promise.resolve( - [ - [a for a in queryset if a.editor.id == instance.id] - for instance in instances - ] + [[a for a in queryset if a.editor.id == instance.id] for instance in instances] ) article_loader = ArticleLoader() @@ -370,9 +366,7 @@ class Query(graphene.ObjectType): } """ expected = { - "articles": { - "edges": [{"node": {"headline": "Hello", "editor": {"firstName": "Penny"}}}] - } + "articles": {"edges": [{"node": {"headline": "Hello", "editor": {"firstName": "Penny"}}}]} } schema = graphene.Schema(query=Query) result = await schema.execute_async(query) @@ -457,9 +451,9 @@ class Query(graphene.ObjectType): "genericReferences": [ { "__typename": "ArticleNode", - "headline": "Hello" + "headline": "Hello", } - ] + ], } } ] @@ -759,14 +753,10 @@ class Query(graphene.ObjectType): { "node": { "beforeChild": { - "edges": [ - {"node": {"name": "Akari", "parent": {"name": "Yui"}}} - ] + "edges": [{"node": {"name": "Akari", "parent": {"name": "Yui"}}}] }, "afterChild": { - "edges": [ - {"node": {"name": "Kyouko", "parent": {"name": "Yui"}}} - ] + "edges": [{"node": {"name": "Kyouko", "parent": {"name": "Yui"}}}] }, } } @@ -800,9 +790,7 @@ class Query(graphene.ObjectType): """ expected = { "professors": { - "edges": [ - {"node": {"vec": [1.0, 2.3], "metadata": {"firstName": "Steven"}}} - ] + "edges": [{"node": {"vec": [1.0, 2.3], "metadata": {"firstName": "Steven"}}}] } } schema = graphene.Schema(query=Query) @@ -860,9 +848,7 @@ def get_queryset(model, info, **args): class Query(graphene.ObjectType): node = Node.Field() - articles = MongoengineConnectionField( - nodes.ArticleNode, get_queryset=get_queryset - ) + articles = MongoengineConnectionField(nodes.ArticleNode, get_queryset=get_queryset) query = """ query ArticlesQuery { @@ -926,9 +912,7 @@ class Query(graphene.ObjectType): result = await schema.execute_async(query) assert not result.errors - assert json.dumps(result.data, sort_keys=True) == json.dumps( - expected, sort_keys=True - ) + assert json.dumps(result.data, sort_keys=True) == json.dumps(expected, sort_keys=True) @pytest.mark.asyncio @@ -990,9 +974,7 @@ class Query(graphene.ObjectType): result = await schema.execute_async(query) assert not result.errors - assert json.dumps(result.data, sort_keys=True) == json.dumps( - expected, sort_keys=True - ) + assert json.dumps(result.data, sort_keys=True) == json.dumps(expected, sort_keys=True) @pytest.mark.asyncio @@ -1020,22 +1002,27 @@ class Query(graphene.ObjectType): """ expected = { "players": { - "edges": [{ - "node": { - "firstName": "Michael", - "articles": { - "edges": [{ - "node": { - "headline": "Hello" - } - }, { - "node": { - "headline": "World" - } - }] + "edges": [ + { + "node": { + "firstName": "Michael", + "articles": { + "edges": [ + { + "node": { + "headline": "Hello", + } + }, + { + "node": { + "headline": "World", + } + }, + ] + }, } } - }] + ] } } schema = graphene.Schema(query=Query) @@ -1071,8 +1058,8 @@ class Query(graphene.ObjectType): """.format(larry_relay_id=larry_relay_id) expected = { - 'players': { - 'edges': [] + "players": { + "edges": [], } } schema = graphene.Schema(query=Query) diff --git a/graphene_mongo/tests/test_relay_query_async.py b/graphene_mongo/tests/test_relay_query_async.py index 83b62ba4..3b9a542a 100644 --- a/graphene_mongo/tests/test_relay_query_async.py +++ b/graphene_mongo/tests/test_relay_query_async.py @@ -1,9 +1,9 @@ -import os -import json import base64 +import json +import os + import graphene import pytest - from graphene.relay import Node from graphql_relay.node.node import to_global_id @@ -214,10 +214,7 @@ class ArticleLoader(DataLoader): def batch_load_fn(self, instances): queryset = models.Article.objects(editor__in=instances) return Promise.resolve( - [ - [a for a in queryset if a.editor.id == instance.id] - for instance in instances - ] + [[a for a in queryset if a.editor.id == instance.id] for instance in instances] ) article_loader = ArticleLoader() @@ -368,9 +365,7 @@ class Query(graphene.ObjectType): } """ expected = { - "articles": { - "edges": [{"node": {"headline": "Hello", "editor": {"firstName": "Penny"}}}] - } + "articles": {"edges": [{"node": {"headline": "Hello", "editor": {"firstName": "Penny"}}}]} } schema = graphene.Schema(query=Query) result = await schema.execute_async(query) @@ -453,11 +448,8 @@ class Query(graphene.ObjectType): "firstName": "Allen", "awards": ["2010-mvp"], "genericReferences": [ - { - "__typename": "ArticleAsyncNode", - "headline": "Hello" - } - ] + {"__typename": "ArticleAsyncNode", "headline": "Hello"} + ], } } ] @@ -758,14 +750,10 @@ class Query(graphene.ObjectType): { "node": { "beforeChild": { - "edges": [ - {"node": {"name": "Akari", "parent": {"name": "Yui"}}} - ] + "edges": [{"node": {"name": "Akari", "parent": {"name": "Yui"}}}] }, "afterChild": { - "edges": [ - {"node": {"name": "Kyouko", "parent": {"name": "Yui"}}} - ] + "edges": [{"node": {"name": "Kyouko", "parent": {"name": "Yui"}}}] }, } } @@ -799,9 +787,7 @@ class Query(graphene.ObjectType): """ expected = { "professors": { - "edges": [ - {"node": {"vec": [1.0, 2.3], "metadata": {"firstName": "Steven"}}} - ] + "edges": [{"node": {"vec": [1.0, 2.3], "metadata": {"firstName": "Steven"}}}] } } schema = graphene.Schema(query=Query) @@ -925,9 +911,7 @@ class Query(graphene.ObjectType): result = await schema.execute_async(query) assert not result.errors - assert json.dumps(result.data, sort_keys=True) == json.dumps( - expected, sort_keys=True - ) + assert json.dumps(result.data, sort_keys=True) == json.dumps(expected, sort_keys=True) @pytest.mark.asyncio @@ -989,9 +973,7 @@ class Query(graphene.ObjectType): result = await schema.execute_async(query) assert not result.errors - assert json.dumps(result.data, sort_keys=True) == json.dumps( - expected, sort_keys=True - ) + assert json.dumps(result.data, sort_keys=True) == json.dumps(expected, sort_keys=True) @pytest.mark.asyncio @@ -1019,22 +1001,27 @@ class Query(graphene.ObjectType): """ expected = { "players": { - "edges": [{ - "node": { - "firstName": "Michael", - "articles": { - "edges": [{ - "node": { - "headline": "Hello" - } - }, { - "node": { - "headline": "World" - } - }] + "edges": [ + { + "node": { + "firstName": "Michael", + "articles": { + "edges": [ + { + "node": { + "headline": "Hello", + } + }, + { + "node": { + "headline": "World", + } + }, + ] + }, } } - }] + ] } } schema = graphene.Schema(query=Query) @@ -1069,11 +1056,7 @@ class Query(graphene.ObjectType): }} """.format(larry_relay_id=larry_relay_id) - expected = { - 'players': { - 'edges': [] - } - } + expected = {"players": {"edges": []}} schema = graphene.Schema(query=Query) result = await schema.execute_async(query) diff --git a/graphene_mongo/tests/test_types.py b/graphene_mongo/tests/test_types.py index 3e9ecbd1..c1217723 100644 --- a/graphene_mongo/tests/test_types.py +++ b/graphene_mongo/tests/test_types.py @@ -13,7 +13,6 @@ class Human(MongoengineObjectType): - pub_date = Int() class Meta: @@ -135,6 +134,7 @@ class A(MongoengineObjectType): class Meta: model = Article order_by = "some_order_by_statement" + assert "some_order_by_statement" not in list(A._meta.fields.keys()) diff --git a/graphene_mongo/tests/test_utils.py b/graphene_mongo/tests/test_utils.py index 3aef4848..e8b6c5bf 100644 --- a/graphene_mongo/tests/test_utils.py +++ b/graphene_mongo/tests/test_utils.py @@ -1,8 +1,9 @@ -from ..utils import get_model_fields, is_valid_mongoengine_model, get_query_fields -from .models import Article, Reporter, Child -from . import types import graphene +from . import types +from .models import Article, Child, Reporter +from ..utils import get_model_fields, get_query_fields, is_valid_mongoengine_model + def test_get_model_fields_no_duplication(): reporter_fields = get_model_fields(Reporter) @@ -82,22 +83,22 @@ def resolve_children(self, info, *args, **kwargs): schema.execute(query) assert get_query_fields(test_get_query_fields.child_info) == { - 'bar': {}, - 'loc': { - 'type': {}, - 'coordinates': {} - } + "bar": {}, + "loc": { + "type": {}, + "coordinates": {}, + }, } assert get_query_fields(test_get_query_fields.children_info) == { - 'ChildType': { - 'baz': {}, - 'loc': { - 'type': {}, - 'coordinates': {} - } + "ChildType": { + "baz": {}, + "loc": { + "type": {}, + "coordinates": {}, + }, + }, + "AnotherChildType": { + "qux": {}, }, - 'AnotherChildType': { - 'qux': {} - } } diff --git a/graphene_mongo/types.py b/graphene_mongo/types.py index 34d364cf..4eeda046 100644 --- a/graphene_mongo/types.py +++ b/graphene_mongo/types.py @@ -17,8 +17,14 @@ from .utils import get_model_fields, is_valid_mongoengine_model, get_query_fields, ExecutorEnum -def construct_fields(model, registry, only_fields, exclude_fields, non_required_fields, - executor: ExecutorEnum = ExecutorEnum.SYNC): +def construct_fields( + model, + registry, + only_fields, + exclude_fields, + non_required_fields, + executor: ExecutorEnum = ExecutorEnum.SYNC, +): """ Args: model (mongoengine.Document): @@ -47,9 +53,9 @@ def construct_fields(model, registry, only_fields, exclude_fields, non_required_ # Take care of list of self-reference. document_type_obj = field.field.__dict__.get("document_type_obj", None) if ( - document_type_obj == model._class_name - or isinstance(document_type_obj, model) - or document_type_obj == model + document_type_obj == model._class_name + or isinstance(document_type_obj, model) + or document_type_obj == model ): self_referenced[name] = field continue @@ -57,8 +63,8 @@ def construct_fields(model, registry, only_fields, exclude_fields, non_required_ if not converted: continue else: - if name in non_required_fields and 'required' in converted.kwargs: - converted.kwargs['required'] = False + if name in non_required_fields and "required" in converted.kwargs: + converted.kwargs["required"] = False fields[name] = converted return fields, self_referenced @@ -77,7 +83,6 @@ def construct_self_referenced_fields(self_referenced, registry, executor=Executo def create_graphene_generic_class(object_type, option_type): class MongoengineGenericObjectTypeOptions(option_type): - model = None registry = None # type: Registry connection = None @@ -88,26 +93,25 @@ class MongoengineGenericObjectTypeOptions(option_type): class GrapheneMongoengineGenericType(object_type): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - required_fields=(), - exclude_fields=(), - non_required_fields=(), - filter_fields=None, - non_filter_fields=(), - connection=None, - connection_class=None, - use_connection=None, - connection_field_class=None, - interfaces=(), - _meta=None, - order_by=None, - **options + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + required_fields=(), + exclude_fields=(), + non_required_fields=(), + filter_fields=None, + non_filter_fields=(), + connection=None, + connection_class=None, + use_connection=None, + connection_field_class=None, + interfaces=(), + _meta=None, + order_by=None, + **options, ): - assert is_valid_mongoengine_model(model), ( "The attribute model in {}.Meta must be a valid Mongoengine Model. " 'Received "{}" instead.' @@ -127,13 +131,9 @@ def __init_subclass_with_meta__( converted_fields, self_referenced = construct_fields( model, registry, only_fields, exclude_fields, non_required_fields ) - mongoengine_fields = yank_fields_from_attrs( - converted_fields, _as=graphene.Field - ) + mongoengine_fields = yank_fields_from_attrs(converted_fields, _as=graphene.Field) if use_connection is None and interfaces: - use_connection = any( - (issubclass(interface, Node) for interface in interfaces) - ) + use_connection = any((issubclass(interface, Node) for interface in interfaces)) if use_connection and not connection: # We create the connection automatically @@ -141,7 +141,7 @@ def __init_subclass_with_meta__( connection_class = Connection connection = connection_class.create_type( - "{}Connection".format(options.get('name') or cls.__name__), node=cls + "{}Connection".format(options.get("name") or cls.__name__), node=cls ) if connection is not None: @@ -187,9 +187,7 @@ def __init_subclass_with_meta__( if not skip_registry: registry.register(cls) # Notes: Take care list of self-reference fields. - converted_fields = construct_self_referenced_fields( - self_referenced, registry - ) + converted_fields = construct_self_referenced_fields(self_referenced, registry) if converted_fields: mongoengine_fields = yank_fields_from_attrs( converted_fields, _as=graphene.Field @@ -206,12 +204,10 @@ def rescan_fields(cls): cls._meta.registry, cls._meta.only_fields, cls._meta.exclude_fields, - cls._meta.non_required_fields + cls._meta.non_required_fields, ) - mongoengine_fields = yank_fields_from_attrs( - converted_fields, _as=graphene.Field - ) + mongoengine_fields = yank_fields_from_attrs(converted_fields, _as=graphene.Field) # The initial scan should take precedence for field in mongoengine_fields: @@ -243,8 +239,11 @@ async def get_node(cls, info, id): if to_snake_case(field) in cls._meta.model._fields_ordered: required_fields.append(to_snake_case(field)) required_fields = list(set(required_fields)) - return await sync_to_async(cls._meta.model.objects.no_dereference().only(*required_fields).get, - thread_sensitive=False, executor=ThreadPoolExecutor())(pk=id) + return await sync_to_async( + cls._meta.model.objects.no_dereference().only(*required_fields).get, + thread_sensitive=False, + executor=ThreadPoolExecutor(), + )(pk=id) def resolve_id(self, info): return str(self.id) @@ -252,9 +251,18 @@ def resolve_id(self, info): return GrapheneMongoengineGenericType, MongoengineGenericObjectTypeOptions -MongoengineObjectType, MongoengineObjectTypeOptions = create_graphene_generic_class(ObjectType, ObjectTypeOptions) -MongoengineInterfaceType, MongoengineInterfaceTypeOptions = create_graphene_generic_class(Interface, InterfaceOptions) -MongoengineInputType, MongoengineInputTypeOptions = create_graphene_generic_class(InputObjectType, - InputObjectTypeOptions) - -GrapheneMongoengineObjectTypes = (MongoengineObjectType, MongoengineInputType, MongoengineInterfaceType) +MongoengineObjectType, MongoengineObjectTypeOptions = create_graphene_generic_class( + ObjectType, ObjectTypeOptions +) +MongoengineInterfaceType, MongoengineInterfaceTypeOptions = create_graphene_generic_class( + Interface, InterfaceOptions +) +MongoengineInputType, MongoengineInputTypeOptions = create_graphene_generic_class( + InputObjectType, InputObjectTypeOptions +) + +GrapheneMongoengineObjectTypes = ( + MongoengineObjectType, + MongoengineInputType, + MongoengineInterfaceType, +) diff --git a/graphene_mongo/types_async.py b/graphene_mongo/types_async.py index 1fcb6b51..033cd9e1 100644 --- a/graphene_mongo/types_async.py +++ b/graphene_mongo/types_async.py @@ -3,21 +3,19 @@ from asgiref.sync import sync_to_async from graphene import InputObjectType from graphene.relay import Connection, Node -from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.interface import Interface, InterfaceOptions +from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs from graphene.utils.str_converters import to_snake_case -from graphene_mongo import AsyncMongoengineConnectionField -from .registry import Registry, get_global_async_registry, \ - get_inputs_async_registry +from graphene_mongo import AsyncMongoengineConnectionField +from .registry import Registry, get_global_async_registry, get_inputs_async_registry from .types import construct_fields, construct_self_referenced_fields -from .utils import is_valid_mongoengine_model, get_query_fields, ExecutorEnum +from .utils import ExecutorEnum, get_query_fields, is_valid_mongoengine_model def create_graphene_generic_class_async(object_type, option_type): class AsyncMongoengineGenericObjectTypeOptions(option_type): - model = None registry = None # type: Registry connection = None @@ -28,26 +26,25 @@ class AsyncMongoengineGenericObjectTypeOptions(option_type): class AsyncGrapheneMongoengineGenericType(object_type): @classmethod def __init_subclass_with_meta__( - cls, - model=None, - registry=None, - skip_registry=False, - only_fields=(), - required_fields=(), - exclude_fields=(), - non_required_fields=(), - filter_fields=None, - non_filter_fields=(), - connection=None, - connection_class=None, - use_connection=None, - connection_field_class=None, - interfaces=(), - _meta=None, - order_by=None, - **options + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + required_fields=(), + exclude_fields=(), + non_required_fields=(), + filter_fields=None, + non_filter_fields=(), + connection=None, + connection_class=None, + use_connection=None, + connection_field_class=None, + interfaces=(), + _meta=None, + order_by=None, + **options, ): - assert is_valid_mongoengine_model(model), ( "The attribute model in {}.Meta must be a valid Mongoengine Model. " 'Received "{}" instead.' @@ -65,15 +62,16 @@ def __init_subclass_with_meta__( 'Registry({}), received "{}".' ).format(object_type, cls.__name__, registry) converted_fields, self_referenced = construct_fields( - model, registry, only_fields, exclude_fields, non_required_fields, ExecutorEnum.ASYNC - ) - mongoengine_fields = yank_fields_from_attrs( - converted_fields, _as=graphene.Field + model, + registry, + only_fields, + exclude_fields, + non_required_fields, + ExecutorEnum.ASYNC, ) + mongoengine_fields = yank_fields_from_attrs(converted_fields, _as=graphene.Field) if use_connection is None and interfaces: - use_connection = any( - (issubclass(interface, Node) for interface in interfaces) - ) + use_connection = any((issubclass(interface, Node) for interface in interfaces)) if use_connection and not connection: # We create the connection automatically @@ -81,7 +79,7 @@ def __init_subclass_with_meta__( connection_class = Connection connection = connection_class.create_type( - "{}Connection".format(options.get('name') or cls.__name__), node=cls + "{}Connection".format(options.get("name") or cls.__name__), node=cls ) if connection is not None: @@ -146,12 +144,11 @@ def rescan_fields(cls): cls._meta.registry, cls._meta.only_fields, cls._meta.exclude_fields, - cls._meta.non_required_fields, ExecutorEnum.ASYNC + cls._meta.non_required_fields, + ExecutorEnum.ASYNC, ) - mongoengine_fields = yank_fields_from_attrs( - converted_fields, _as=graphene.Field - ) + mongoengine_fields = yank_fields_from_attrs(converted_fields, _as=graphene.Field) # The initial scan should take precedence for field in mongoengine_fields: @@ -183,7 +180,9 @@ async def get_node(cls, info, id): if to_snake_case(field) in cls._meta.model._fields_ordered: required_fields.append(to_snake_case(field)) required_fields = list(set(required_fields)) - return await sync_to_async(cls._meta.model.objects.no_dereference().only(*required_fields).get)(pk=id) + return await sync_to_async( + cls._meta.model.objects.no_dereference().only(*required_fields).get + )(pk=id) def resolve_id(self, info): return str(self.id) @@ -191,9 +190,13 @@ def resolve_id(self, info): return AsyncGrapheneMongoengineGenericType, AsyncMongoengineGenericObjectTypeOptions -AsyncMongoengineObjectType, AsyncMongoengineObjectTypeOptions = create_graphene_generic_class_async(ObjectType, - ObjectTypeOptions) -AsyncMongoengineInterfaceType, MongoengineInterfaceTypeOptions = create_graphene_generic_class_async(Interface, - InterfaceOptions) +AsyncMongoengineObjectType, AsyncMongoengineObjectTypeOptions = create_graphene_generic_class_async( + ObjectType, ObjectTypeOptions +) + +( + AsyncMongoengineInterfaceType, + MongoengineInterfaceTypeOptions, +) = create_graphene_generic_class_async(Interface, InterfaceOptions) AsyncGrapheneMongoengineObjectTypes = (AsyncMongoengineObjectType, AsyncMongoengineInterfaceType) diff --git a/graphene_mongo/utils.py b/graphene_mongo/utils.py index 0d587b70..63438aba 100644 --- a/graphene_mongo/utils.py +++ b/graphene_mongo/utils.py @@ -34,8 +34,8 @@ def get_model_reference_fields(model, excluding=None): attributes = dict() for attr_name, attr in model._fields.items(): if attr_name in excluding or not isinstance( - attr, - (mongoengine.fields.ReferenceField, mongoengine.fields.LazyReferenceField), + attr, + (mongoengine.fields.ReferenceField, mongoengine.fields.LazyReferenceField), ): continue attributes[attr_name] = attr @@ -44,8 +44,7 @@ def get_model_reference_fields(model, excluding=None): def is_valid_mongoengine_model(model): return inspect.isclass(model) and ( - issubclass(model, mongoengine.Document) - or issubclass(model, mongoengine.EmbeddedDocument) + issubclass(model, mongoengine.Document) or issubclass(model, mongoengine.EmbeddedDocument) ) @@ -76,9 +75,7 @@ def import_single_dispatch(): def get_type_for_document(schema, document): types = schema.types.values() for _type in types: - type_document = hasattr(_type, "_meta") and getattr( - _type._meta, "document", None - ) + type_document = hasattr(_type, "_meta") and getattr(_type._meta, "document", None) if document == type_document: return _type @@ -137,22 +134,19 @@ def collect_query_fields(node, fragments): field = {} selection_set = None if type(node) == dict: - selection_set = node.get('selection_set') + selection_set = node.get("selection_set") else: selection_set = node.selection_set if selection_set: for leaf in selection_set.selections: - if leaf.kind == 'field': - field.update({ - leaf.name.value: collect_query_fields(leaf, fragments) - }) - elif leaf.kind == 'fragment_spread': - field.update(collect_query_fields(fragments[leaf.name.value], - fragments)) - elif leaf.kind == 'inline_fragment': - field.update({ - leaf.type_condition.name.value: collect_query_fields(leaf, fragments) - }) + if leaf.kind == "field": + field.update({leaf.name.value: collect_query_fields(leaf, fragments)}) + elif leaf.kind == "fragment_spread": + field.update(collect_query_fields(fragments[leaf.name.value], fragments)) + elif leaf.kind == "inline_fragment": + field.update( + {leaf.type_condition.name.value: collect_query_fields(leaf, fragments)} + ) return field @@ -238,13 +232,12 @@ def find_skip_and_limit(first, last, after, before, count=None): return skip, limit, reverse -def connection_from_iterables(edges, start_offset, has_previous_page, has_next_page, connection_type, - edge_type, - pageinfo_type): +def connection_from_iterables( + edges, start_offset, has_previous_page, has_next_page, connection_type, edge_type, pageinfo_type +): edges_items = [ edge_type( - node=node, - cursor=offset_to_cursor((0 if start_offset is None else start_offset) + i) + node=node, cursor=offset_to_cursor((0 if start_offset is None else start_offset) + i) ) for i, node in enumerate(edges) ] @@ -284,6 +277,4 @@ def sync_to_async( """ if executor is None: executor = ThreadPoolExecutor() - return asgiref_sync_to_async( - func=func, thread_sensitive=thread_sensitive, executor=executor - ) + return asgiref_sync_to_async(func=func, thread_sensitive=thread_sensitive, executor=executor) diff --git a/poetry.lock b/poetry.lock index b3690cc7..f7d6dc03 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,10 +1,9 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aniso8601" version = "9.0.1" description = "A library for parsing ISO 8601 strings." -category = "main" optional = false python-versions = "*" files = [ @@ -19,7 +18,6 @@ dev = ["black", "coverage", "isort", "pre-commit", "pyenchant", "pylint"] name = "asgiref" version = "3.6.0" description = "ASGI specs, helper code, and adapters" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -37,7 +35,6 @@ tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -49,7 +46,6 @@ files = [ name = "coverage" version = "7.2.3" description = "Code coverage measurement for Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -116,7 +112,6 @@ toml = ["tomli"] name = "dnspython" version = "2.3.0" description = "DNS toolkit" -category = "main" optional = false python-versions = ">=3.7,<4.0" files = [ @@ -137,7 +132,6 @@ wmi = ["wmi (>=1.5.1,<2.0.0)"] name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -152,7 +146,6 @@ test = ["pytest (>=6)"] name = "flake8" version = "5.0.4" description = "the modular source code checker: pep8 pyflakes and co" -category = "dev" optional = false python-versions = ">=3.6.1" files = [ @@ -170,7 +163,6 @@ pyflakes = ">=2.5.0,<2.6.0" name = "graphene" version = "3.2.2" description = "GraphQL Framework for Python" -category = "main" optional = false python-versions = "*" files = [ @@ -191,7 +183,6 @@ test = ["coveralls (>=3.3,<4)", "iso8601 (>=1,<2)", "mock (>=4,<5)", "pytest (>= name = "graphql-core" version = "3.2.3" description = "GraphQL implementation for Python, a port of GraphQL.js, the JavaScript reference implementation for GraphQL." -category = "main" optional = false python-versions = ">=3.6,<4" files = [ @@ -206,7 +197,6 @@ typing-extensions = {version = ">=4.2,<5", markers = "python_version < \"3.8\""} name = "graphql-relay" version = "3.2.0" description = "Relay library for graphql-core" -category = "main" optional = false python-versions = ">=3.6,<4" files = [ @@ -222,7 +212,6 @@ typing-extensions = {version = ">=4.1,<5", markers = "python_version < \"3.8\""} name = "importlib-metadata" version = "4.2.0" description = "Read metadata from Python packages" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -242,7 +231,6 @@ testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pep517", name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -254,7 +242,6 @@ files = [ name = "iso8601" version = "1.1.0" description = "Simple module to parse ISO 8601 dates" -category = "main" optional = false python-versions = ">=3.6.2,<4.0" files = [ @@ -266,7 +253,6 @@ files = [ name = "mccabe" version = "0.7.0" description = "McCabe checker, plugin for flake8" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -278,7 +264,6 @@ files = [ name = "mock" version = "5.0.1" description = "Rolling backport of unittest.mock for all Pythons" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -295,7 +280,6 @@ test = ["pytest", "pytest-cov"] name = "mongoengine" version = "0.27.0" description = "MongoEngine is a Python Object-Document Mapper for working with MongoDB." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -310,7 +294,6 @@ pymongo = ">=3.4,<5.0" name = "mongomock" version = "4.1.2" description = "Fake pymongo stub for testing simple MongoDB-dependent code" -category = "dev" optional = false python-versions = "*" files = [ @@ -326,7 +309,6 @@ sentinels = "*" name = "packaging" version = "23.0" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -338,7 +320,6 @@ files = [ name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -357,7 +338,6 @@ testing = ["pytest", "pytest-benchmark"] name = "promise" version = "2.3" description = "Promises/A+ implementation for Python" -category = "main" optional = false python-versions = "*" files = [ @@ -374,7 +354,6 @@ test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", name = "pycodestyle" version = "2.9.1" description = "Python style guide checker" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -386,7 +365,6 @@ files = [ name = "pyflakes" version = "2.5.0" description = "passive checker of Python programs" -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -398,7 +376,6 @@ files = [ name = "pymongo" version = "4.3.3" description = "Python driver for MongoDB " -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -493,7 +470,6 @@ zstd = ["zstandard"] name = "pytest" version = "7.3.0" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -517,7 +493,6 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no name = "pytest-asyncio" version = "0.21.0" description = "Pytest support for asyncio" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -537,7 +512,6 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy name = "pytest-cov" version = "4.0.0" description = "Pytest plugin for measuring coverage." -category = "dev" optional = false python-versions = ">=3.6" files = [ @@ -552,11 +526,36 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "ruff" +version = "0.1.6" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:88b8cdf6abf98130991cbc9f6438f35f6e8d41a02622cc5ee130a02a0ed28703"}, + {file = "ruff-0.1.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c549ed437680b6105a1299d2cd30e4964211606eeb48a0ff7a93ef70b902248"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cf5f701062e294f2167e66d11b092bba7af6a057668ed618a9253e1e90cfd76"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05991ee20d4ac4bb78385360c684e4b417edd971030ab12a4fbd075ff535050e"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87455a0c1f739b3c069e2f4c43b66479a54dea0276dd5d4d67b091265f6fd1dc"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:683aa5bdda5a48cb8266fcde8eea2a6af4e5700a392c56ea5fb5f0d4bfdc0240"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:137852105586dcbf80c1717facb6781555c4e99f520c9c827bd414fac67ddfb6"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd98138a98d48a1c36c394fd6b84cd943ac92a08278aa8ac8c0fdefcf7138f35"}, + {file = "ruff-0.1.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a0cd909d25f227ac5c36d4e7e681577275fb74ba3b11d288aff7ec47e3ae745"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e8fd1c62a47aa88a02707b5dd20c5ff20d035d634aa74826b42a1da77861b5ff"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:fd89b45d374935829134a082617954120d7a1470a9f0ec0e7f3ead983edc48cc"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:491262006e92f825b145cd1e52948073c56560243b55fb3b4ecb142f6f0e9543"}, + {file = "ruff-0.1.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ea284789861b8b5ca9d5443591a92a397ac183d4351882ab52f6296b4fdd5462"}, + {file = "ruff-0.1.6-py3-none-win32.whl", hash = "sha256:1610e14750826dfc207ccbcdd7331b6bd285607d4181df9c1c6ae26646d6848a"}, + {file = "ruff-0.1.6-py3-none-win_amd64.whl", hash = "sha256:4558b3e178145491e9bc3b2ee3c4b42f19d19384eaa5c59d10acf6e8f8b57e33"}, + {file = "ruff-0.1.6-py3-none-win_arm64.whl", hash = "sha256:03910e81df0d8db0e30050725a5802441c2022ea3ae4fe0609b76081731accbc"}, + {file = "ruff-0.1.6.tar.gz", hash = "sha256:1b09f29b16c6ead5ea6b097ef2764b42372aebe363722f1605ecbcd2b9207184"}, +] + [[package]] name = "sentinels" version = "1.0.0" description = "Various objects to denote special meanings in python" -category = "dev" optional = false python-versions = "*" files = [ @@ -567,7 +566,6 @@ files = [ name = "singledispatch" version = "4.0.0" description = "Backport functools.singledispatch to older Pythons." -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -583,7 +581,6 @@ testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-chec name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" files = [ @@ -595,7 +592,6 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -607,7 +603,6 @@ files = [ name = "typing-extensions" version = "4.5.0" description = "Backported and Experimental Type Hints for Python 3.7+" -category = "main" optional = false python-versions = ">=3.7" files = [ @@ -619,7 +614,6 @@ files = [ name = "zipp" version = "3.15.0" description = "Backport of pathlib-compatible object wrapper for zip files" -category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -634,4 +628,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.7,<4" -content-hash = "55982d91910b89a9be452b0ffbede03224bb8222fd0f82d37aa594e7eccb730f" +content-hash = "c12bc64004fd15f27726b5ed4a6b30ba7ac1353a2170e2d75b39bdc80571677c" diff --git a/pyproject.toml b/pyproject.toml index bed4455d..f09f5efa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,12 @@ mock = ">=5.0.1" flake8 = "*" pytest-cov = "*" pytest-asyncio = "^0.21.0" +ruff = "^0.1.6" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 100 \ No newline at end of file From a09c174d0969edaded311b1ce1bb805677855306 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 20:29:39 +0530 Subject: [PATCH 3/7] fix: inconsistent has_next_page when pk is not sorted --- graphene_mongo/fields.py | 34 ++++++++++++----------- graphene_mongo/fields_async.py | 49 +++++++++++++++++----------------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index 7fb9fed9..5ff9aabe 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -406,14 +406,12 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a before = args.pop("before", None) if before: before = cursor_to_offset(before) - - queryset = None + has_next_page = False if resolved is not None: items = resolved if isinstance(items, QuerySet): - queryset = items.clone() try: if last is not None and after is not None: count = items.count(with_limit_and_skip=False) @@ -422,7 +420,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a except OperationFailure: count = len(items) else: - queryset = None count = len(items) skip, limit, reverse = find_skip_and_limit( @@ -431,18 +428,23 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a if isinstance(items, QuerySet): if limit: - if reverse: - items = items.order_by("-pk").skip(skip).limit(limit) - else: - items = items.skip(skip).limit(limit) + _base_query: QuerySet = ( + items.order_by("-pk").skip(skip) if reverse else items.skip(skip) + ) + items = _base_query.limit(limit) + has_next_page = len(_base_query.skip(limit).only("id").limit(1)) != 0 elif skip: items = items.skip(skip) else: if limit: if reverse: - items = items[::-1][skip : skip + limit] + _base_query = items[::-1] + items = _base_query[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) else: + _base_query = items items = items[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) elif skip: items = items[skip:] iterables = list(items) @@ -538,9 +540,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a ) if limit: if reverse: - items = items[::-1][skip : skip + limit] + _base_query = items[::-1] + items = _base_query[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) else: + _base_query = items items = items[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) elif skip: items = items[skip:] iterables = items @@ -552,16 +558,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False ) - else: - if isinstance(queryset, QuerySet) and iterables: - has_next_page = bool(queryset(pk__gt=iterables[-1].pk).limit(1).first()) - else: - has_next_page = False has_previous_page = True if skip else False + if reverse: iterables = list(iterables) iterables.reverse() skip = limit + connection = connection_from_iterables( edges=iterables, start_offset=skip, @@ -571,7 +574,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a edge_type=self.type.Edge, pageinfo_type=graphene.PageInfo, ) - connection.iterable = iterables connection.list_length = list_length return connection diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index 7f077677..efc11a16 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -14,7 +14,6 @@ from mongoengine import QuerySet from promise import Promise from pymongo.errors import OperationFailure -from concurrent.futures import ThreadPoolExecutor from .registry import get_global_async_registry from . import MongoengineConnectionField from .utils import ( @@ -97,8 +96,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non before = args.pop("before", None) if before: before = cursor_to_offset(before) - - queryset = None + has_next_page = False if resolved is not None: items = resolved @@ -119,21 +117,28 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non ) if isinstance(items, QuerySet): - queryset = items.clone() if limit: - if reverse: - items = items.order_by("-pk").skip(skip).limit(limit) - else: - items = items.skip(skip).limit(limit) + _base_query: QuerySet = ( + await sync_to_async(items.order_by("-pk").skip)(skip) + if reverse + else await sync_to_async(items.skip)(skip) + ) + items = await sync_to_async(_base_query.limit)(limit) + has_next_page = ( + len(await sync_to_async(_base_query.skip(limit).only("id").limit)(1)) != 0 + ) elif skip: - items = items.skip(skip) + items = await sync_to_async(items.skip)(skip) else: - queryset = None if limit: if reverse: - items = items[::-1][skip : skip + limit] + _base_query = items[::-1] + items = _base_query[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) else: + _base_query = items items = items[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) elif skip: items = items[skip:] iterables = await sync_to_async(list)(items) @@ -176,11 +181,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non (mongoengine.get_db()[self.model._get_collection_name()]).count_documents )(args_copy) else: - count = await sync_to_async( - self.model.objects(args_copy).count, - thread_sensitive=False, - executor=ThreadPoolExecutor(), - )() + count = await sync_to_async(self.model.objects(args_copy).count)() if count != 0: skip, limit, reverse = find_skip_and_limit( first=first, after=after, last=last, before=before, count=count @@ -228,9 +229,13 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non ) if limit: if reverse: - items = items[::-1][skip : skip + limit] + _base_query = items[::-1] + items = _base_query[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) else: + _base_query = items items = items[skip : skip + limit] + has_next_page = (skip + limit) < len(_base_query) elif skip: items = items[skip:] iterables = items @@ -243,18 +248,13 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False ) - else: - if isinstance(queryset, QuerySet) and iterables: - has_next_page = bool( - await sync_to_async(queryset(pk__gt=iterables[-1].pk).limit(1).first)() - ) - else: - has_next_page = False has_previous_page = True if skip else False + if reverse: iterables = await sync_to_async(list)(iterables) iterables.reverse() skip = limit + connection = connection_from_iterables( edges=iterables, start_offset=skip, @@ -264,7 +264,6 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non edge_type=self.type.Edge, pageinfo_type=graphene.PageInfo, ) - connection.iterable = iterables connection.list_length = list_length return connection From cad8b76c4ec3d30e28081f26d837e3243073303c Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 20:50:54 +0530 Subject: [PATCH 4/7] fix: self.get_queryset not being async --- graphene_mongo/fields_async.py | 107 +++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 5 deletions(-) diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index efc11a16..a4bf80e5 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -12,6 +12,7 @@ from graphql import GraphQLResolveInfo from graphql_relay import from_global_id, cursor_to_offset from mongoengine import QuerySet +from mongoengine.base import get_document from promise import Promise from pymongo.errors import OperationFailure from .registry import get_global_async_registry @@ -22,6 +23,7 @@ connection_from_iterables, ExecutorEnum, sync_to_async, + get_model_reference_fields, ) import pymongo @@ -57,6 +59,101 @@ def fields(self): def registry(self): return getattr(self.node_type._meta, "registry", get_global_async_registry()) + async def get_queryset( + self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args + ): + if required_fields is None: + required_fields = list() + + if args: + reference_fields = get_model_reference_fields(self.model) + hydrated_references = {} + for arg_name, arg in args.copy().items(): + if arg_name in reference_fields and not isinstance( + arg, mongoengine.base.metaclasses.TopLevelDocumentMetaclass + ): + try: + reference_obj = reference_fields[arg_name].document_type( + pk=from_global_id(arg)[1] + ) + except TypeError: + reference_obj = reference_fields[arg_name].document_type(pk=arg) + hydrated_references[arg_name] = reference_obj + elif arg_name in self.model._fields_ordered and isinstance( + getattr(self.model, arg_name), mongoengine.fields.GenericReferenceField + ): + try: + reference_obj = get_document( + self.registry._registry_string_map[from_global_id(arg)[0]] + )(pk=from_global_id(arg)[1]) + except TypeError: + reference_obj = get_document(arg["_cls"])(pk=arg["_ref"].id) + hydrated_references[arg_name] = reference_obj + elif "__near" in arg_name and isinstance( + getattr(self.model, arg_name.split("__")[0]), mongoengine.fields.PointField + ): + location = args.pop(arg_name, None) + hydrated_references[arg_name] = location["coordinates"] + if (arg_name.split("__")[0] + "__max_distance") not in args: + hydrated_references[arg_name.split("__")[0] + "__max_distance"] = 10000 + elif arg_name == "id": + hydrated_references["id"] = from_global_id(args.pop("id", None))[1] + args.update(hydrated_references) + + if self._get_queryset: + queryset_or_filters = self._get_queryset(model, info, **args) + if isinstance(queryset_or_filters, mongoengine.QuerySet): + return queryset_or_filters + else: + args.update(queryset_or_filters) + if limit is not None: + if reversed: + if self.order_by: + order_by = self.order_by + ",-pk" + else: + order_by = "-pk" + return await sync_to_async( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(order_by) + .skip(skip if skip else 0) + .limit + )(limit) + else: + return await sync_to_async( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(self.order_by) + .skip(skip if skip else 0) + .limit + )(limit) + elif skip is not None: + if reversed: + if self.order_by: + order_by = self.order_by + ",-pk" + else: + order_by = "-pk" + return await sync_to_async( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(order_by) + .skip + )(skip) + else: + return await sync_to_async( + model.objects(**args) + .no_dereference() + .only(*required_fields) + .order_by(self.order_by) + .skip + )(skip) + return await sync_to_async( + model.objects(**args).no_dereference().only(*required_fields).order_by + )(self.order_by) + async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args): if required_fields is None: required_fields = list() @@ -186,7 +283,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non skip, limit, reverse = find_skip_and_limit( first=first, after=after, last=last, before=before, count=count ) - iterables = self.get_queryset( + iterables = await self.get_queryset( self.model, info, required_fields, skip, limit, reverse, **args ) iterables = await sync_to_async(list)(iterables) @@ -194,7 +291,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset( + info.context.queryset = await self.get_queryset( self.model, info, required_fields, **args ) @@ -210,13 +307,13 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non args["pk__in"] = args["pk__in"][skip : skip + limit] elif skip: args["pk__in"] = args["pk__in"][skip:] - iterables = self.get_queryset(self.model, info, required_fields, **args) + iterables = await self.get_queryset(self.model, info, required_fields, **args) iterables = await sync_to_async(list)(iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset( + info.context.queryset = await self.get_queryset( self.model, info, required_fields, **args ) @@ -305,7 +402,7 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args): if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = self.get_queryset( + info.context.queryset = await self.get_queryset( self.model, info, required_fields, **args_copy ) From e5fd8dfd2e995c875191c4e203e57223d6461da3 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 21:03:32 +0530 Subject: [PATCH 5/7] fix[pipeline]: add ruff remove flake8 --- .github/workflows/lint.yml | 4 +-- Makefile | 3 +- graphene_mongo/fields.py | 2 +- graphene_mongo/fields_async.py | 2 +- graphene_mongo/utils.py | 2 +- poetry.lock | 52 +--------------------------------- pyproject.toml | 1 - 7 files changed, 8 insertions(+), 58 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index eb4dd686..305d8d83 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,6 +22,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 - - name: Lint with flake8 + pip install ruff + - name: Lint with ruff run: make lint diff --git a/Makefile b/Makefile index b983aa0d..07c3c5d3 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,8 @@ clean: @find . -name "__pycache__" -delete lint: - @flake8 graphene_mongo --count --show-source --statistics + @ruff check graphene_mongo + @ruff format . --check test: clean pytest graphene_mongo/tests --cov=graphene_mongo --cov-report=html --cov-report=term diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index 5ff9aabe..69750764 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -645,7 +645,7 @@ def chained_resolver(self, resolver, is_partial, root, info, **args): operation = list(arg.keys())[0] args_copy["pk" + operation.replace("$", "__")] = arg[operation] if not isinstance(arg, ObjectId) and "." in arg_name: - if type(arg) == dict: + if isinstance(arg, dict): operation = list(arg.keys())[0] args_copy[ arg_name.replace(".", "__") + operation.replace("$", "__") diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index a4bf80e5..f04274c2 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -435,7 +435,7 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args): operation = list(arg.keys())[0] args_copy["pk" + operation.replace("$", "__")] = arg[operation] if not isinstance(arg, ObjectId) and "." in arg_name: - if type(arg) == dict: + if isinstance(arg, dict): operation = list(arg.keys())[0] args_copy[ arg_name.replace(".", "__") + operation.replace("$", "__") diff --git a/graphene_mongo/utils.py b/graphene_mongo/utils.py index 63438aba..ffd01f45 100644 --- a/graphene_mongo/utils.py +++ b/graphene_mongo/utils.py @@ -133,7 +133,7 @@ def collect_query_fields(node, fragments): field = {} selection_set = None - if type(node) == dict: + if isinstance(node, dict): selection_set = node.get("selection_set") else: selection_set = node.selection_set diff --git a/poetry.lock b/poetry.lock index f7d6dc03..81376a6b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -142,23 +142,6 @@ files = [ [package.extras] test = ["pytest (>=6)"] -[[package]] -name = "flake8" -version = "5.0.4" -description = "the modular source code checker: pep8 pyflakes and co" -optional = false -python-versions = ">=3.6.1" -files = [ - {file = "flake8-5.0.4-py2.py3-none-any.whl", hash = "sha256:7a1cf6b73744f5806ab95e526f6f0d8c01c66d7bbe349562d22dfca20610b248"}, - {file = "flake8-5.0.4.tar.gz", hash = "sha256:6fbe320aad8d6b95cec8b8e47bc933004678dc63095be98528b7bdd2a9f510db"}, -] - -[package.dependencies] -importlib-metadata = {version = ">=1.1.0,<4.3", markers = "python_version < \"3.8\""} -mccabe = ">=0.7.0,<0.8.0" -pycodestyle = ">=2.9.0,<2.10.0" -pyflakes = ">=2.5.0,<2.6.0" - [[package]] name = "graphene" version = "3.2.2" @@ -249,17 +232,6 @@ files = [ {file = "iso8601-1.1.0.tar.gz", hash = "sha256:32811e7b81deee2063ea6d2e94f8819a86d1f3811e49d23623a41fa832bef03f"}, ] -[[package]] -name = "mccabe" -version = "0.7.0" -description = "McCabe checker, plugin for flake8" -optional = false -python-versions = ">=3.6" -files = [ - {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, - {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, -] - [[package]] name = "mock" version = "5.0.1" @@ -350,28 +322,6 @@ six = "*" [package.extras] test = ["coveralls", "futures", "mock", "pytest (>=2.7.3)", "pytest-benchmark", "pytest-cov"] -[[package]] -name = "pycodestyle" -version = "2.9.1" -description = "Python style guide checker" -optional = false -python-versions = ">=3.6" -files = [ - {file = "pycodestyle-2.9.1-py2.py3-none-any.whl", hash = "sha256:d1735fc58b418fd7c5f658d28d943854f8a849b01a5d0a1e6f3f3fdd0166804b"}, - {file = "pycodestyle-2.9.1.tar.gz", hash = "sha256:2c9607871d58c76354b697b42f5d57e1ada7d261c261efac224b664affdc5785"}, -] - -[[package]] -name = "pyflakes" -version = "2.5.0" -description = "passive checker of Python programs" -optional = false -python-versions = ">=3.6" -files = [ - {file = "pyflakes-2.5.0-py2.py3-none-any.whl", hash = "sha256:4579f67d887f804e67edb544428f264b7b24f435b263c4614f384135cea553d2"}, - {file = "pyflakes-2.5.0.tar.gz", hash = "sha256:491feb020dca48ccc562a8c0cbe8df07ee13078df59813b83959cbdada312ea3"}, -] - [[package]] name = "pymongo" version = "4.3.3" @@ -628,4 +578,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.7,<4" -content-hash = "c12bc64004fd15f27726b5ed4a6b30ba7ac1353a2170e2d75b39bdc80571677c" +content-hash = "3590e3b2214b5e391d5f8e104f9d4f010380acd5ceaa21e2772ba36edb2bffac" diff --git a/pyproject.toml b/pyproject.toml index f09f5efa..7e28d52d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ asgiref = "^3.6.0" pytest = "*" mongomock = ">=4.1.2" mock = ">=5.0.1" -flake8 = "*" pytest-cov = "*" pytest-asyncio = "^0.21.0" ruff = "^0.1.6" From b28191994139adc0ff61ca61956ad6eca6d52a46 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 21:20:51 +0530 Subject: [PATCH 6/7] refact: filter connection --- graphene_mongo/fields.py | 19 ++++++++++++------- graphene_mongo/fields_async.py | 18 ++++++++++++------ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index 4d21d631..59903f0e 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -3,6 +3,7 @@ import logging from collections import OrderedDict from functools import partial, reduce +from itertools import filterfalse import bson import graphene @@ -579,7 +580,6 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a return connection def chained_resolver(self, resolver, is_partial, root, info, **args): - for key, value in dict(args).items(): if value is None: del args[key] @@ -600,18 +600,23 @@ def chained_resolver(self, resolver, is_partial, root, info, **args): if isinstance(self.model, mongoengine.Document) or isinstance( self.model, mongoengine.base.metaclasses.TopLevelDocumentMetaclass ): - from itertools import filterfalse - connection_fields = [ field for field in self.fields if type(self.fields[field]) == MongoengineConnectionField ] - filter_connection = lambda x: ( - connection_fields.__contains__(x) - or self._type._meta.non_filter_fields.__contains__(x) + + def filter_connection(x): + return any( + [ + connection_fields.__contains__(x), + self._type._meta.non_filter_fields.__contains__(x), + ] + ) + + filterable_args = tuple( + filterfalse(filter_connection, list(self.model._fields_ordered)) ) - filterable_args = tuple(filterfalse(filter_connection, list(self.model._fields_ordered))) for arg_name, arg in args.copy().items(): if arg_name not in filterable_args + tuple(self.filter_args.keys()): args_copy.pop(arg_name) diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index a7321130..b930bcdf 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from functools import partial from typing import Coroutine +from itertools import filterfalse import bson import graphene @@ -386,18 +387,23 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args): if isinstance(self.model, mongoengine.Document) or isinstance( self.model, mongoengine.base.metaclasses.TopLevelDocumentMetaclass ): - from itertools import filterfalse - connection_fields = [ field for field in self.fields if type(self.fields[field]) == AsyncMongoengineConnectionField ] - filter_connection = lambda x: ( - connection_fields.__contains__(x) - or self._type._meta.non_filter_fields.__contains__(x) + + def filter_connection(x): + return any( + [ + connection_fields.__contains__(x), + self._type._meta.non_filter_fields.__contains__(x), + ] + ) + + filterable_args = tuple( + filterfalse(filter_connection, list(self.model._fields_ordered)) ) - filterable_args = tuple(filterfalse(filter_connection, list(self.model._fields_ordered))) for arg_name, arg in args.copy().items(): if arg_name not in filterable_args + tuple(self.filter_args.keys()): args_copy.pop(arg_name) From 50642e9f9c4db1440fd581fc981552a5dbc3bfd3 Mon Sep 17 00:00:00 2001 From: M Aswin Kishore Date: Tue, 21 Nov 2023 21:32:59 +0530 Subject: [PATCH 7/7] refact: has_next_page logic --- graphene_mongo/fields.py | 2 +- graphene_mongo/fields_async.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index 59903f0e..4798922a 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -433,7 +433,7 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a items.order_by("-pk").skip(skip) if reverse else items.skip(skip) ) items = _base_query.limit(limit) - has_next_page = len(_base_query.skip(limit).only("id").limit(1)) != 0 + has_next_page = _base_query.skip(limit).only("id").limit(1).count() != 0 elif skip: items = items.skip(skip) else: diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index b930bcdf..32319765 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -223,7 +223,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non ) items = await sync_to_async(_base_query.limit)(limit) has_next_page = ( - len(await sync_to_async(_base_query.skip(limit).only("id").limit)(1)) != 0 + await sync_to_async(_base_query.skip(limit).only("id").limit(1).count)() + != 0 ) elif skip: items = await sync_to_async(items.skip)(skip)