diff --git a/.gitignore b/.gitignore index f5a8e6ae..7862503e 100644 --- a/.gitignore +++ b/.gitignore @@ -78,4 +78,5 @@ man/ .pytest_cache #vscode -.vscode/ \ No newline at end of file +.vscode/ +pip-wheel-metadata diff --git a/flask_rebar/rebar.py b/flask_rebar/rebar.py index fc5345b4..72167790 100644 --- a/flask_rebar/rebar.py +++ b/flask_rebar/rebar.py @@ -32,10 +32,16 @@ from flask_rebar.utils.request_utils import get_header_params_or_400 from flask_rebar.utils.request_utils import get_json_body_params_or_400 from flask_rebar.utils.request_utils import get_query_string_params_or_400 +from flask_rebar.utils.request_utils import normalize_schema from flask_rebar.utils.deprecation import deprecated, deprecated_parameters from flask_rebar.swagger_generation import SwaggerV2Generator from flask_rebar.swagger_ui import create_swagger_ui_blueprint +# Deal with maintaining (for now at least) support for 2.7+: +try: + from collections.abc import Mapping # 3.3+ +except ImportError: + from collections import Mapping # 2.7+ # To catch redirection exceptions, app.errorhandler expects 301 in versions # below 0.11.0 but the exception itself in versions greater than 0.11.0. @@ -459,8 +465,19 @@ def add_handler( :param Type[USE_DEFAULT]|None|str mimetype: Content-Type header to add to the response schema """ - if isinstance(response_body_schema, marshmallow.Schema): - response_body_schema = {200: response_body_schema} + # Fix #115: if we were passed bare classes we'll go ahead and instantiate + headers_schema = normalize_schema(headers_schema) + request_body_schema = normalize_schema(request_body_schema) + query_string_schema = normalize_schema(query_string_schema) + if response_body_schema: + # Ensure we wrap in appropriate default (200) dict if we were passed a single Schema or class: + if not isinstance(response_body_schema, Mapping): + response_body_schema = {200: response_body_schema} + # use normalize_schema to convert any class reference(s) to instantiated schema(s): + response_body_schema = { + code: normalize_schema(schema) + for (code, schema) in response_body_schema.items() + } # authenticators can be a list of Authenticators, a single Authenticator, USE_DEFAULT, or None if isinstance(authenticators, Authenticator) or authenticators is USE_DEFAULT: diff --git a/flask_rebar/utils/request_utils.py b/flask_rebar/utils/request_utils.py index d0c41877..a27f38ec 100644 --- a/flask_rebar/utils/request_utils.py +++ b/flask_rebar/utils/request_utils.py @@ -21,6 +21,7 @@ from flask_rebar import compat from flask_rebar import errors from flask_rebar import messages +from flask_rebar.utils.defaults import USE_DEFAULT class HeadersProxy(compat.Mapping): @@ -92,7 +93,7 @@ def normalize_schema(schema): This allows for either an instance of a marshmallow.Schema or the class itself to be passed to functions. """ - if not isinstance(schema, marshmallow.Schema): + if schema not in (None, USE_DEFAULT) and not isinstance(schema, marshmallow.Schema): schema = schema() return schema diff --git a/tests/test_rebar.py b/tests/test_rebar.py index 28af1f95..8fd503d1 100644 --- a/tests/test_rebar.py +++ b/tests/test_rebar.py @@ -12,7 +12,6 @@ import marshmallow as m from flask import Flask -from werkzeug.routing import RequestRedirect from flask_rebar import messages from flask_rebar import HeaderApiKeyAuthenticator, SwaggerV3Generator @@ -677,3 +676,80 @@ def test_redirects_for_missing_trailing_slash(self): resp = app.test_client().get(path="/with_trailing_slash") self.assertIn(resp.status_code, (301, 308)) self.assertTrue(resp.headers["Location"].endswith("/with_trailing_slash/")) + + def test_bare_class_schemas_handled(self): + rebar = Rebar() + registry = rebar.create_handler_registry() + + expected_foo = FooSchema().load({"uid": "some_uid", "name": "Namey McNamerton"}) + expected_headers = {"x-name": "Header Name"} + + def get_foo(*args, **kwargs): + return expected_foo + + def post_foo(*args, **kwargs): + return expected_foo + + register_endpoint( + registry=registry, + method="GET", + path="/my_get_endpoint", + headers_schema=HeadersSchema, + response_body_schema={200: FooSchema}, + query_string_schema=FooListSchema, + func=get_foo, + ) + + register_endpoint( + registry=registry, + method="POST", + path="/my_post_endpoint", + request_body_schema=FooListSchema, + response_body_schema=FooSchema, + func=post_foo, + ) + + app = create_rebar_app(rebar) + # violate headers schema: + resp = app.test_client().get(path="/my_get_endpoint?name=QuerystringName") + self.assertEqual(resp.status_code, 400) + self.assertEqual( + get_json_from_resp(resp)["message"], messages.header_validation_failed + ) + # violate querystring schema: + resp = app.test_client().get(path="/my_get_endpoint", headers=expected_headers) + self.assertEqual(resp.status_code, 400) + self.assertEqual( + get_json_from_resp(resp)["message"], messages.query_string_validation_failed + ) + # valid request: + resp = app.test_client().get( + path="/my_get_endpoint?name=QuerystringName", headers=expected_headers + ) + self.assertEqual(resp.status_code, 200) + self.assertEqual(get_json_from_resp(resp), expected_foo.data) + + resp = app.test_client().post( + path="/my_post_endpoint", + data='{"wrong": "Posted Name"}', + content_type="application/json", + ) + self.assertEqual(resp.status_code, 400) + self.assertEqual( + get_json_from_resp(resp)["message"], messages.body_validation_failed + ) + + resp = app.test_client().post( + path="/my_post_endpoint", + data='{"name": "Posted Name"}', + content_type="application/json", + ) + self.assertEqual(resp.status_code, 200) + + # ensure Swagger generation doesn't break (Issue #115) + from flask_rebar import SwaggerV2Generator, SwaggerV3Generator + + swagger = SwaggerV2Generator().generate(registry) + self.assertIsNotNone(swagger) # really only care that it didn't barf + swagger = SwaggerV3Generator().generate(registry) + self.assertIsNotNone(swagger)