Skip to content

Commit

Permalink
Merge pull request #91 from graphql-python/feat-data-loader
Browse files Browse the repository at this point in the history
Feat data loader
  • Loading branch information
abawchen authored May 14, 2019
2 parents 98c2d76 + 9a35410 commit 008dbc2
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 97 deletions.
33 changes: 23 additions & 10 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import graphene
import mongoengine
from promise import Promise
from graphene.relay import ConnectionField
from graphene.types.argument import to_arguments
from graphene.types.dynamic import Dynamic
Expand Down Expand Up @@ -126,7 +127,6 @@ def fields(self):
return self._type._meta.fields

def get_queryset(self, model, info, **args):

if args:
reference_fields = get_model_reference_fields(self.model)
hydrated_references = {}
Expand Down Expand Up @@ -157,36 +157,49 @@ def default_resolver(self, _root, info, **args):
_id = args.pop('id', None)

if _id is not None:
objs = [get_node_from_global_id(self.node_type, info, _id)]
iterables = [get_node_from_global_id(self.node_type, info, _id)]
list_length = 1
elif callable(getattr(self.model, 'objects', None)):
objs = self.get_queryset(self.model, info, **args)
list_length = objs.count()
iterables = self.get_queryset(self.model, info, **args)
list_length = iterables.count()
else:
objs = []
iterables = []
list_length = 0

connection = connection_from_list_slice(
list_slice=objs,
list_slice=iterables,
args=connection_args,
list_length=list_length,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=graphene.PageInfo,
)
connection.iterable = objs
connection.iterable = iterables
connection.list_length = list_length
return connection

def chained_resolver(self, resolver, root, info, **args):
if not bool(args):
def chained_resolver(self, resolver, is_partial, root, info, **args):
if not bool(args) or not is_partial:
# XXX: Filter nested args
resolved = resolver(root, info, **args)
if resolved is not None:
return resolved
return self.default_resolver(root, info, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
iterable = resolver(root, info, **args)
if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)

def get_resolver(self, parent_resolver):
super_resolver = self.resolver or parent_resolver
resolver = partial(self.chained_resolver, super_resolver)
resolver = partial(
self.chained_resolver, super_resolver, isinstance(super_resolver, partial))
return partial(self.connection_resolver, resolver, self.type)
43 changes: 20 additions & 23 deletions graphene_mongo/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,23 @@
import json
import graphene

from . import models
from . import types
from .setup import fixtures, fixtures_dirname
from .models import (
Child, Editor, Player, Reporter, ProfessorVector, Parent, CellTower
)
from .types import (
ChildType, EditorType, PlayerType, ReporterType, ProfessorVectorType, ParentType, CellTowerType
)


def test_should_query_editor(fixtures, fixtures_dirname):

class Query(graphene.ObjectType):

editor = graphene.Field(EditorType)
editors = graphene.List(EditorType)
editor = graphene.Field(types.EditorType)
editors = graphene.List(types.EditorType)

def resolve_editor(self, *args, **kwargs):
return Editor.objects.first()
return models.Editor.objects.first()

def resolve_editors(self, *args, **kwargs):
return list(Editor.objects.all())
return list(models.Editor.objects.all())

query = '''
query EditorQuery {
Expand Down Expand Up @@ -91,10 +87,10 @@ def resolve_editors(self, *args, **kwargs):
def test_should_query_reporter(fixtures):

class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType)
reporter = graphene.Field(types.ReporterType)

def resolve_reporter(self, *args, **kwargs):
return Reporter.objects.first()
return models.Reporter.objects.first()

query = '''
query ReporterQuery {
Expand Down Expand Up @@ -154,10 +150,10 @@ def test_should_custom_kwargs(fixtures):

class Query(graphene.ObjectType):

editors = graphene.List(EditorType, first=graphene.Int())
editors = graphene.List(types.EditorType, first=graphene.Int())

def resolve_editors(self, *args, **kwargs):
editors = Editor.objects()
editors = models.Editor.objects()
if 'first' in kwargs:
editors = editors[:kwargs['first']]
return list(editors)
Expand Down Expand Up @@ -192,10 +188,10 @@ def test_should_self_reference(fixtures):

class Query(graphene.ObjectType):

all_players = graphene.List(PlayerType)
all_players = graphene.List(types.PlayerType)

def resolve_all_players(self, *args, **kwargs):
return Player.objects.all()
return models.Player.objects.all()

query = '''
query PlayersQuery {
Expand Down Expand Up @@ -260,10 +256,10 @@ def resolve_all_players(self, *args, **kwargs):
def test_should_query_with_embedded_document(fixtures):

class Query(graphene.ObjectType):
professor_vector = graphene.Field(ProfessorVectorType, id=graphene.String())
professor_vector = graphene.Field(types.ProfessorVectorType, id=graphene.String())

def resolve_professor_vector(self, info, id):
return ProfessorVector.objects(metadata__id=id).first()
return models.ProfessorVector.objects(metadata__id=id).first()

query = """
query {
Expand All @@ -284,7 +280,8 @@ def resolve_professor_vector(self, info, id):
}
}
}
schema = graphene.Schema(query=Query, types=[ProfessorVectorType])
schema = graphene.Schema(
query=Query, types=[types.ProfessorVectorType])
result = schema.execute(query)
assert not result.errors
assert result.data == expected
Expand All @@ -294,10 +291,10 @@ def test_should_query_child(fixtures):

class Query(graphene.ObjectType):

children = graphene.List(ChildType)
children = graphene.List(types.ChildType)

def resolve_children(self, *args, **kwargs):
return list(Child.objects.all())
return list(models.Child.objects.all())

query = '''
query Query {
Expand Down Expand Up @@ -338,10 +335,10 @@ def test_should_query_cell_tower(fixtures):

class Query(graphene.ObjectType):

cell_towers = graphene.List(CellTowerType)
cell_towers = graphene.List(types.CellTowerType)

def resolve_cell_towers(self, *args, **kwargs):
return list(CellTower.objects.all())
return list(models.CellTower.objects.all())

query = '''
query Query {
Expand Down
Loading

0 comments on commit 008dbc2

Please sign in to comment.