Skip to content

Commit e9d5e88

Browse files
committed
Handle the default django list field and test the async execution of the fields
1 parent c10753d commit e9d5e88

File tree

3 files changed

+184
-13
lines changed

3 files changed

+184
-13
lines changed

graphene_django/fields.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,41 @@ def get_manager(self):
5353
def list_resolver(
5454
django_object_type, resolver, default_manager, root, info, **args
5555
):
56-
queryset = maybe_queryset(resolver(root, info, **args))
56+
iterable = resolver(root, info, **args)
57+
58+
if info.is_awaitable(iterable):
59+
60+
async def resolve_list_async(iterable):
61+
queryset = maybe_queryset(await iterable)
62+
if queryset is None:
63+
queryset = maybe_queryset(default_manager)
64+
65+
if isinstance(queryset, QuerySet):
66+
# Pass queryset to the DjangoObjectType get_queryset method
67+
queryset = maybe_queryset(
68+
await sync_to_async(django_object_type.get_queryset)(
69+
queryset, info
70+
)
71+
)
72+
73+
return await sync_to_async(list)(queryset)
74+
75+
return resolve_list_async(iterable)
76+
77+
queryset = maybe_queryset(iterable)
5778
if queryset is None:
5879
queryset = maybe_queryset(default_manager)
5980

6081
if isinstance(queryset, QuerySet):
6182
# Pass queryset to the DjangoObjectType get_queryset method
6283
queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
6384

64-
try:
85+
try:
6586
get_running_loop()
6687
except RuntimeError:
67-
pass
88+
pass
6889
else:
69-
return queryset.aiterator()
90+
return sync_to_async(list)(queryset)
7091

7192
return queryset
7293

@@ -238,34 +259,39 @@ def connection_resolver(
238259
# or a resolve_foo (does not accept queryset)
239260

240261
iterable = resolver(root, info, **args)
241-
262+
242263
if info.is_awaitable(iterable):
264+
243265
async def resolve_connection_async(iterable):
244266
iterable = await iterable
245267
if iterable is None:
246268
iterable = default_manager
247269
## This could also be async
248270
iterable = queryset_resolver(connection, iterable, info, args)
249-
271+
250272
if info.is_awaitable(iterable):
251273
iterable = await iterable
252-
253-
return await sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
274+
275+
return await sync_to_async(cls.resolve_connection)(
276+
connection, args, iterable, max_limit=max_limit
277+
)
278+
254279
return resolve_connection_async(iterable)
255-
280+
256281
if iterable is None:
257282
iterable = default_manager
258283
# thus the iterable gets refiltered by resolve_queryset
259284
# but iterable might be promise
260285
iterable = queryset_resolver(connection, iterable, info, args)
261286

262-
try:
287+
try:
263288
get_running_loop()
264289
except RuntimeError:
265-
pass
290+
pass
266291
else:
267-
return sync_to_async(cls.resolve_connection)(connection, args, iterable, max_limit=max_limit)
268-
292+
return sync_to_async(cls.resolve_connection)(
293+
connection, args, iterable, max_limit=max_limit
294+
)
269295

270296
return cls.resolve_connection(connection, args, iterable, max_limit=max_limit)
271297

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from asgiref.sync import async_to_sync
2+
3+
4+
def assert_async_result_equal(schema, query, result):
5+
async_result = async_to_sync(schema.execute_async)(query)
6+
assert async_result == result

graphene_django/tests/test_fields.py

+139
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import re
33
from django.db.models import Count, Prefetch
4+
from asgiref.sync import sync_to_async, async_to_sync
45

56
import pytest
67

@@ -14,6 +15,7 @@
1415
FilmDetails as FilmDetailsModel,
1516
Reporter as ReporterModel,
1617
)
18+
from .async_test_helper import assert_async_result_equal
1719

1820

1921
class TestDjangoListField:
@@ -75,6 +77,7 @@ class Query(ObjectType):
7577

7678
result = schema.execute(query)
7779

80+
assert_async_result_equal(schema, query, result)
7881
assert not result.errors
7982
assert result.data == {
8083
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
@@ -102,6 +105,7 @@ class Query(ObjectType):
102105
result = schema.execute(query)
103106
assert not result.errors
104107
assert result.data == {"reporters": []}
108+
assert_async_result_equal(schema, query, result)
105109

106110
ReporterModel.objects.create(first_name="Tara", last_name="West")
107111
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
@@ -112,6 +116,7 @@ class Query(ObjectType):
112116
assert result.data == {
113117
"reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}]
114118
}
119+
assert_async_result_equal(schema, query, result)
115120

116121
def test_override_resolver(self):
117122
class Reporter(DjangoObjectType):
@@ -139,6 +144,37 @@ def resolve_reporters(_, info):
139144
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
140145

141146
result = schema.execute(query)
147+
assert not result.errors
148+
assert result.data == {"reporters": [{"firstName": "Tara"}]}
149+
150+
def test_override_resolver_async_execution(self):
151+
class Reporter(DjangoObjectType):
152+
class Meta:
153+
model = ReporterModel
154+
fields = ("first_name",)
155+
156+
class Query(ObjectType):
157+
reporters = DjangoListField(Reporter)
158+
159+
@staticmethod
160+
@sync_to_async
161+
def resolve_reporters(_, info):
162+
return ReporterModel.objects.filter(first_name="Tara")
163+
164+
schema = Schema(query=Query)
165+
166+
query = """
167+
query {
168+
reporters {
169+
firstName
170+
}
171+
}
172+
"""
173+
174+
ReporterModel.objects.create(first_name="Tara", last_name="West")
175+
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
176+
177+
result = async_to_sync(schema.execute_async)(query)
142178

143179
assert not result.errors
144180
assert result.data == {"reporters": [{"firstName": "Tara"}]}
@@ -203,6 +239,7 @@ class Query(ObjectType):
203239
{"firstName": "Debra", "articles": []},
204240
]
205241
}
242+
assert_async_result_equal(schema, query, result)
206243

