Skip to content

Commit 38b1aa5

Browse files
ahopkinsBenJeau
andauthored
fix: injecting async dependencies (#121)
Co-authored-by: Benoit Jeaurond <[email protected]>
1 parent a7f63a0 commit 38b1aa5

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

sanic_ext/extensions/injection/constructor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from inspect import isawaitable
3+
from inspect import iscoroutine
44
from typing import (
55
TYPE_CHECKING,
66
Any,
@@ -44,7 +44,7 @@ async def __call__(self, request, **kwargs):
4444
if self.pass_kwargs:
4545
args.update(kwargs)
4646
retval = self.func(request, **args)
47-
if isawaitable(retval):
47+
if iscoroutine(retval):
4848
retval = await retval
4949
return retval
5050
except TypeError as e:
@@ -136,6 +136,6 @@ async def do_cast(_type, constructor, request, **kwargs):
136136
args = [request] if constructor else []
137137

138138
retval = cast(*args, **kwargs)
139-
if isawaitable(retval):
139+
if iscoroutine(retval):
140140
retval = await retval
141141
return retval

tests/extensions/injection/test_add_dependency.py

+45
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ async def create(cls, request, person_id: int) -> Person:
3838
return cls(person_id=PersonID(person_id), name="noname", age=111)
3939

4040

41+
@dataclass
42+
class AsyncName:
43+
name: str
44+
45+
def __await__(self):
46+
return self
47+
48+
4149
counter = count()
4250

4351

@@ -235,3 +243,40 @@ async def handler(request, ws, foo: A):
235243

236244
request, response = app.test_client.websocket("/foo")
237245
assert ev.is_set()
246+
247+
248+
def test_injection_of_awaitable_variable_in_do_cast(app):
249+
"""Test for do_cast() iscoroutine() check"""
250+
251+
@app.get("/person/<name:str>")
252+
def handler(request, name: AsyncName):
253+
request.ctx.name = name
254+
return text(name.name)
255+
256+
app.ext.add_dependency(AsyncName)
257+
258+
request, response = app.test_client.get("/person/george")
259+
260+
assert response.body == b"george"
261+
assert isinstance(request.ctx.name, AsyncName)
262+
assert request.ctx.name.name == "george"
263+
264+
265+
def test_injection_of_awaitable_variable_in_call(app):
266+
"""Test for __call__() iscoroutine() check"""
267+
268+
@app.get("/person/<name:str>")
269+
def handler(request, name: AsyncName):
270+
request.ctx.name = name
271+
return text(name.name)
272+
273+
def test():
274+
return AsyncName("george")
275+
276+
app.ext.dependency(test())
277+
278+
request, response = app.test_client.get("/person/george")
279+
280+
assert response.body == b"george"
281+
assert isinstance(request.ctx.name, AsyncName)
282+
assert request.ctx.name.name == "george"

tests/extensions/openapi/test_parameter.py

+21
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ async def handler4(request: Request, val1: int):
7979
async def handler5(request: Request, val1: int):
8080
return text("ok")
8181

82+
@app.route("/test6/<val1:strorempty>")
83+
async def handler6(request: Request, val1: str):
84+
"""
85+
openapi:
86+
---
87+
operationId: get.test1
88+
parameters:
89+
- name: val1
90+
in: path
91+
description: val1 path param
92+
required: false
93+
"""
94+
return text("ok")
95+
8296
spec = get_spec(app)
8397
for i in range(1, 6):
8498
assert f"/test{i}/{{val1}}" in spec["paths"]
@@ -88,3 +102,10 @@ async def handler5(request: Request, val1: int):
88102
assert parameter["required"] is True
89103
assert parameter["schema"]["type"] == TYPE
90104
assert parameter["description"] == DESCRIPTION
105+
106+
assert "/test6/{val1}" in spec["paths"]
107+
parameter = spec["paths"]["/test6/{val1}"]["get"]["parameters"][0]
108+
assert parameter["name"] == NAME
109+
assert parameter["in"] == LOCATION
110+
assert parameter["required"] is False
111+
assert parameter["description"] == DESCRIPTION

0 commit comments

Comments
 (0)