diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index fcf20015..d10820ba 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -5,6 +5,7 @@ import graphene import mongoengine +from promise import Promise from graphene.relay import ConnectionField from graphene.types.argument import to_arguments from graphene.types.dynamic import Dynamic @@ -126,7 +127,6 @@ def fields(self): return self._type._meta.fields def get_queryset(self, model, info, **args): - if args: reference_fields = get_model_reference_fields(self.model) hydrated_references = {} @@ -157,36 +157,49 @@ def default_resolver(self, _root, info, **args): _id = args.pop('id', None) if _id is not None: - objs = [get_node_from_global_id(self.node_type, info, _id)] + iterables = [get_node_from_global_id(self.node_type, info, _id)] list_length = 1 elif callable(getattr(self.model, 'objects', None)): - objs = self.get_queryset(self.model, info, **args) - list_length = objs.count() + iterables = self.get_queryset(self.model, info, **args) + list_length = iterables.count() else: - objs = [] + iterables = [] list_length = 0 connection = connection_from_list_slice( - list_slice=objs, + list_slice=iterables, args=connection_args, list_length=list_length, connection_type=self.type, edge_type=self.type.Edge, pageinfo_type=graphene.PageInfo, ) - connection.iterable = objs + connection.iterable = iterables connection.list_length = list_length return connection - def chained_resolver(self, resolver, root, info, **args): - if not bool(args): + def chained_resolver(self, resolver, is_partial, root, info, **args): + if not bool(args) or not is_partial: # XXX: Filter nested args resolved = resolver(root, info, **args) if resolved is not None: return resolved return self.default_resolver(root, info, **args) + @classmethod + def connection_resolver(cls, resolver, connection_type, root, info, **args): + iterable = resolver(root, info, **args) + if isinstance(connection_type, graphene.NonNull): + connection_type = connection_type.of_type + + on_resolve = partial(cls.resolve_connection, connection_type, args) + if Promise.is_thenable(iterable): + return Promise.resolve(iterable).then(on_resolve) + + return on_resolve(iterable) + def get_resolver(self, parent_resolver): super_resolver = self.resolver or parent_resolver - resolver = partial(self.chained_resolver, super_resolver) + resolver = partial( + self.chained_resolver, super_resolver, isinstance(super_resolver, partial)) return partial(self.connection_resolver, resolver, self.type) diff --git a/graphene_mongo/tests/test_query.py b/graphene_mongo/tests/test_query.py index a48fcbf7..c89b501c 100644 --- a/graphene_mongo/tests/test_query.py +++ b/graphene_mongo/tests/test_query.py @@ -3,27 +3,23 @@ import json import graphene +from . import models +from . import types from .setup import fixtures, fixtures_dirname -from .models import ( - Child, Editor, Player, Reporter, ProfessorVector, Parent, CellTower -) -from .types import ( - ChildType, EditorType, PlayerType, ReporterType, ProfessorVectorType, ParentType, CellTowerType -) def test_should_query_editor(fixtures, fixtures_dirname): class Query(graphene.ObjectType): - editor = graphene.Field(EditorType) - editors = graphene.List(EditorType) + editor = graphene.Field(types.EditorType) + editors = graphene.List(types.EditorType) def resolve_editor(self, *args, **kwargs): - return Editor.objects.first() + return models.Editor.objects.first() def resolve_editors(self, *args, **kwargs): - return list(Editor.objects.all()) + return list(models.Editor.objects.all()) query = ''' query EditorQuery { @@ -91,10 +87,10 @@ def resolve_editors(self, *args, **kwargs): def test_should_query_reporter(fixtures): class Query(graphene.ObjectType): - reporter = graphene.Field(ReporterType) + reporter = graphene.Field(types.ReporterType) def resolve_reporter(self, *args, **kwargs): - return Reporter.objects.first() + return models.Reporter.objects.first() query = ''' query ReporterQuery { @@ -154,10 +150,10 @@ def test_should_custom_kwargs(fixtures): class Query(graphene.ObjectType): - editors = graphene.List(EditorType, first=graphene.Int()) + editors = graphene.List(types.EditorType, first=graphene.Int()) def resolve_editors(self, *args, **kwargs): - editors = Editor.objects() + editors = models.Editor.objects() if 'first' in kwargs: editors = editors[:kwargs['first']] return list(editors) @@ -192,10 +188,10 @@ def test_should_self_reference(fixtures): class Query(graphene.ObjectType): - all_players = graphene.List(PlayerType) + all_players = graphene.List(types.PlayerType) def resolve_all_players(self, *args, **kwargs): - return Player.objects.all() + return models.Player.objects.all() query = ''' query PlayersQuery { @@ -260,10 +256,10 @@ def resolve_all_players(self, *args, **kwargs): def test_should_query_with_embedded_document(fixtures): class Query(graphene.ObjectType): - professor_vector = graphene.Field(ProfessorVectorType, id=graphene.String()) + professor_vector = graphene.Field(types.ProfessorVectorType, id=graphene.String()) def resolve_professor_vector(self, info, id): - return ProfessorVector.objects(metadata__id=id).first() + return models.ProfessorVector.objects(metadata__id=id).first() query = """ query { @@ -284,7 +280,8 @@ def resolve_professor_vector(self, info, id): } } } - schema = graphene.Schema(query=Query, types=[ProfessorVectorType]) + schema = graphene.Schema( + query=Query, types=[types.ProfessorVectorType]) result = schema.execute(query) assert not result.errors assert result.data == expected @@ -294,10 +291,10 @@ def test_should_query_child(fixtures): class Query(graphene.ObjectType): - children = graphene.List(ChildType) + children = graphene.List(types.ChildType) def resolve_children(self, *args, **kwargs): - return list(Child.objects.all()) + return list(models.Child.objects.all()) query = ''' query Query { @@ -338,10 +335,10 @@ def test_should_query_cell_tower(fixtures): class Query(graphene.ObjectType): - cell_towers = graphene.List(CellTowerType) + cell_towers = graphene.List(types.CellTowerType) def resolve_cell_towers(self, *args, **kwargs): - return list(CellTower.objects.all()) + return list(models.CellTower.objects.all()) query = ''' query Query { diff --git a/graphene_mongo/tests/test_relay_query.py b/graphene_mongo/tests/test_relay_query.py index cf6a8168..7e26f2fa 100644 --- a/graphene_mongo/tests/test_relay_query.py +++ b/graphene_mongo/tests/test_relay_query.py @@ -6,29 +6,21 @@ from graphene.relay import Node +from . import models +from . import nodes +from . import types from .setup import fixtures, fixtures_dirname -from .models import Article, Reporter -from .nodes import (ArticleNode, - EditorNode, - PlayerNode, - ReporterNode, - ChildNode, - ParentWithRelationshipNode, - ProfessorVectorNode,) from ..fields import MongoengineConnectionField - - -def get_nodes(data, key): - return map(lambda edge: edge['node'], data[key]['edges']) +from ..types import MongoengineObjectType def test_should_query_reporter(fixtures): class Query(graphene.ObjectType): - reporter = graphene.Field(ReporterNode) + reporter = graphene.Field(nodes.ReporterNode) def resolve_reporter(self, *args, **kwargs): - return Reporter.objects.first() + return models.Reporter.objects.first() query = ''' query ReporterQuery { @@ -125,13 +117,13 @@ def resolve_reporter(self, *args, **kwargs): schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data['reporter'] == expected['reporter'] + assert result.data == expected def test_should_query_reporters_with_nested_document(fixtures): class Query(graphene.ObjectType): - reporters = MongoengineConnectionField(ReporterNode) + reporters = MongoengineConnectionField(nodes.ReporterNode) query = ''' query ReporterQuery { @@ -179,30 +171,30 @@ class Query(graphene.ObjectType): schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data['reporters'] == expected['reporters'] + assert result.data == expected def test_should_query_all_editors(fixtures, fixtures_dirname): class Query(graphene.ObjectType): - editors = MongoengineConnectionField(EditorNode) + editors = MongoengineConnectionField(nodes.EditorNode) query = ''' query EditorQuery { - editors { - edges { - node { - id, - firstName, - lastName, - avatar { - contentType, - length, - data + editors { + edges { + node { + id, + firstName, + lastName, + avatar { + contentType, + length, + data + } } } } - } } ''' @@ -256,18 +248,91 @@ class Query(graphene.ObjectType): schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data['editors'] == expected['editors'] + assert result.data == expected + + +def test_should_query_editors_with_dataloader(fixtures): + from promise import Promise + from promise.dataloader import DataLoader + + class ArticleLoader(DataLoader): + + def batch_load_fn(self, instances): + queryset = models.Article.objects(editor__in=instances) + return Promise.resolve([ + [a for a in queryset if a.editor.id == instance.id] + for instance in instances + ]) + + article_loader = ArticleLoader() + + class _EditorNode(MongoengineObjectType): + + class Meta: + model = models.Editor + interfaces = (graphene.Node,) + + articles = MongoengineConnectionField(nodes.ArticleNode) + + def resolve_articles(self, *args, **kwargs): + return article_loader.load(self) + + class Query(graphene.ObjectType): + editors = MongoengineConnectionField(_EditorNode) + + query = ''' + query EditorPromiseConnectionQuery { + editors(first: 1) { + edges { + node { + firstName, + articles(first: 1) { + edges { + node { + headline + } + } + } + } + } + } + } + ''' + + expected = { + 'editors': { + 'edges': [ + { + 'node': { + 'firstName': 'Penny', + 'articles': { + 'edges': [ + { + 'node': { + 'headline': 'Hello' + } + } + ] + } + } + } + ] + } + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert result.data == expected def test_should_filter_editors_by_id(fixtures): class Query(graphene.ObjectType): - node = Node.Field() - all_editors = MongoengineConnectionField(EditorNode) + editors = MongoengineConnectionField(nodes.EditorNode) query = ''' query EditorQuery { - allEditors(id: "RWRpdG9yTm9kZToy") { + editors(id: "RWRpdG9yTm9kZToy") { edges { node { id, @@ -279,7 +344,7 @@ class Query(graphene.ObjectType): } ''' expected = { - 'allEditors': { + 'editors': { 'edges': [ { 'node': { @@ -287,7 +352,6 @@ class Query(graphene.ObjectType): 'firstName': 'Grant', 'lastName': 'Hill' } - } ] } @@ -295,14 +359,13 @@ class Query(graphene.ObjectType): schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert dict(result.data['allEditors']) == expected['allEditors'] + assert result.data == expected def test_should_filter(fixtures): class Query(graphene.ObjectType): - node = Node.Field() - articles = MongoengineConnectionField(ArticleNode) + articles = MongoengineConnectionField(nodes.ArticleNode) query = ''' query ArticlesQuery { @@ -343,8 +406,7 @@ class Query(graphene.ObjectType): def test_should_filter_by_reference_field(fixtures): class Query(graphene.ObjectType): - node = Node.Field() - articles = MongoengineConnectionField(ArticleNode) + articles = MongoengineConnectionField(nodes.ArticleNode) query = ''' query ArticlesQuery { @@ -384,7 +446,7 @@ def test_should_filter_through_inheritance(fixtures): class Query(graphene.ObjectType): node = Node.Field() - children = MongoengineConnectionField(ChildNode) + children = MongoengineConnectionField(nodes.ChildNode) query = ''' query ChildrenQuery { @@ -427,7 +489,7 @@ class Query(graphene.ObjectType): def test_should_filter_by_list_contains(fixtures): # Notes: https://goo.gl/hMNRgs class Query(graphene.ObjectType): - reporters = MongoengineConnectionField(ReporterNode) + reporters = MongoengineConnectionField(nodes.ReporterNode) query = ''' query ReportersQuery { @@ -464,7 +526,7 @@ class Query(graphene.ObjectType): def test_should_filter_by_id(fixtures): # Notes: https://goo.gl/hMNRgs class Query(graphene.ObjectType): - reporter = Node.Field(ReporterNode) + reporter = Node.Field(nodes.ReporterNode) query = ''' query ReporterQuery { @@ -492,7 +554,7 @@ def test_should_first_n(fixtures): class Query(graphene.ObjectType): - editors = MongoengineConnectionField(EditorNode) + editors = MongoengineConnectionField(nodes.EditorNode) query = ''' query EditorQuery { @@ -516,13 +578,13 @@ class Query(graphene.ObjectType): 'editors': { 'edges': [ { - 'cursor': 'xxx', + 'cursor': 'YXJyYXljb25uZWN0aW9uOjA=', 'node': { 'firstName': 'Penny' } }, { - 'cursor': 'xxx', + 'cursor': 'YXJyYXljb25uZWN0aW9uOjE=', 'node': { 'firstName': 'Grant' } @@ -531,8 +593,8 @@ class Query(graphene.ObjectType): 'pageInfo': { 'hasNextPage': True, 'hasPreviousPage': False, - 'startCursor': 'xxx', - 'endCursor': 'xxx' + 'startCursor': 'YXJyYXljb25uZWN0aW9uOjA=', + 'endCursor': 'YXJyYXljb25uZWN0aW9uOjE=' } } } @@ -540,14 +602,13 @@ class Query(graphene.ObjectType): result = schema.execute(query) assert not result.errors - assert all(item in get_nodes(result.data, 'editors') - for item in get_nodes(expected, 'editors')) + assert result.data == expected def test_should_after(fixtures): class Query(graphene.ObjectType): - players = MongoengineConnectionField(PlayerNode) + players = MongoengineConnectionField(nodes.PlayerNode) query = ''' query EditorQuery { @@ -595,7 +656,7 @@ class Query(graphene.ObjectType): def test_should_before(fixtures): class Query(graphene.ObjectType): - players = MongoengineConnectionField(PlayerNode) + players = MongoengineConnectionField(nodes.PlayerNode) query = ''' query EditorQuery { @@ -636,7 +697,7 @@ class Query(graphene.ObjectType): def test_should_last_n(fixtures): class Query(graphene.ObjectType): - players = MongoengineConnectionField(PlayerNode) + players = MongoengineConnectionField(nodes.PlayerNode) query = ''' query PlayerQuery { @@ -679,11 +740,11 @@ def test_should_self_reference(fixtures): class Query(graphene.ObjectType): - all_players = MongoengineConnectionField(PlayerNode) + players = MongoengineConnectionField(nodes.PlayerNode) query = ''' query PlayersQuery { - allPlayers { + players { edges { node { firstName, @@ -707,7 +768,7 @@ class Query(graphene.ObjectType): } ''' expected = { - 'allPlayers': { + 'players': { 'edges': [ { 'node': { @@ -790,7 +851,7 @@ def test_should_lazy_reference(fixtures): class Query(graphene.ObjectType): node = Node.Field() - parents = MongoengineConnectionField(ParentWithRelationshipNode) + parents = MongoengineConnectionField(nodes.ParentWithRelationshipNode) schema = graphene.Schema(query=Query) @@ -855,11 +916,11 @@ def test_should_query_with_embedded_document(fixtures): class Query(graphene.ObjectType): - all_professors = MongoengineConnectionField(ProfessorVectorNode) + professors = MongoengineConnectionField(nodes.ProfessorVectorNode) query = ''' query { - allProfessors { + professors { edges { node { vec, @@ -872,7 +933,7 @@ class Query(graphene.ObjectType): } ''' expected = { - 'allProfessors': { + 'professors': { 'edges': [ { 'node': { @@ -889,14 +950,14 @@ class Query(graphene.ObjectType): schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert dict(result.data['allProfessors']) == expected['allProfessors'] + assert result.data == expected def test_should_get_queryset_returns_dict_filters(fixtures): class Query(graphene.ObjectType): node = Node.Field() - articles = MongoengineConnectionField(ArticleNode, get_queryset=lambda *_, **__: {"headline": "World"}) + articles = MongoengineConnectionField(nodes.ArticleNode, get_queryset=lambda *_, **__: {"headline": "World"}) query = ''' query ArticlesQuery { @@ -941,7 +1002,7 @@ def get_queryset(model, info, **args): class Query(graphene.ObjectType): node = Node.Field() - articles = MongoengineConnectionField(ArticleNode, get_queryset=get_queryset) + articles = MongoengineConnectionField(nodes.ArticleNode, get_queryset=get_queryset) query = ''' query ArticlesQuery { diff --git a/setup.cfg b/setup.cfg index fef3f048..1a477837 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,5 +19,4 @@ known_first_party=graphene,graphene_mongo test=pytest [tool:pytest] -addopts=-vv python_files = graphene_mongo/tests/*.py