207244
def test_override_resolver_nested_list_field(self):
208245
class Article(DjangoObjectType):
@@ -261,6 +298,7 @@ class Query(ObjectType):
261298
{"firstName": "Debra", "articles": []},
262299
]
263300
}
301+
assert_async_result_equal(schema, query, result)
264302

265303
def test_get_queryset_filter(self):
266304
class Reporter(DjangoObjectType):
@@ -306,6 +344,7 @@ def resolve_reporters(_, info):
306344

307345
assert not result.errors
308346
assert result.data == {"reporters": [{"firstName": "Tara"}]}
347+
assert_async_result_equal(schema, query, result)
309348

310349
def test_resolve_list(self):
311350
"""Resolving a plain list should work (and not call get_queryset)"""
@@ -354,6 +393,55 @@ def resolve_reporters(_, info):
354393
assert not result.errors
355394
assert result.data == {"reporters": [{"firstName": "Debra"}]}
356395

396+
def test_resolve_list_async(self):
397+
"""Resolving a plain list should work (and not call get_queryset) when running under async"""
398+
399+
class Reporter(DjangoObjectType):
400+
class Meta:
401+
model = ReporterModel
402+
fields = ("first_name", "articles")
403+
404+
@classmethod
405+
def get_queryset(cls, queryset, info):
406+
# Only get reporters with at least 1 article
407+
return queryset.annotate(article_count=Count("articles")).filter(
408+
article_count__gt=0
409+
)
410+
411+
class Query(ObjectType):
412+
reporters = DjangoListField(Reporter)
413+
414+
@staticmethod
415+
@sync_to_async
416+
def resolve_reporters(_, info):
417+
return [ReporterModel.objects.get(first_name="Debra")]
418+
419+
schema = Schema(query=Query)
420+
421+
query = """
422+
query {
423+
reporters {
424+
firstName
425+
}
426+
}
427+
"""
428+
429+
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
430+
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
431+
432+
ArticleModel.objects.create(
433+
headline="Amazing news",
434+
reporter=r1,
435+
pub_date=datetime.date.today(),
436+
pub_date_time=datetime.datetime.now(),
437+
editor=r1,
438+
)
439+
440+
result = async_to_sync(schema.execute_async)(query)
441+
442+
assert not result.errors
443+
assert result.data == {"reporters": [{"firstName": "Debra"}]}
444+
357445
def test_get_queryset_foreign_key(self):
358446
class Article(DjangoObjectType):
359447
class Meta:
@@ -413,6 +501,7 @@ class Query(ObjectType):
413501
{"firstName": "Debra", "articles": []},
414502
]
415503
}
504+
assert_async_result_equal(schema, query, result)
416505

417506
def test_resolve_list_external_resolver(self):
418507
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
@@ -461,6 +550,54 @@ class Query(ObjectType):
461550
assert not result.errors
462551
assert result.data == {"reporters": [{"firstName": "Debra"}]}
463552

553+
def test_resolve_list_external_resolver_async(self):
554+
"""Resolving a plain list from external resolver should work (and not call get_queryset)"""
555+
556+
class Reporter(DjangoObjectType):
557+
class Meta:
558+
model = ReporterModel
559+
fields = ("first_name", "articles")
560+
561+
@classmethod
562+
def get_queryset(cls, queryset, info):
563+
# Only get reporters with at least 1 article
564+
return queryset.annotate(article_count=Count("articles")).filter(
565+
article_count__gt=0
566+
)
567+
568+
@sync_to_async
569+
def resolve_reporters(_, info):
570+
return [ReporterModel.objects.get(first_name="Debra")]
571+
572+
class Query(ObjectType):
573+
reporters = DjangoListField(Reporter, resolver=resolve_reporters)
574+
575+
schema = Schema(query=Query)
576+
577+
query = """
578+
query {
579+
reporters {
580+
firstName
581+
}
582+
}
583+
"""
584+
585+
r1 = ReporterModel.objects.create(first_name="Tara", last_name="West")
586+
ReporterModel.objects.create(first_name="Debra", last_name="Payne")
587+
588+
ArticleModel.objects.create(
589+
headline="Amazing news",
590+
reporter=r1,
591+
pub_date=datetime.date.today(),
592+
pub_date_time=datetime.datetime.now(),
593+
editor=r1,
594+
)
595+
596+
result = async_to_sync(schema.execute_async)(query)
597+
598+
assert not result.errors
599+
assert result.data == {"reporters": [{"firstName": "Debra"}]}
600+
464601
def test_get_queryset_filter_external_resolver(self):
465602
class Reporter(DjangoObjectType):
466603
class Meta:
@@ -505,6 +642,7 @@ class Query(ObjectType):
505642

506643
assert not result.errors
507644
assert result.data == {"reporters": [{"firstName": "Tara"}]}
645+
assert_async_result_equal(schema, query, result)
508646

509647
def test_select_related_and_prefetch_related_are_respected(
510648
self, django_assert_num_queries
@@ -647,3 +785,4 @@ def resolve_articles(root, info):
647785
r'SELECT .* FROM "tests_film" INNER JOIN "tests_film_reporters" .* LEFT OUTER JOIN "tests_filmdetails"',
648786
captured.captured_queries[1]["sql"],
649787
)
788+
assert_async_result_equal(schema, query, result)

0 commit comments

Comments
 (0)