diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index fcf20015..d10820ba 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -5,6 +5,7 @@ import graphene import mongoengine +from promise import Promise from graphene.relay import ConnectionField from graphene.types.argument import to_arguments from graphene.types.dynamic import Dynamic @@ -126,7 +127,6 @@ def fields(self): return self._type._meta.fields def get_queryset(self, model, info, **args): - if args: reference_fields = get_model_reference_fields(self.model) hydrated_references = {} @@ -157,36 +157,49 @@ def default_resolver(self, _root, info, **args): _id = args.pop('id', None) if _id is not None: - objs = [get_node_from_global_id(self.node_type, info, _id)] + iterables = [get_node_from_global_id(self.node_type, info, _id)] list_length = 1 elif callable(getattr(self.model, 'objects', None)): - objs = self.get_queryset(self.model, info, **args) - list_length = objs.count() + iterables = self.get_queryset(self.model, info, **args) + list_length = iterables.count() else: - objs = [] + iterables = [] list_length = 0 connection = connection_from_list_slice( - list_slice=objs, + list_slice=iterables, args=connection_args, list_length=list_length, connection_type=self.type, edge_type=self.type.Edge, pageinfo_type=graphene.PageInfo, ) - connection.iterable = objs + connection.iterable = iterables connection.list_length = list_length return connection - def chained_resolver(self, resolver, root, info, **args): - if not bool(args): + def chained_resolver(self, resolver, is_partial, root, info, **args): + if not bool(args) or not is_partial: # XXX: Filter nested args resolved = resolver(root, info, **args) if resolved is not None: return resolved return self.default_resolver(root, info, **args) + @classmethod + def connection_resolver(cls, resolver, connection_type, root, info, **args): + iterable = resolver(root, info, **args) + if isinstance(connection_type, graphene.NonNull): + connection_type = connection_type.of_type + + on_resolve = partial(cls.resolve_connection, connection_type, args) + if Promise.is_thenable(iterable): + return Promise.resolve(iterable).then(on_resolve) + + return on_resolve(iterable) + def get_resolver(self, parent_resolver): super_resolver = self.resolver or parent_resolver - resolver = partial(self.chained_resolver, super_resolver) + resolver = partial( + self.chained_resolver, super_resolver, isinstance(super_resolver, partial)) return partial(self.connection_resolver, resolver, self.type) diff --git a/graphene_mongo/tests/test_relay_query.py b/graphene_mongo/tests/test_relay_query.py index 6ddc2c37..0f913724 100644 --- a/graphene_mongo/tests/test_relay_query.py +++ b/graphene_mongo/tests/test_relay_query.py @@ -11,6 +11,7 @@ from . import types from .setup import fixtures, fixtures_dirname from ..fields import MongoengineConnectionField +from ..types import MongoengineObjectType def test_should_query_reporter(fixtures): @@ -250,64 +251,78 @@ class Query(graphene.ObjectType): assert result.data == expected - def test_should_query_editors_with_dataloader(fixtures): from promise import Promise from promise.dataloader import DataLoader - class EditorLoader(DataLoader): + class ArticleLoader(DataLoader): - def batch_load_fn(self, keys): - print(keys) - queryset = models.Editor.objects(_id__in=keys) - return Promise.resolve( - [ - [e for e in queryset if e._id == _id] - for _id in keys - ] - ) + def batch_load_fn(self, instances): + queryset = models.Article.objects(editor__in=instances) + return Promise.resolve([ + [a for a in queryset if a.editor.id == instance.id] + for instance in instances + ]) - editor_loader = EditorLoader() + article_loader = ArticleLoader() - class Query(graphene.ObjectType): - # editors = MongoengineConnectionField(nodes.EditorNode) - editors = graphene.List(types.EditorType) + class _EditorNode(MongoengineObjectType): + + class Meta: + model = models.Editor + interfaces = (graphene.Node,) - def resolve_editors(self, info, *args, **kwargs): - print('hell') - # print(self.__dict__) - print(self) - print(info) - print(args) - print(kwargs) - return None + articles = MongoengineConnectionField(nodes.ArticleNode) + + def resolve_articles(self, *args, **kwargs): + return article_loader.load(self) + class Query(graphene.ObjectType): + editors = MongoengineConnectionField(_EditorNode) query = ''' - query EditorPromiseQuery { + query EditorsConnectionPromiseQuery { editors(first: 1) { - firstName - } - } - ''' - """ - query = ''' - query EditorPromiseQuery { - editors { edges { node { - firstName + firstName, + articles(first: 1) { + edges { + node { + headline + } + } + } } } } } ''' - """ + + expected = { + 'editors': { + 'edges': [ + { + 'node': { + 'firstName': 'Penny', + 'articles': { + 'edges': [ + { + 'node': { + 'headline': 'Hello' + } + } + ] + } + } + } + ] + } + } schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - # print(result.errors) - print('ccccc' * 10, result.data) + assert result.data == expected def test_should_filter_editors_by_id(fixtures): @@ -337,7 +352,6 @@ class Query(graphene.ObjectType): 'firstName': 'Grant', 'lastName': 'Hill' } - } ] } diff --git a/setup.cfg b/setup.cfg index fef3f048..1a477837 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,4 @@ known_first_party=graphene,graphene_mongo test=pytest [tool:pytest] -addopts=-vv python_files = graphene_mongo/tests/*.py