From cb601ac23044a8d513709bbd135e13d54fa643ac Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 22 Jan 2022 18:28:50 +0100 Subject: [PATCH 1/7] rewrote path resolution logic --- starlite/app.py | 31 ++++++++++++++++- starlite/asgi.py | 42 +++++++++++++++++++---- starlite/routing.py | 64 ++++++++++++++++++----------------- tests/test_path_resolution.py | 20 +++++++++++ 4 files changed, 118 insertions(+), 39 deletions(-) diff --git a/starlite/app.py b/starlite/app.py index b6b3e1b2d0..2fd07a9934 100644 --- a/starlite/app.py +++ b/starlite/app.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Union, cast, Set, Any +from urllib.parse import urlparse from openapi_schema_pydantic import OpenAPI, Schema from openapi_schema_pydantic.util import construct_open_api_with_schema_class @@ -53,6 +54,7 @@ class Starlite(Router): "openapi_schema", "plugins", "state", + "route_map" # the rest of __slots__ are defined in Router and should not be duplicated # see: https://stackoverflow.com/questions/472000/usage-of-slots ) @@ -86,6 +88,7 @@ def __init__( # pylint: disable=too-many-locals self.state = State() self.plugins = plugins or [] self.routes: List[Union[BaseRoute, Mount]] = [] # type: ignore + self.route_map: Dict[str, Any] = {} super().__init__( dependencies=dependencies, guards=guards, @@ -117,11 +120,36 @@ def __init__( # pylint: disable=too-many-locals static_files = StaticFiles(html=config.html_mode, check_dir=False) static_files.all_directories = config.directories # type: ignore self.routes.append(Mount(path, static_files)) + self.construct_route_map() async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self await self.middleware_stack(scope, receive, send) + def construct_route_map(self): + for route in self.routes: + path = route.path + for param_definition in route.path_parameters: + path = path.replace(param_definition["full"], "") + path = path.replace("{}", "*") + components = path.split("/") if path not in ["/", None, ""] else ["_root"] + cur = self.route_map + for component in components: + if "_components" not in cur: + cur["_components"] = set() + components_set = cast(Set[str], cur["_components"]) + components_set.add(component) + if component not in cur: + cur[component] = {} + cur = cast(Dict[str, Any], cur[component]) + if "_handlers" not in cur: + cur["_handlers"] = {} + handlers = cast(Dict[str, BaseRoute], cur["_handlers"]) + if isinstance(route, HTTPRoute): + handlers["http"] = route + else: + handlers["websocket"] = route + def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override] """ Register a Controller, Route instance or RouteHandler on the app. @@ -131,6 +159,7 @@ def register(self, value: ControllerRouterHandler) -> None: # type: ignore[over handlers = super().register(value=value) for route_handler in handlers: self.create_handler_signature_model(route_handler=route_handler) + self.construct_route_map() def create_handler_signature_model(self, route_handler: BaseRouteHandler) -> None: """ diff --git a/starlite/asgi.py b/starlite/asgi.py index b8a60bc02f..0516f49e91 100644 --- a/starlite/asgi.py +++ b/starlite/asgi.py @@ -1,9 +1,12 @@ from inspect import getfullargspec, isawaitable -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any, List, cast, Set, Dict, Union -from starlette.routing import Router as StarletteRouter +from starlette.routing import Router as StarletteRouter, WebSocketRoute +from starlette.types import Scope, Receive, Send +from starlite.exceptions import NotFoundException from starlite.types import LifeCycleHandler +from starlite.routing import HTTPRoute if TYPE_CHECKING: # pragma: no cover from starlite.app import Starlite @@ -24,13 +27,38 @@ def __init__( self.app = app super().__init__(redirect_slashes=redirect_slashes, on_startup=on_startup, on_shutdown=on_shutdown) - def __getattribute__(self, key: str) -> Any: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ - We override attribute access to return the app routes + The main entry point to the Router class. """ - if key == "routes": - return self.app.routes - return super().__getattribute__(key) + + scope_type = scope["type"] + + if scope_type == "lifespan": + await self.lifespan(scope, receive, send) + return + path_params: List[str] = [] + path = cast(str, scope["path"]).strip() + if path != "/" and path.endswith("/"): + path = path.rstrip("/") + components = path.split("/") if path != "/" else ["_root"] + cur = self.app.route_map + for component in components: + components_set = cast(Set[str], cur["_components"]) + if component in components_set: + cur = cast(Dict[str, Any], cur[component]) + elif "*" in components_set: + path_params.append(component) + cur = cast(Dict[str, Any], cur["*"]) + else: + raise NotFoundException() + handlers = cast(Dict[str, Any], cur["_handlers"]) + try: + route = cast(Union[WebSocketRoute, HTTPRoute], handlers[scope_type]) + scope["path_params"] = route.parse_path_params(path_params) + await route.handle(scope=scope, receive=receive, send=send) + except KeyError: + raise NotFoundException() async def call_lifecycle_handler(self, handler: LifeCycleHandler) -> None: """ diff --git a/starlite/routing.py b/starlite/routing.py index feaf172db1..e2b490c385 100644 --- a/starlite/routing.py +++ b/starlite/routing.py @@ -2,16 +2,17 @@ from abc import ABC from inspect import isclass from typing import Any, Dict, ItemsView, List, Optional, Tuple, Union, cast +from uuid import UUID from pydantic import validate_arguments from pydantic.typing import AnyCallable -from starlette.routing import Match, compile_path, get_name +from starlette.routing import get_name from starlette.types import Receive, Scope, Send from typing_extensions import Type from starlite.controller import Controller from starlite.enums import HttpMethod, ScopeType -from starlite.exceptions import ImproperlyConfiguredException, MethodNotAllowedException +from starlite.exceptions import ImproperlyConfiguredException, MethodNotAllowedException, ValidationException from starlite.handlers import BaseRouteHandler, HTTPRouteHandler, WebsocketRouteHandler from starlite.provide import Provide from starlite.request import Request, WebSocket @@ -28,6 +29,13 @@ param_match_regex = re.compile(r"{(.*?)}") +param_type_map = { + "str": str, + "int": int, + "float": float, + "uuid": UUID +} + class BaseRoute(ABC): __slots__ = ( @@ -36,9 +44,7 @@ class BaseRoute(ABC): "methods", "param_convertors", "path", - "path_format", "path_parameters", - "path_regex", "scope_type", ) @@ -51,45 +57,41 @@ def __init__( scope_type: ScopeType, methods: Optional[List[Method]] = None, ): - if not path.startswith("/"): - raise ImproperlyConfiguredException("Routed paths must start with '/'") + self.path, self.path_parameters = self.parse_path(path) self.handler_names = handler_names - self.path = path self.scope_type = scope_type - self.path_regex, self.path_format, self.param_convertors = compile_path(path) - self.path_parameters: List[str] = param_match_regex.findall(self.path) - self.methods = methods or [] if "GET" in self.methods: self.methods.append("HEAD") - for parameter in self.path_parameters: - if ":" not in parameter or not parameter.split(":")[1]: + + def parse_path(self, path: str) -> Tuple[str, List[Dict[str, Any]]]: + path = normalize_path(path) + path_parameters = [] + + for param in param_match_regex.findall(path): + if ":" not in param: raise ImproperlyConfiguredException("path parameter must declare a type: '{parameter_name:type}'") + param_name, param_type = param.split(":") + path_parameters.append({"name": param_name, "type": param_type_map[param_type], "full": param}) + + return path, path_parameters @property def is_http_route(self) -> bool: """Determines whether the given route is an http or websocket route""" return self.scope_type == "http" - def matches(self, scope: Scope) -> Tuple[Match, Scope]: - """ - Try to match a given scope's path to self.path - - Note: The code in this method is adapted from starlette.routing - """ - if scope["type"] == self.scope_type.value: - match = self.path_regex.match(scope["path"]) - if match: - matched_params = match.groupdict() - for key, value in matched_params.items(): - matched_params[key] = self.param_convertors[key].convert(value) - path_params = dict(scope.get("path_params", {})) - path_params.update(matched_params) - child_scope = {"endpoint": self, "path_params": path_params} - if self.is_http_route and scope["method"] not in self.methods: - return Match.PARTIAL, child_scope - return Match.FULL, child_scope - return Match.NONE, {} + def parse_path_params(self, raw_params: List[str]) -> Dict[str, Any]: + try: + parsed_params: Dict[str, Any] = {} + for index, param_definition in enumerate(self.path_parameters): + raw_param = raw_params[index] + param_name = cast(str, param_definition["name"]) + param_type = cast(Type, param_definition["type"]) + parsed_params[param_name] = param_type(raw_param) + return parsed_params + except TypeError as e: + raise ValidationException from e class HTTPRoute(BaseRoute): diff --git a/tests/test_path_resolution.py b/tests/test_path_resolution.py index ba578cf4a6..7d23c0dd61 100644 --- a/tests/test_path_resolution.py +++ b/tests/test_path_resolution.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + import pytest from starlette.status import HTTP_200_OK, HTTP_204_NO_CONTENT @@ -10,6 +12,24 @@ def root_delete_handler() -> None: return None +@pytest.mark.parametrize( + "request_path, router_path", + [ + [f"/path/1/2/sub/{str(uuid4())}", "/path/{first:int}/{second:str}/sub/{third:uuid}"], + [f"/path/1/2/sub/{str(uuid4())}/", "/path/{first:int}/{second:str}/sub/{third:uuid}/"], + ["/", "/"], + ["", ""] + ], +) +def test_path_parsing_and_matching(request_path, router_path): + @get(path=router_path) + def test_method() -> None: + return None + + with create_test_client(test_method) as client: + response = client.get(request_path) + assert response.status_code == HTTP_200_OK + @pytest.mark.parametrize( "decorator, test_path, decorator_path, delete_handler", [ From 76d0219dd664ee3d345412fa93255ffa8e2ae451 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 22 Jan 2022 18:50:52 +0100 Subject: [PATCH 2/7] optimized handle methods --- starlite/app.py | 9 ++++--- starlite/handlers.py | 8 +++++-- starlite/openapi/parameters.py | 21 ++++------------- starlite/routing.py | 40 ++++++++++++++++++-------------- tests/openapi/test_parameters.py | 9 +------ tests/test_path_resolution.py | 25 +++++++++++++++++++- tests/test_route.py | 14 +---------- 7 files changed, 65 insertions(+), 61 deletions(-) diff --git a/starlite/app.py b/starlite/app.py index 2fd07a9934..074749858c 100644 --- a/starlite/app.py +++ b/starlite/app.py @@ -1,5 +1,4 @@ -from typing import Dict, List, Optional, Union, cast, Set, Any -from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Set, Union, cast from openapi_schema_pydantic import OpenAPI, Schema from openapi_schema_pydantic.util import construct_open_api_with_schema_class @@ -23,7 +22,7 @@ from starlite.datastructures import State from starlite.enums import MediaType from starlite.exceptions import HTTPException -from starlite.handlers import BaseRouteHandler +from starlite.handlers import BaseRouteHandler, HTTPRouteHandler from starlite.openapi.path_item import create_path_item from starlite.plugins.base import PluginProtocol from starlite.provide import Provide @@ -172,6 +171,10 @@ def create_handler_signature_model(self, route_handler: BaseRouteHandler) -> Non for provider in list(route_handler.resolve_dependencies().values()): if not provider.signature_model: provider.signature_model = create_function_signature_model(fn=provider.dependency, plugins=self.plugins) + route_handler.resolve_guards() + if isinstance(route_handler, HTTPRouteHandler): + route_handler.resolve_before_request() + route_handler.resolve_after_request() def build_middleware_stack( self, diff --git a/starlite/handlers.py b/starlite/handlers.py index 47c9c42cd0..5abe2b914b 100644 --- a/starlite/handlers.py +++ b/starlite/handlers.py @@ -395,7 +395,8 @@ async def handle_request(self, request: Request) -> StarletteResponse: if not self.fn: raise ImproperlyConfiguredException("cannot call 'handle' without a decorated function") - await self.authorize_connection(connection=request) + if self.resolve_guards(): + await self.authorize_connection(connection=request) before_request_handler = self.resolve_before_request() data = None @@ -407,7 +408,10 @@ async def handle_request(self, request: Request) -> StarletteResponse: # if data has not been returned by the before request handler, we proceed with the request if data is None: - params = await self.get_parameters_from_connection(connection=request) + if self.signature_model.__fields__: + params = await self.get_parameters_from_connection(connection=request) + else: + params = {} if isinstance(self.owner, Controller): data = self.fn(self.owner, **params) else: diff --git a/starlite/openapi/parameters.py b/starlite/openapi/parameters.py index 5a693c1043..12116d3e74 100644 --- a/starlite/openapi/parameters.py +++ b/starlite/openapi/parameters.py @@ -1,27 +1,16 @@ from typing import Any, Dict, List -from uuid import UUID from openapi_schema_pydantic import Parameter, Schema from pydantic.fields import ModelField -from typing_extensions import Type from starlite.handlers import BaseRouteHandler from starlite.openapi.schema import create_schema -def create_path_parameter_schema(path_parameter: str, field: ModelField, generate_examples: bool) -> Schema: - """Create a path parameter from the given path_param string in the format param_name:type""" - param_type_map: Dict[str, Type[Any]] = { - "str": str, - "float": float, - "int": int, - "uuid": UUID, - } - parameter_type = path_parameter.split(":")[1] - if parameter_type not in param_type_map: - raise TypeError(f"Unsupported path param type {parameter_type}") +def create_path_parameter_schema(path_parameter: Dict[str, Any], field: ModelField, generate_examples: bool) -> Schema: + """Create a path parameter from the given path_param definition""" field.sub_fields = None - field.outer_type_ = param_type_map[parameter_type] + field.outer_type_ = path_parameter["type"] return create_schema(field=field, generate_examples=generate_examples) @@ -34,7 +23,7 @@ def create_parameters( """ Create a list of path/query/header Parameter models for the given PathHandler """ - path_parameter_names = [path_param.split(":")[0] for path_param in path_parameters] + path_parameter_names = [path_param["name"] for path_param in path_parameters] parameters: List[Parameter] = [] ignored_fields = [ "data", @@ -55,7 +44,7 @@ def create_parameters( param_in = "path" required = True schema = create_path_parameter_schema( - path_parameter=[p for p in path_parameters if f_name in p][0], + path_parameter=[p for p in path_parameters if f_name in p["name"]][0], field=field, generate_examples=generate_examples, ) diff --git a/starlite/routing.py b/starlite/routing.py index e2b490c385..7677b2f9a4 100644 --- a/starlite/routing.py +++ b/starlite/routing.py @@ -12,7 +12,11 @@ from starlite.controller import Controller from starlite.enums import HttpMethod, ScopeType -from starlite.exceptions import ImproperlyConfiguredException, MethodNotAllowedException, ValidationException +from starlite.exceptions import ( + ImproperlyConfiguredException, + MethodNotAllowedException, + ValidationException, +) from starlite.handlers import BaseRouteHandler, HTTPRouteHandler, WebsocketRouteHandler from starlite.provide import Provide from starlite.request import Request, WebSocket @@ -29,12 +33,7 @@ param_match_regex = re.compile(r"{(.*?)}") -param_type_map = { - "str": str, - "int": int, - "float": float, - "uuid": UUID -} +param_type_map = {"str": str, "int": int, "float": float, "uuid": UUID} class BaseRoute(ABC): @@ -44,6 +43,7 @@ class BaseRoute(ABC): "methods", "param_convertors", "path", + "path_format", "path_parameters", "scope_type", ) @@ -57,24 +57,26 @@ def __init__( scope_type: ScopeType, methods: Optional[List[Method]] = None, ): - self.path, self.path_parameters = self.parse_path(path) + self.path, self.path_format, self.path_parameters = self.parse_path(path) self.handler_names = handler_names self.scope_type = scope_type self.methods = methods or [] if "GET" in self.methods: self.methods.append("HEAD") - def parse_path(self, path: str) -> Tuple[str, List[Dict[str, Any]]]: + def parse_path(self, path: str) -> Tuple[str, str, List[Dict[str, Any]]]: path = normalize_path(path) + path_format = path path_parameters = [] for param in param_match_regex.findall(path): if ":" not in param: raise ImproperlyConfiguredException("path parameter must declare a type: '{parameter_name:type}'") param_name, param_type = param.split(":") + path_format = path_format.replace(param, param_name) path_parameters.append({"name": param_name, "type": param_type_map[param_type], "full": param}) - return path, path_parameters + return path, path_format, path_parameters @property def is_http_route(self) -> bool: @@ -82,16 +84,18 @@ def is_http_route(self) -> bool: return self.scope_type == "http" def parse_path_params(self, raw_params: List[str]) -> Dict[str, Any]: - try: - parsed_params: Dict[str, Any] = {} - for index, param_definition in enumerate(self.path_parameters): - raw_param = raw_params[index] - param_name = cast(str, param_definition["name"]) + parsed_params: Dict[str, Any] = {} + for index, param_definition in enumerate(self.path_parameters): + raw_param = raw_params[index] + param_name = cast(str, param_definition["name"]) + try: param_type = cast(Type, param_definition["type"]) parsed_params[param_name] = param_type(raw_param) - return parsed_params - except TypeError as e: - raise ValidationException from e + except (ValueError, TypeError) as e: + print(str(raw_params)) + print(str(param_definition)) + raise ValidationException(f"unable to parse path parameter {str(raw_param)}") from e + return parsed_params class HTTPRoute(BaseRoute): diff --git a/tests/openapi/test_parameters.py b/tests/openapi/test_parameters.py index 691064f5ba..660401a102 100644 --- a/tests/openapi/test_parameters.py +++ b/tests/openapi/test_parameters.py @@ -1,19 +1,12 @@ from typing import Callable, cast -import pytest - from starlite import Starlite from starlite.openapi.enums import OpenAPIType -from starlite.openapi.parameters import create_parameters, create_path_parameter_schema +from starlite.openapi.parameters import create_parameters from starlite.utils import create_function_signature_model, find_index from tests.openapi.utils import PersonController -def test_create_path_parameters_schema_raise_for_invalid_type(): - with pytest.raises(TypeError): - create_path_parameter_schema(path_parameter="string_int_id:strint", field=None, generate_examples=False) - - def test_create_parameters(): app = Starlite(route_handlers=[PersonController]) index = find_index(app.routes, lambda x: x.path_format == "/{service_id}/person") diff --git a/tests/test_path_resolution.py b/tests/test_path_resolution.py index 7d23c0dd61..c05f9f341a 100644 --- a/tests/test_path_resolution.py +++ b/tests/test_path_resolution.py @@ -18,7 +18,7 @@ def root_delete_handler() -> None: [f"/path/1/2/sub/{str(uuid4())}", "/path/{first:int}/{second:str}/sub/{third:uuid}"], [f"/path/1/2/sub/{str(uuid4())}/", "/path/{first:int}/{second:str}/sub/{third:uuid}/"], ["/", "/"], - ["", ""] + ["", ""], ], ) def test_path_parsing_and_matching(request_path, router_path): @@ -30,6 +30,29 @@ def test_method() -> None: response = client.get(request_path) assert response.status_code == HTTP_200_OK + +def test_path_parsing_with_ambigous_paths(): + @get(path="/{path_param:int}", media_type=MediaType.TEXT) + def path_param(path_param: int) -> str: + return str(path_param) + + @get(path="/query_param", media_type=MediaType.TEXT) + def query_param(value: int) -> str: + return str(value) + + @get(path="/mixed/{path_param:int}", media_type=MediaType.TEXT) + def mixed_params(path_param: int, value: int) -> str: + return str(path_param + value) + + with create_test_client([path_param, query_param, mixed_params]) as client: + response = client.get("/1") + assert response.status_code == HTTP_200_OK + response = client.get("/query_param?value=1") + assert response.status_code == HTTP_200_OK + response = client.get("/mixed/1/?value=1") + assert response.status_code == HTTP_200_OK + + @pytest.mark.parametrize( "decorator, test_path, decorator_path, delete_handler", [ diff --git a/tests/test_route.py b/tests/test_route.py index 691bc00916..62a8dcee97 100644 --- a/tests/test_route.py +++ b/tests/test_route.py @@ -1,8 +1,7 @@ import pytest -from starlette.routing import Match from starlite import get, post -from starlite.exceptions import ImproperlyConfiguredException, MethodNotAllowedException +from starlite.exceptions import MethodNotAllowedException from starlite.routing import HTTPRoute @@ -22,14 +21,3 @@ async def test_http_route_raises_for_unsupported_method(): with pytest.raises(MethodNotAllowedException): await route.handle(scope={"method": "DELETE"}, receive=lambda x: x, send=lambda x: x) - - -def test_match_partial(): - route = HTTPRoute(path="/", route_handlers=[my_get_handler, my_post_handler]) - match, _ = route.matches(scope={"path": "/", "method": "DELETE", "type": "http"}) - assert match == Match.PARTIAL - - -def test_http_route_raises_for_no_leading_slash(): - with pytest.raises(ImproperlyConfiguredException): - assert HTTPRoute(path="first", route_handlers=[]) From 65db8799bcf6c342d7668dee0a6188291d166e79 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 22 Jan 2022 23:47:58 +0100 Subject: [PATCH 3/7] added ASGI handler --- starlite/__init__.py | 28 +++++---- starlite/app.py | 40 +++++++++---- starlite/asgi.py | 32 +++++----- starlite/datastructures.py | 8 ++- starlite/enums.py | 1 + starlite/handlers.py | 67 +++++++++++++++++---- starlite/openapi/parameters.py | 2 +- starlite/plugins/base.py | 15 ++--- starlite/request.py | 88 +++++++++++++++++----------- starlite/routing.py | 75 ++++++++++++++++++------ tests/handlers/test_asgi_handlers.py | 70 ++++++++++++++++++++++ tests/test_path_resolution.py | 29 ++++++++- 12 files changed, 339 insertions(+), 116 deletions(-) create mode 100644 tests/handlers/test_asgi_handlers.py diff --git a/starlite/__init__.py b/starlite/__init__.py index d6abe9789f..63288930a9 100644 --- a/starlite/__init__.py +++ b/starlite/__init__.py @@ -24,9 +24,11 @@ StarLiteException, ) from .handlers import ( + ASGIRouteHandler, BaseRouteHandler, HTTPRouteHandler, WebsocketRouteHandler, + asgi, delete, get, patch, @@ -49,18 +51,24 @@ __all__ = [ "AbstractAuthenticationMiddleware", + "asgi", + "ASGIRouteHandler", "AuthenticationResult", "BaseRoute", "BaseRouteHandler", "Body", - "CORSConfig", "Controller", + "CORSConfig", + "create_test_client", + "create_test_request", + "delete", "DTOFactory", "File", + "get", "HTTPException", + "HttpMethod", "HTTPRoute", "HTTPRouteHandler", - "HttpMethod", "ImproperlyConfiguredException", "InternalServerException", "LoggingConfig", @@ -74,33 +82,29 @@ "OpenAPIMediaType", "Parameter", "Partial", + "patch", "PermissionDeniedException", "PluginProtocol", + "post", "Provide", + "put", "Redirect", "Request", "RequestEncodingType", "Response", "ResponseHeader", + "route", "Router", "ScopeType", "ServiceUnavailableException", - "StarLiteException", "Starlite", + "StarLiteException", "State", "StaticFilesConfig", "Stream", "TestClient", "WebSocket", + "websocket", "WebSocketRoute", "WebsocketRouteHandler", - "create_test_client", - "create_test_request", - "delete", - "get", - "patch", - "post", - "put", - "route", - "websocket", ] diff --git a/starlite/app.py b/starlite/app.py index 074749858c..7ea2465795 100644 --- a/starlite/app.py +++ b/starlite/app.py @@ -11,7 +11,6 @@ from starlette.middleware.cors import CORSMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware -from starlette.routing import Mount from starlette.staticfiles import StaticFiles from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from starlette.types import ASGIApp, Receive, Scope, Send @@ -22,13 +21,13 @@ from starlite.datastructures import State from starlite.enums import MediaType from starlite.exceptions import HTTPException -from starlite.handlers import BaseRouteHandler, HTTPRouteHandler +from starlite.handlers import BaseRouteHandler, HTTPRouteHandler, asgi from starlite.openapi.path_item import create_path_item from starlite.plugins.base import PluginProtocol from starlite.provide import Provide from starlite.request import Request from starlite.response import Response -from starlite.routing import BaseRoute, HTTPRoute, Router +from starlite.routing import BaseRoute, HTTPRoute, Router, WebSocketRoute from starlite.types import ( AfterRequestHandler, BeforeRequestHandler, @@ -53,7 +52,8 @@ class Starlite(Router): "openapi_schema", "plugins", "state", - "route_map" + "route_map", + "static_paths" # the rest of __slots__ are defined in Router and should not be duplicated # see: https://stackoverflow.com/questions/472000/usage-of-slots ) @@ -86,8 +86,9 @@ def __init__( # pylint: disable=too-many-locals self.debug = debug self.state = State() self.plugins = plugins or [] - self.routes: List[Union[BaseRoute, Mount]] = [] # type: ignore + self.routes: List[BaseRoute] = [] self.route_map: Dict[str, Any] = {} + self.static_paths = set() super().__init__( dependencies=dependencies, guards=guards, @@ -116,22 +117,29 @@ def __init__( # pylint: disable=too-many-locals if static_files_config: for config in static_files_config if isinstance(static_files_config, list) else [static_files_config]: path = normalize_path(config.path) + self.static_paths.add(path) static_files = StaticFiles(html=config.html_mode, check_dir=False) static_files.all_directories = config.directories # type: ignore - self.routes.append(Mount(path, static_files)) - self.construct_route_map() + self.register(asgi(path=path)(static_files)) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: scope["app"] = self - await self.middleware_stack(scope, receive, send) + if scope["type"] != "lifespan": + await self.middleware_stack(scope, receive, send) + else: + await self.asgi_router.lifespan(scope, receive, send) + + def construct_route_map(self) -> None: + """ + Create a map of the app's routes. This map is used in the asgi router to route requests. - def construct_route_map(self): + """ for route in self.routes: path = route.path for param_definition in route.path_parameters: path = path.replace(param_definition["full"], "") path = path.replace("{}", "*") - components = path.split("/") if path not in ["/", None, ""] else ["_root"] + components = path.split("/") if path not in ["", "/", None] else ["_root"] cur = self.route_map for component in components: if "_components" not in cur: @@ -139,15 +147,23 @@ def construct_route_map(self): components_set = cast(Set[str], cur["_components"]) components_set.add(component) if component not in cur: - cur[component] = {} + cur[component] = {"_components": set()} cur = cast(Dict[str, Any], cur[component]) if "_handlers" not in cur: cur["_handlers"] = {} + if "_handler_types" not in cur: + cur["_handler_types"] = set() + if path in self.static_paths: + cur["static_path"] = path + handler_type = cast(Set[str], cur["_handler_types"]) + handler_type.add(route.scope_type.value) handlers = cast(Dict[str, BaseRoute], cur["_handlers"]) if isinstance(route, HTTPRoute): handlers["http"] = route - else: + elif isinstance(route, WebSocketRoute): handlers["websocket"] = route + else: + handlers["asgi"] = route def register(self, value: ControllerRouterHandler) -> None: # type: ignore[override] """ diff --git a/starlite/asgi.py b/starlite/asgi.py index 0516f49e91..1b85844e5c 100644 --- a/starlite/asgi.py +++ b/starlite/asgi.py @@ -1,12 +1,13 @@ from inspect import getfullargspec, isawaitable -from typing import TYPE_CHECKING, Any, List, cast, Set, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, List, Set, Union, cast -from starlette.routing import Router as StarletteRouter, WebSocketRoute -from starlette.types import Scope, Receive, Send +from starlette.routing import Router as StarletteRouter +from starlette.routing import WebSocketRoute +from starlette.types import Receive, Scope, Send from starlite.exceptions import NotFoundException +from starlite.routing import ASGIRoute, HTTPRoute from starlite.types import LifeCycleHandler -from starlite.routing import HTTPRoute if TYPE_CHECKING: # pragma: no cover from starlite.app import Starlite @@ -31,12 +32,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ The main entry point to the Router class. """ - scope_type = scope["type"] - - if scope_type == "lifespan": - await self.lifespan(scope, receive, send) - return path_params: List[str] = [] path = cast(str, scope["path"]).strip() if path != "/" and path.endswith("/"): @@ -50,15 +46,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: elif "*" in components_set: path_params.append(component) cur = cast(Dict[str, Any], cur["*"]) + elif cur.get("static_path"): # noqa: SIM106 + static_path = cast(str, cur["static_path"]) + scope["path"] = scope["path"].replace(static_path, "") + scope_type = "asgi" else: raise NotFoundException() - handlers = cast(Dict[str, Any], cur["_handlers"]) try: - route = cast(Union[WebSocketRoute, HTTPRoute], handlers[scope_type]) - scope["path_params"] = route.parse_path_params(path_params) + handlers = cast(Dict[str, Any], cur["_handlers"]) + handler_types = cast(Set[str], cur["_handler_types"]) + route = cast( + Union[WebSocketRoute, ASGIRoute, HTTPRoute], + handlers[scope_type if scope_type in handler_types else "asgi"], + ) + scope["path_params"] = route.parse_path_params(path_params) # type: ignore await route.handle(scope=scope, receive=receive, send=send) - except KeyError: - raise NotFoundException() + except KeyError as e: + raise NotFoundException() from e async def call_lifecycle_handler(self, handler: LifeCycleHandler) -> None: """ diff --git a/starlite/datastructures.py b/starlite/datastructures.py index 0567acd87d..faaa140981 100644 --- a/starlite/datastructures.py +++ b/starlite/datastructures.py @@ -19,10 +19,12 @@ def copy(self) -> "State": return copy(self) -class File(BaseModel): +class StarliteType(BaseModel): class Config: arbitrary_types_allowed = True + +class File(StarliteType): path: FilePath filename: str stat_result: Optional[os.stat_result] = None @@ -35,11 +37,11 @@ def validate_status_code( # pylint: disable=no-self-argument, no-self-use return value or os.stat(cast(str, values.get("path"))) -class Redirect(BaseModel): +class Redirect(StarliteType): path: str -class Stream(BaseModel): +class Stream(StarliteType): class Config: arbitrary_types_allowed = True diff --git a/starlite/enums.py b/starlite/enums.py index 54c31a881a..55a7638faf 100644 --- a/starlite/enums.py +++ b/starlite/enums.py @@ -45,3 +45,4 @@ class RequestEncodingType(str, Enum): class ScopeType(str, Enum): HTTP = "http" WEBSOCKET = "websocket" + ASGI = "asgi" diff --git a/starlite/handlers.py b/starlite/handlers.py index 5abe2b914b..5c9237d2a5 100644 --- a/starlite/handlers.py +++ b/starlite/handlers.py @@ -23,10 +23,11 @@ from starlette.responses import Response as StarletteResponse from starlette.responses import StreamingResponse from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT +from starlette.types import Receive, Scope, Send from starlite.constants import REDIRECT_STATUS_CODES from starlite.controller import Controller -from starlite.datastructures import File, Redirect, Stream +from starlite.datastructures import File, Redirect, StarliteType, Stream from starlite.enums import HttpMethod, MediaType from starlite.exceptions import ( HTTPException, @@ -392,7 +393,7 @@ async def handle_request(self, request: Request) -> StarletteResponse: """ Handles a given Request in relation to self. """ - if not self.fn: + if not self.fn or not self.signature_model: raise ImproperlyConfiguredException("cannot call 'handle' without a decorated function") if self.resolve_guards(): @@ -428,16 +429,18 @@ async def to_response(self, request: Request, data: Any) -> StarletteResponse: after_request = self.resolve_after_request() media_type = self.media_type.value if isinstance(self.media_type, Enum) else self.media_type headers = {k: v.value for k, v in self.resolve_response_headers().items()} - if isinstance(data, StarletteResponse): - response = data - elif isinstance(data, Redirect): - response = RedirectResponse(headers=headers, status_code=self.status_code, url=data.path) - elif isinstance(data, File): - response = FileResponse(media_type=media_type, headers=headers, **data.dict()) - elif isinstance(data, Stream): - response = StreamingResponse( - content=data.iterator, status_code=self.status_code, media_type=media_type, headers=headers - ) + response: StarletteResponse + if isinstance(data, (StarletteResponse, StarliteType)): + if isinstance(data, Redirect): + response = RedirectResponse(headers=headers, status_code=self.status_code, url=data.path) + elif isinstance(data, File): + response = FileResponse(media_type=media_type, headers=headers, **data.dict()) + elif isinstance(data, Stream): + response = StreamingResponse( + content=data.iterator, status_code=self.status_code, media_type=media_type, headers=headers + ) + else: + response = cast(StarletteResponse, data) else: plugin = get_plugin_for_value(data, request.app.plugins) if plugin: @@ -749,3 +752,43 @@ async def handle_websocket(self, web_socket: WebSocket) -> None: websocket = WebsocketRouteHandler + + +class ASGIRouteHandler(BaseRouteHandler): + def __call__(self, fn: AnyCallable) -> "ASGIRouteHandler": + """ + Replaces a function with itself + """ + self.fn = fn + self.validate_handler_function() + return self + + def validate_handler_function(self) -> None: + """ + Validates the route handler function once it's set by inspecting its return annotations + """ + super().validate_handler_function() + signature = Signature.from_callable(cast(AnyCallable, self.fn)) + + if signature.return_annotation is not None: + raise ImproperlyConfiguredException("ASGI handler functions should return 'None'") + if any(key not in signature.parameters for key in ["scope", "send", "receive"]): + raise ImproperlyConfiguredException( + "ASGI handler functions should define 'scope', 'send' and 'receive' arguments" + ) + + async def handle_asgi(self, scope: Scope, send: Send, receive: Receive) -> None: + """ + Handles a given Websocket in relation to self. + """ + if not self.fn: # pragma: no cover + raise ImproperlyConfiguredException("cannot call a route handler without a decorated function") + connection = HTTPConnection(scope=scope, receive=receive) + await self.authorize_connection(connection=connection) + if isinstance(self.owner, Controller): + await self.fn(self.owner, scope=scope, receive=receive, send=send) + else: + await self.fn(scope=scope, receive=receive, send=send) + + +asgi = ASGIRouteHandler diff --git a/starlite/openapi/parameters.py b/starlite/openapi/parameters.py index 12116d3e74..119810d04b 100644 --- a/starlite/openapi/parameters.py +++ b/starlite/openapi/parameters.py @@ -17,7 +17,7 @@ def create_path_parameter_schema(path_parameter: Dict[str, Any], field: ModelFie def create_parameters( route_handler: BaseRouteHandler, handler_fields: Dict[str, ModelField], - path_parameters: List[str], + path_parameters: List[Dict[str, Any]], generate_examples: bool, ) -> List[Parameter]: """ diff --git a/starlite/plugins/base.py b/starlite/plugins/base.py index 148b789d0a..c63bb323c1 100644 --- a/starlite/plugins/base.py +++ b/starlite/plugins/base.py @@ -45,13 +45,14 @@ def from_dict(self, model_class: Type[T], **kwargs: Any) -> T: def get_plugin_for_value(value: Any, plugins: List[PluginProtocol]) -> Optional[PluginProtocol]: """Helper function to returns a plugin to handle a given value, if any plugin supports it""" - if value and isinstance(value, (list, tuple)): - value = value[0] - if get_args(value): - value = get_args(value)[0] - for plugin in plugins: - if plugin.is_plugin_supported_type(value): - return plugin + if plugins: + if value and isinstance(value, (list, tuple)): + value = value[0] + if get_args(value): + value = get_args(value)[0] + for plugin in plugins: + if plugin.is_plugin_supported_type(value): + return plugin return None diff --git a/starlite/request.py b/starlite/request.py index c2aa913ae4..ea826d5fe3 100644 --- a/starlite/request.py +++ b/starlite/request.py @@ -1,5 +1,7 @@ from contextlib import suppress -from typing import TYPE_CHECKING, Any, Dict, Generic, List, TypeVar, Union, cast +from functools import reduce +from typing import TYPE_CHECKING, Any, Dict, Generic, List, Tuple, TypeVar, Union, cast +from urllib.parse import parse_qsl from orjson import JSONDecodeError, loads from pydantic.fields import SHAPE_LIST, SHAPE_SINGLETON, ModelField, Undefined @@ -65,30 +67,44 @@ def auth(self) -> Auth: return cast(Auth, self.scope["auth"]) +_true_values = {"True", "true"} +_false_values = {"False", "false"} + + +def _query_param_reducer( + acc: Dict[str, Union[str, List[str]]], cur: Tuple[str, str] +) -> Dict[str, Union[str, List[str]]]: + key, value = cur + if value in _true_values: + value = True # type: ignore + elif value in _false_values: + value = False # type: ignore + param = acc.get(key) + if param: + if isinstance(param, str): + acc[key] = [param, value] + else: + acc[key] = [*cast(List[Any], param), value] + else: + acc[key] = value + return acc + + def parse_query_params(connection: HTTPConnection) -> Dict[str, Any]: """ Parses and normalize a given connection's query parameters into a regular dictionary Extends the Starlette query params handling by supporting lists """ - params: Dict[str, Union[str, List[str]]] = {} try: - for key, value in connection.query_params.multi_items(): - if value in ["True", "true"]: - value = True # type: ignore - elif value in ["False", "false"]: - value = False # type: ignore - param = params.get(key) - if param: - if isinstance(param, str): - params[key] = [param, value] - else: - params[key] = [*cast(List[Any], param), value] - else: - params[key] = value - return params + qs = cast(Union[str, bytes], connection.scope["query_string"]) + return reduce( + _query_param_reducer, + parse_qsl(qs if isinstance(qs, str) else qs.decode("latin-1"), keep_blank_values=True), + {}, + ) except KeyError: - return params + return {} def handle_multipart(media_type: RequestEncodingType, form_data: FormData, field: ModelField) -> Any: @@ -144,25 +160,27 @@ def get_connection_parameters( return query_params[field_name] extra = field.field_info.extra - parameter_name = None - source = None + extra_keys = set(extra.keys()) default = field.default if field.default is not Undefined else None - if extra.get("query"): - parameter_name = extra["query"] - source = query_params - if extra.get("header"): - parameter_name = extra["header"] - source = header_params - if extra.get("cookie"): - parameter_name = extra["cookie"] - source = connection.cookies - if parameter_name and source: - parameter_is_required = extra["required"] - try: - return source[parameter_name] - except KeyError as e: - if parameter_is_required and not default: - raise ValidationException(f"Missing required parameter {parameter_name}") from e + if extra_keys: + parameter_name = None + source = None + if "query" in extra_keys and extra["query"]: + parameter_name = extra["query"] + source = query_params + elif "header" in extra_keys and extra["header"]: + parameter_name = extra["header"] + source = header_params + elif "cookie" in extra_keys and extra["cookie"]: + parameter_name = extra["cookie"] + source = connection.cookies + if parameter_name and source: + parameter_is_required = extra["required"] + try: + return source[parameter_name] + except KeyError as e: + if parameter_is_required and not default: + raise ValidationException(f"Missing required parameter {parameter_name}") from e return default diff --git a/starlite/routing.py b/starlite/routing.py index 7677b2f9a4..aeb16f288f 100644 --- a/starlite/routing.py +++ b/starlite/routing.py @@ -17,7 +17,12 @@ MethodNotAllowedException, ValidationException, ) -from starlite.handlers import BaseRouteHandler, HTTPRouteHandler, WebsocketRouteHandler +from starlite.handlers import ( + ASGIRouteHandler, + BaseRouteHandler, + HTTPRouteHandler, + WebsocketRouteHandler, +) from starlite.provide import Provide from starlite.request import Request, WebSocket from starlite.response import Response @@ -64,7 +69,11 @@ def __init__( if "GET" in self.methods: self.methods.append("HEAD") - def parse_path(self, path: str) -> Tuple[str, str, List[Dict[str, Any]]]: + @staticmethod + def parse_path(path: str) -> Tuple[str, str, List[Dict[str, Any]]]: + """ + Normalizes and parses a path + """ path = normalize_path(path) path_format = path path_parameters = [] @@ -75,15 +84,12 @@ def parse_path(self, path: str) -> Tuple[str, str, List[Dict[str, Any]]]: param_name, param_type = param.split(":") path_format = path_format.replace(param, param_name) path_parameters.append({"name": param_name, "type": param_type_map[param_type], "full": param}) - return path, path_format, path_parameters - @property - def is_http_route(self) -> bool: - """Determines whether the given route is an http or websocket route""" - return self.scope_type == "http" - def parse_path_params(self, raw_params: List[str]) -> Dict[str, Any]: + """ + Parses raw path parameters by mapping them into a dictionary + """ parsed_params: Dict[str, Any] = {} for index, param_definition in enumerate(self.path_parameters): raw_param = raw_params[index] @@ -91,9 +97,7 @@ def parse_path_params(self, raw_params: List[str]) -> Dict[str, Any]: try: param_type = cast(Type, param_definition["type"]) parsed_params[param_name] = param_type(raw_param) - except (ValueError, TypeError) as e: - print(str(raw_params)) - print(str(param_definition)) + except (ValueError, TypeError) as e: # pragma: no cover raise ValidationException(f"unable to parse path parameter {str(raw_param)}") from e return parsed_params @@ -178,6 +182,34 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: await self.route_handler.handle_websocket(web_socket=web_socket) +class ASGIRoute(BaseRoute): + __slots__ = ( + "route_handler", + # the rest of __slots__ are defined in BaseRoute and should not be duplicated + # see: https://stackoverflow.com/questions/472000/usage-of-slots + ) + + @validate_arguments(config={"arbitrary_types_allowed": True}) + def __init__( + self, + *, + path: str, + route_handler: ASGIRouteHandler, + ): + self.route_handler = route_handler + super().__init__( + path=path, + scope_type=ScopeType.ASGI, + handler_names=[get_name(cast(AnyCallable, route_handler.fn))], + ) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + ASGI app that creates a WebSocket from the passed in args, and then awaits the handler function + """ + await self.route_handler.handle_asgi(scope=scope, receive=receive, send=send) + + class Router: __slots__ = ( "after_request", @@ -236,16 +268,16 @@ def route_handler_method_map(self) -> Dict[str, Union[WebsocketRouteHandler, Dic @staticmethod def map_route_handlers( value: Union[Controller, BaseRouteHandler, "Router"], - ) -> ItemsView[str, Union[WebsocketRouteHandler, Dict[HttpMethod, HTTPRouteHandler]]]: + ) -> ItemsView[str, Union[WebsocketRouteHandler, ASGIRoute, Dict[HttpMethod, HTTPRouteHandler]]]: """ Maps route handlers to http methods """ - handlers_map: Dict[str, Union[WebsocketRouteHandler, Dict[HttpMethod, HTTPRouteHandler]]] = {} + handlers_map: Dict[str, Any] = {} if isinstance(value, BaseRouteHandler): for path in value.paths: if isinstance(value, HTTPRouteHandler): handlers_map[path] = {http_method: value for http_method in value.http_methods} - elif isinstance(value, WebsocketRouteHandler): + elif isinstance(value, (WebsocketRouteHandler, ASGIRouteHandler)): handlers_map[path] = value elif isinstance(value, Router): handlers_map = value.route_handler_method_map @@ -258,9 +290,9 @@ def map_route_handlers( if not isinstance(handlers_map.get(path), dict): handlers_map[path] = {} for http_method in route_handler.http_methods: - handlers_map[path][http_method] = route_handler # type: ignore + handlers_map[path][http_method] = route_handler else: - handlers_map[path] = cast(WebsocketRouteHandler, route_handler) + handlers_map[path] = cast(Union[WebsocketRouteHandler, ASGIRouteHandler], route_handler) return handlers_map.items() def validate_registration_value( @@ -285,7 +317,9 @@ def validate_registration_value( value.owner = self return cast(Union[Controller, BaseRouteHandler, "Router"], value) - def register(self, value: ControllerRouterHandler) -> List[Union[HTTPRouteHandler, WebsocketRouteHandler]]: + def register( + self, value: ControllerRouterHandler + ) -> List[Union[HTTPRouteHandler, WebsocketRouteHandler, ASGIRouteHandler]]: """ Register a Controller, Route instance or RouteHandler on the router @@ -293,14 +327,17 @@ def register(self, value: ControllerRouterHandler) -> List[Union[HTTPRouteHandle by any of the routing decorators (e.g. route, get, post...) exported from 'starlite.routing' """ validated_value = self.validate_registration_value(value) - handlers: List[Union[HTTPRouteHandler, WebsocketRouteHandler]] = [] + handlers: List[Union[HTTPRouteHandler, WebsocketRouteHandler, ASGIRouteHandler]] = [] for route_path, handler_or_method_map in self.map_route_handlers(value=validated_value): path = join_paths([self.path, route_path]) if isinstance(handler_or_method_map, WebsocketRouteHandler): handlers.append(handler_or_method_map) self.routes.append(WebSocketRoute(path=path, route_handler=handler_or_method_map)) + elif isinstance(handler_or_method_map, ASGIRouteHandler): + handlers.append(handler_or_method_map) + self.routes.append(ASGIRoute(path=path, route_handler=handler_or_method_map)) else: - route_handlers = list(handler_or_method_map.values()) + route_handlers = list(cast(Dict[HttpMethod, HTTPRouteHandler], handler_or_method_map).values()) handlers.extend(route_handlers) if self.route_handler_method_map.get(path): existing_route_index = find_index( diff --git a/tests/handlers/test_asgi_handlers.py b/tests/handlers/test_asgi_handlers.py new file mode 100644 index 0000000000..40c5db4d14 --- /dev/null +++ b/tests/handlers/test_asgi_handlers.py @@ -0,0 +1,70 @@ +import pytest +from starlette.status import HTTP_200_OK +from starlette.types import Receive, Scope, Send + +from starlite import ( + Controller, + ImproperlyConfiguredException, + MediaType, + Response, + asgi, + create_test_client, +) + + +def test_asgi_handler_validation(): + def fn_without_scope_arg(receive: Receive, send: Send) -> None: + pass + + with pytest.raises(ImproperlyConfiguredException): + asgi(path="/")(fn_without_scope_arg) + + def fn_without_receive_arg(scope: Scope, send: Send) -> None: + pass + + with pytest.raises(ImproperlyConfiguredException): + asgi(path="/")(fn_without_receive_arg) + + def fn_without_send_arg(scope: Scope, receive: Receive) -> None: + pass + + with pytest.raises(ImproperlyConfiguredException): + asgi(path="/")(fn_without_send_arg) + + def fn_with_return_annotation(scope: Scope, receive: Receive, send: Send) -> dict: + return dict() + + with pytest.raises(ImproperlyConfiguredException): + asgi(path="/")(fn_with_return_annotation) + + asgi_handler_with_no_fn = asgi(path="/") + + with pytest.raises(ImproperlyConfiguredException): + create_test_client(route_handlers=asgi_handler_with_no_fn) + + +def test_handle_asgi(): + @asgi(path="/") + async def root_asgi_handler(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert scope["method"] == "GET" + response = Response("Hello World", media_type=MediaType.TEXT, status_code=HTTP_200_OK) + await response(scope, receive, send) + + class MyController(Controller): + path = "/asgi" + + @asgi() + async def root_asgi_handler(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert scope["method"] == "GET" + response = Response("Hello World", media_type=MediaType.TEXT, status_code=HTTP_200_OK) + await response(scope, receive, send) + + with create_test_client([root_asgi_handler, MyController]) as client: + response = client.get("/") + assert response.status_code == HTTP_200_OK + assert response.text == "Hello World" + response = client.get("/asgi") + assert response.status_code == HTTP_200_OK + assert response.text == "Hello World" diff --git a/tests/test_path_resolution.py b/tests/test_path_resolution.py index c05f9f341a..234d7e7b9f 100644 --- a/tests/test_path_resolution.py +++ b/tests/test_path_resolution.py @@ -1,7 +1,12 @@ from uuid import uuid4 import pytest -from starlette.status import HTTP_200_OK, HTTP_204_NO_CONTENT +from starlette.status import ( + HTTP_200_OK, + HTTP_204_NO_CONTENT, + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) from starlite import Controller, MediaType, create_test_client, delete, get from tests import Person, PersonFactory @@ -108,3 +113,25 @@ def handler_fn(some_id: int = 1) -> str: fourth_response = client.get("/something/2") assert fourth_response.status_code == HTTP_200_OK assert fourth_response.text == "2" + + +@pytest.mark.parametrize( + "handler, handler_path, request_path, expected_status_code", + [ + (get, "/sub-path", "/", HTTP_404_NOT_FOUND), + (get, "/sub/path", "/sub-path", HTTP_404_NOT_FOUND), + (get, "/sub/path", "/sub", HTTP_404_NOT_FOUND), + (get, "/sub/path/{path_param:int}", "/sub/path", HTTP_404_NOT_FOUND), + (get, "/sub/path/{path_param:int}", "/sub/path/abcd", HTTP_400_BAD_REQUEST), + (get, "/sub/path/{path_param:uuid}", "/sub/path/100", HTTP_400_BAD_REQUEST), + (get, "/sub/path/{path_param:float}", "/sub/path/abcd", HTTP_400_BAD_REQUEST), + ], +) +def test_path_validation(handler, handler_path, request_path, expected_status_code): + @get(handler_path) + def handler(**kwargs) -> None: + ... + + with create_test_client(handler) as client: + response = client.get(request_path) + assert response.status_code == expected_status_code From 0374b0a586d55fbbac035e98450d8bbc056fc366 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Jan 2022 17:28:29 +0100 Subject: [PATCH 4/7] updated dependencies --- poetry.lock | 72 ++++++++++++++++++++++++++--------------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4f1728f546..8b8f6b4d36 100644 --- a/poetry.lock +++ b/poetry.lock @@ -18,11 +18,11 @@ trio = ["trio (>=0.16)"] [[package]] name = "asgiref" -version = "3.4.1" +version = "3.5.0" description = "ASGI specs, helper code, and adapters" category = "dev" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [package.dependencies] typing-extensions = {version = "*", markers = "python_version < \"3.8\""} @@ -333,7 +333,7 @@ i18n = ["babel (>=2.9.0)"] [[package]] name = "mkdocs-material" -version = "8.1.7" +version = "8.1.8" description = "A Material Design theme for MkDocs" category = "dev" optional = false @@ -394,7 +394,7 @@ pydantic = ">=1.8.2" [[package]] name = "orjson" -version = "3.6.5" +version = "3.6.6" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" category = "main" optional = false @@ -713,7 +713,7 @@ typing-extensions = ">=3.7.4" [[package]] name = "starlette" -version = "0.17.1" +version = "0.18.0" description = "The little ASGI library that shines." category = "main" optional = false @@ -721,7 +721,7 @@ python-versions = ">=3.6" [package.dependencies] anyio = ">=3.0.0,<4" -typing-extensions = {version = "*", markers = "python_version < \"3.8\""} +typing-extensions = {version = "*", markers = "python_version < \"3.10\""} [package.extras] full = ["itsdangerous", "jinja2", "python-multipart", "pyyaml", "requests"] @@ -830,8 +830,8 @@ anyio = [ {file = "anyio-3.5.0.tar.gz", hash = "sha256:a0aeffe2fb1fdf374a8e4b471444f0f3ac4fb9f5a5b542b48824475e0042a5a6"}, ] asgiref = [ - {file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"}, - {file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"}, + {file = "asgiref-3.5.0-py3-none-any.whl", hash = "sha256:88d59c13d634dcffe0510be048210188edd79aeccb6a6c9028cdad6f31d730a9"}, + {file = "asgiref-3.5.0.tar.gz", hash = "sha256:2f8abc20f7248433085eda803936d98992f1343ddb022065779f37c5da0181d0"}, ] atomicwrites = [ {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, @@ -1090,8 +1090,8 @@ mkdocs = [ {file = "mkdocs-1.2.3.tar.gz", hash = "sha256:89f5a094764381cda656af4298727c9f53dc3e602983087e1fe96ea1df24f4c1"}, ] mkdocs-material = [ - {file = "mkdocs-material-8.1.7.tar.gz", hash = "sha256:16a50e3f08f1e41bdc3115a00045d174e7fd8219c26917d0d0b48b2cc9d5a18f"}, - {file = "mkdocs_material-8.1.7-py2.py3-none-any.whl", hash = "sha256:71bcac6795b22dcf8bab8b9ad3fe462242c4cd05d28398281902425401f23462"}, + {file = "mkdocs-material-8.1.8.tar.gz", hash = "sha256:7698b59e09640fb0ae47f4ec426f327f7cbe729386b1cdbbb65ac77bc2a95b45"}, + {file = "mkdocs_material-8.1.8-py2.py3-none-any.whl", hash = "sha256:e8adc408c620e5fa23286e7fbfde3ccd37187fbda907a64594c36dbcd119a9ce"}, ] mkdocs-material-extensions = [ {file = "mkdocs-material-extensions-1.0.3.tar.gz", hash = "sha256:bfd24dfdef7b41c312ede42648f9eb83476ea168ec163b613f9abd12bbfddba2"}, @@ -1128,30 +1128,30 @@ openapi-schema-pydantic = [ {file = "openapi_schema_pydantic-1.2.1-py3-none-any.whl", hash = "sha256:840b60bc2fcbe7d2dc2a63f66bdf71decabe59a422d4b8030cc6cb3213111a8e"}, ] orjson = [ - {file = "orjson-3.6.5-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:6c444edc073eb69cf85b28851a7a957807a41ce9bb3a9c14eefa8b33030cf050"}, - {file = "orjson-3.6.5-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:432c6da3d8d4630739f5303dcc45e8029d357b7ff8e70b7239be7bd047df6b19"}, - {file = "orjson-3.6.5-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:0fa32319072fadf0732d2c1746152f868a1b0f83c8cce2cad4996f5f3ca4e979"}, - {file = "orjson-3.6.5-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:0d65cc67f2e358712e33bc53810022ef5181c2378a7603249cd0898aa6cd28d4"}, - {file = "orjson-3.6.5-cp310-none-win_amd64.whl", hash = "sha256:fa8e3d0f0466b7d771a8f067bd8961bc17ca6ea4c89a91cd34d6648e6b1d1e47"}, - {file = "orjson-3.6.5-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:470596fbe300a7350fd7bbcf94d2647156401ab6465decb672a00e201af1813a"}, - {file = "orjson-3.6.5-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d2680d9edc98171b0c59e52c1ed964619be5cb9661289c0dd2e667773fa87f15"}, - {file = "orjson-3.6.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001962a334e1ab2162d2f695f2770d2383c7ffd2805cec6dbb63ea2ad96bf0ad"}, - {file = "orjson-3.6.5-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:522c088679c69e0dd2c72f43cd26a9e73df4ccf9ed725ac73c151bbe816fe51a"}, - {file = "orjson-3.6.5-cp37-cp37m-manylinux_2_24_x86_64.whl", hash = "sha256:d2b871a745a64f72631b633271577c99da628a9b63e10bd5c9c20706e19fe282"}, - {file = "orjson-3.6.5-cp37-none-win_amd64.whl", hash = "sha256:51ab01fed3b3e21561f21386a2f86a0415338541938883b6ca095001a3014a3e"}, - {file = "orjson-3.6.5-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:fc7e62edbc7ece95779a034d9e206d7ba9e2b638cc548fd3a82dc5225f656625"}, - {file = "orjson-3.6.5-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0720d60db3fa25956011a573274a269eb37de98070f3bc186582af1222a2d084"}, - {file = "orjson-3.6.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e169a8876aed7a5bff413c53257ef1fa1d9b68c855eb05d658c4e73ed8dff508"}, - {file = "orjson-3.6.5-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:331f9a3bdba30a6913ad1d149df08e4837581e3ce92bf614277d84efccaf796f"}, - {file = "orjson-3.6.5-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:ece5dfe346b91b442590a41af7afe61df0af369195fed13a1b29b96b1ba82905"}, - {file = "orjson-3.6.5-cp38-none-win_amd64.whl", hash = "sha256:6a5e9eb031b44b7a429c705ca48820371d25b9467c9323b6ae7a712daf15fbef"}, - {file = "orjson-3.6.5-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:206237fa5e45164a678b12acc02aac7c5b50272f7f31116e1e08f8bcaf654f93"}, - {file = "orjson-3.6.5-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:d5aceeb226b060d11ccb5a84a4cfd760f8024289e3810ec446ef2993a85dbaca"}, - {file = "orjson-3.6.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80dba3dbc0563c49719e8cc7d1568a5cf738accfcd1aa6ca5e8222b57436e75e"}, - {file = "orjson-3.6.5-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:443f39bc5e7966880142430ce091e502aea068b38cb9db5f1ffdcfee682bc2d4"}, - {file = "orjson-3.6.5-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:a06f2dd88323a480ac1b14d5829fb6cdd9b0d72d505fabbfbd394da2e2e07f6f"}, - {file = "orjson-3.6.5-cp39-none-win_amd64.whl", hash = "sha256:82cb42dbd45a3856dbad0a22b54deb5e90b2567cdc2b8ea6708e0c4fe2e12be3"}, - {file = "orjson-3.6.5.tar.gz", hash = "sha256:eb3a7d92d783c89df26951ef3e5aca9d96c9c6f2284c752aa3382c736f950597"}, + {file = "orjson-3.6.6-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:e4a7cad6c63306318453980d302c7c0b74c0cc290dd1f433bbd7d31a5af90cf1"}, + {file = "orjson-3.6.6-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:e533941dca4a0530a876de32e54bf2fd3269cdec3751aebde7bfb5b5eba98e74"}, + {file = "orjson-3.6.6-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:9adf63be386eaa34278967512b83ff8fc4bed036a246391ae236f68d23c47452"}, + {file = "orjson-3.6.6-cp310-cp310-manylinux_2_24_x86_64.whl", hash = "sha256:3b636753ae34d4619b11ea7d664a2f1e87e55e9738e5123e12bcce22acae9d13"}, + {file = "orjson-3.6.6-cp310-none-win_amd64.whl", hash = "sha256:78a10295ed048fd916c6584d6d27c232eae805a43e7c14be56e3745f784f0eb6"}, + {file = "orjson-3.6.6-cp37-cp37m-macosx_10_7_x86_64.whl", hash = "sha256:82b4f9fb2af7799b52932a62eac484083f930d5519560d6f64b24d66a368d03f"}, + {file = "orjson-3.6.6-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:a0033d07309cc7d8b8c4bc5d42f0dd4422b53ceb91dee9f4086bb2afa70b7772"}, + {file = "orjson-3.6.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b321f99473116ab7c7c028377372f7b4adba4029aaca19cd567e83898f55579"}, + {file = "orjson-3.6.6-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:b9c98ed94f1688cc11b5c61b8eea39d854a1a2f09f71d8a5af005461b14994ed"}, + {file = "orjson-3.6.6-cp37-cp37m-manylinux_2_24_x86_64.whl", hash = "sha256:00b333a41392bd07a8603c42670547dbedf9b291485d773f90c6470eff435608"}, + {file = "orjson-3.6.6-cp37-none-win_amd64.whl", hash = "sha256:8d4fd3bdee65a81f2b79c50937d4b3c054e1e6bfa3fc72ed018a97c0c7c3d521"}, + {file = "orjson-3.6.6-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:954c9f8547247cd7a8c91094ff39c9fe314b5eaeaec90b7bfb7384a4108f416f"}, + {file = "orjson-3.6.6-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:74e5aed657ed0b91ef05d44d6a26d3e3e12ce4d2d71f75df41a477b05878c4a9"}, + {file = "orjson-3.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4008a5130e6e9c33abaa95e939e0e755175da10745740aa6968461b2f16830e2"}, + {file = "orjson-3.6.6-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:012761d5f3d186deb4f6238f15e9ea7c1aac6deebc8f5b741ba3b4fafe017460"}, + {file = "orjson-3.6.6-cp38-cp38-manylinux_2_24_x86_64.whl", hash = "sha256:b464546718a940b48d095a98df4c04808bfa6c8706fe751fc3f9390bc2f82643"}, + {file = "orjson-3.6.6-cp38-none-win_amd64.whl", hash = "sha256:f10a800f4e5a4aab52076d4628e9e4dab9370bdd9d8ea254ebfde846b653ab25"}, + {file = "orjson-3.6.6-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:8010d2610cfab721725ef14d578c7071e946bbdae63322d8f7b49061cf3fde8d"}, + {file = "orjson-3.6.6-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:8dca67a4855e1e0f9a2ea0386e8db892708522e1171dc0ddf456932288fbae63"}, + {file = "orjson-3.6.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af065d60523139b99bd35b839c7a2d8c5da55df8a8c4402d2eb6cdc07fa7a624"}, + {file = "orjson-3.6.6-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:fa1f389cc9f766ae0cf7ba3533d5089836b01a5ccb3f8d904297f1fcf3d9dc34"}, + {file = "orjson-3.6.6-cp39-cp39-manylinux_2_24_x86_64.whl", hash = "sha256:ec1221ad78f94d27b162a1d35672b62ef86f27f0e4c2b65051edb480cc86b286"}, + {file = "orjson-3.6.6-cp39-none-win_amd64.whl", hash = "sha256:afed2af55eeda1de6b3f1cbc93431981b19d380fcc04f6ed86e74c1913070304"}, + {file = "orjson-3.6.6.tar.gz", hash = "sha256:55dd988400fa7fbe0e31407c683f5aaab013b5bd967167b8fe058186773c4d6c"}, ] packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, @@ -1347,8 +1347,8 @@ sqlalchemy2-stubs = [ {file = "sqlalchemy2_stubs-0.0.2a19-py3-none-any.whl", hash = "sha256:aac7dca77a2c49e5f0934976421d5e25ae4dc5e27db48c01e055f81caa1e3ead"}, ] starlette = [ - {file = "starlette-0.17.1-py3-none-any.whl", hash = "sha256:26a18cbda5e6b651c964c12c88b36d9898481cd428ed6e063f5f29c418f73050"}, - {file = "starlette-0.17.1.tar.gz", hash = "sha256:57eab3cc975a28af62f6faec94d355a410634940f10b30d68d31cb5ec1b44ae8"}, + {file = "starlette-0.18.0-py3-none-any.whl", hash = "sha256:377d64737a0e03560cb8eaa57604afee143cea5a4996933242798a7820e64f53"}, + {file = "starlette-0.18.0.tar.gz", hash = "sha256:b45c6e9a617ecb5caf7e6446bd8d767b0084d6217e8e1b08187ca5191e10f097"}, ] text-unidecode = [ {file = "text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93"}, From c11c246b37b537c03a3d2bc8ca01aaed3299ead2 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Jan 2022 17:30:08 +0100 Subject: [PATCH 5/7] 0.7.0 --- CHANGELOG.md | 7 +++++++ pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16d7224356..45b20ed203 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,3 +86,10 @@ 1. supports generics 2. added `to_model_instance` and `from_model_instance` methods 3. added `field_definitions` kwarg, allowing for creating custom fields + + +[0.7.0] +- optimization: rewrote route resolution +- optimization: updated query parameters parsing +- optimization: updated request-response cycle handling +- added `@asgi` route handler decorator diff --git a/pyproject.toml b/pyproject.toml index fbbb5086f8..8d20be0087 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "starlite" -version = "0.6.0" +version = "0.7.0" description = "Light-weight and flexible ASGI API Framework" authors = ["Na'aman Hirschfeld "] maintainers = ["Na'aman Hirschfeld "] From 0abf065621c834bd265750dea9ed2ebee5fac8ea Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Jan 2022 17:34:08 +0100 Subject: [PATCH 6/7] added basic docs --- docs/usage/2-route-handlers.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/usage/2-route-handlers.md b/docs/usage/2-route-handlers.md index 1642fc7c6a..b1f7c518fa 100644 --- a/docs/usage/2-route-handlers.md +++ b/docs/usage/2-route-handlers.md @@ -229,6 +229,26 @@ In all other regards websocket handlers function exactly like other route handle OpenAPI currently does not support websockets. As a result not schema will be generated for websocket route handlers, and you cannot configure any schema related parameters for these. + +## ASGI Route Handlers + +!!! info + This feature is available from v0.7.0 onwards + +You can write your own ASGI apps using the `asgi` route handler decorator: + +```python +from starlette.types import Scope, Receive, Send +from starlite import asgi + + +@asgi(path="/my-asgi-app") +async def my_asgi_app(scope: Scope, receive: Receive, send: Send) -> None: + ... +``` + +ASGI apps are currently not handled in OpenAPI generation - although this will change in the future. + ## Handler Function Kwargs Route handler functions or methods access various data by declaring these as annotated function kwargs. The annotated From a28bfb1f599601ccf1034211bf07b66637781044 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Jan 2022 17:54:57 +0100 Subject: [PATCH 7/7] updated ASGI docs --- docs/usage/2-route-handlers.md | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/usage/2-route-handlers.md b/docs/usage/2-route-handlers.md index b1f7c518fa..400d30df47 100644 --- a/docs/usage/2-route-handlers.md +++ b/docs/usage/2-route-handlers.md @@ -239,15 +239,26 @@ You can write your own ASGI apps using the `asgi` route handler decorator: ```python from starlette.types import Scope, Receive, Send -from starlite import asgi +from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST +from starlite import Response, asgi @asgi(path="/my-asgi-app") async def my_asgi_app(scope: Scope, receive: Receive, send: Send) -> None: - ... + if scope["type"] == "http": + method = scope["method"] + if method.lower() == "get": + response = Response({"hello": "world"}, status_code=HTTP_200_OK) + await response(scope=scope, receive=receive, send=send) + return + response = Response( + {"detail": "unsupported request"}, status_code=HTTP_400_BAD_REQUEST + ) + await response(scope=scope, receive=receive, send=send) ``` -ASGI apps are currently not handled in OpenAPI generation - although this will change in the future. +!!! note + ASGI apps are currently not handled in OpenAPI generation - although this might change in the future. ## Handler Function Kwargs