From 2c39e6f13518ff6b7603360c270d7eb925e177d8 Mon Sep 17 00:00:00 2001 From: Yuval Herziger Date: Fri, 31 Dec 2021 18:32:17 +0200 Subject: [PATCH] Additional fields --- README.md | 72 +++++++++++++++++++++++++++---------- aiohttp_catcher/catcher.py | 27 ++++++++------ aiohttp_catcher/scenario.py | 28 ++++++++++++--- pyproject.toml | 5 +-- tests/conftest.py | 2 +- tests/test_catcher.py | 54 +++++++++++++++++++++++++++- 6 files changed, 149 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index faa83e4..a287899 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ will be handled however you want them to. * [Callables and Awaitables](#callables-and-awaitables) * [Handle Several Exceptions Similarly](#handle-several-exceptions-similarly) * [Scenarios as Dictionaries](#scenarios-as-dictionaries) + * [Additional Fields](#additional-fields) + * [Default for Unhandled Exceptions](#default-for-unhandled-exceptions) - [Development](#development) *** @@ -29,9 +31,9 @@ will be handled however you want them to. from aiohttp import web from aiohttp_catcher import catch, Catcher -async def hello(request): - division = 1 / 0 - return web.Response(text=f"1 / 0 = {division}") +async def divide(request): + quotient = 1 / 0 + return web.Response(text=f"1 / 0 = {quotient}") async def main(): @@ -45,7 +47,7 @@ async def main(): # Register your catcher as an aiohttp middleware: app = web.Application(middlewares=[catcher.middleware]) - app.add_routes([web.get("/divide-by-zero", hello)]) + app.add_routes([web.get("/divide-by-zero", divide)]) web.run_app(app) ``` @@ -61,7 +63,7 @@ Making a request to `/divide-by-zero` will return a 400 status code with the fol ### Return a Constant In case you want some exceptions to return a constant message across your application, you can do -so by using the `.and_return("some value")` method: +so by using the `and_return("some value")` method: ```python await catcher.add_scenario( @@ -76,7 +78,7 @@ await catcher.add_scenario( In some cases, you would want to return a stringified version of your exception, should it entail user-friendly information. -``` +```python class EntityNotFound(Exception): def __init__(self, entity_id, *args, **kwargs): super(EntityNotFound, self).__init__(*args, **kwargs) @@ -164,6 +166,38 @@ await catcher.add_scenarios( *** +### Additional Fields + +You can enrich your error responses with additional fields. You can provide additional fields using +literal dictionaries or with callables. Callables will be called with the exception object as their +only argument. + + +```python +# Using a literal dictionary: +await catcher.add_scenario( + catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields({"error_code": "ENTITY_NOT_FOUND"}) +) + +# Using a function (or an async function): +await catcher.add_scenario( + catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields(lambda e: {"error_code": e.error_code}) +) +``` + +*** + +### Default for Unhandled Exceptions + +Exceptions that aren't registered with scenarios in your `Catcher` will default to 500, with a payload similar to +the following: + +```json +{"code": 500, "message": "Internal server error"} +``` + +*** + ## Development Contributions are warmly welcomed. Before submitting your PR, please run the tests using the following Make target: @@ -172,22 +206,22 @@ Contributions are warmly welcomed. Before submitting your PR, please run the te make ci ``` -Alternatively, you can run each test suite separately: +Alternatively, you can run each test separately: -1. Unit tests: +Unit tests: - ```bash - make test/py - ``` +```bash +make test/py +``` -2. Linting with pylint: +Linting with pylint: - ```bash - make pylint - ``` +```bash +make pylint +``` -3. Static security checks with bandit: +Static security checks with bandit: - ```bash - make pybandit - ``` +```bash +make pybandit +``` diff --git a/aiohttp_catcher/catcher.py b/aiohttp_catcher/catcher.py index 6cae4e4..5f82cab 100644 --- a/aiohttp_catcher/catcher.py +++ b/aiohttp_catcher/catcher.py @@ -50,16 +50,21 @@ async def catcher_middleware(request: Request, handler: Handler): return await handler(request) except Exception as exc: exc_module = await get_full_class_name(exc.__class__) + scenario: Scenario if exc_module in self.scenario_map: - scenario: Scenario = self.scenario_map[exc_module] - data = { - self.envelope: await scenario.get_response_message(exc), - self.code: scenario.status_code, - } - return json_response( - data=data, - status=scenario.status_code, - dumps=self.encoder - ) - raise + scenario = self.scenario_map[exc_module] + else: + LOGGER.exception("aiohttp-catcher caught an unhandled exception") + scenario = Scenario(exceptions=[type(exc)]) + additional_fields: Dict = await scenario.get_additional_fields(exc) + data = { + self.envelope: await scenario.get_response_message(exc), + self.code: scenario.status_code, + **additional_fields + } + return json_response( + data=data, + status=scenario.status_code, + dumps=self.encoder + ) return catcher_middleware diff --git a/aiohttp_catcher/scenario.py b/aiohttp_catcher/scenario.py index ef37b7b..b95c238 100644 --- a/aiohttp_catcher/scenario.py +++ b/aiohttp_catcher/scenario.py @@ -1,19 +1,25 @@ -from typing import Any, Awaitable, Callable, List, Union +from typing import Any, Awaitable, Callable, Dict, List, Union from inspect import isawaitable, iscoroutine, iscoroutinefunction +def is_async(f): + return isawaitable(f) or iscoroutine(f) or iscoroutinefunction(f) + class Scenario: status_code: int = 500 is_callable: bool = False stringify_exception: bool = False + additional_fields: Union[Dict, Callable, Awaitable] = None - def __init__(self, exceptions: List[Exception], func: Union[Callable, Awaitable] = None, constant: Any = None, - stringify_exception: bool = False, status_code: int = 500): + def __init__(self, exceptions: List[Exception], func: Union[Callable, Awaitable] = None, constant: Any = "Internal server error", + stringify_exception: bool = False, status_code: int = 500, + additional_fields: Union[Dict, Callable, Awaitable] = None): self.exceptions = exceptions self.stringify_exception = stringify_exception self.func = func self.constant = constant + self.additional_fields = additional_fields if not stringify_exception: if func and hasattr(func, "__call__"): self.is_callable = True @@ -21,18 +27,30 @@ def __init__(self, exceptions: List[Exception], func: Union[Callable, Awaitable] async def get_response_message(self, exc: Exception) -> Any: if self.is_callable: - awaitable = isawaitable(self.func) or iscoroutine(self.func) or iscoroutinefunction(self.func) - if awaitable: + if is_async(self.func): return await self.func(exc) return self.func(exc) if self.stringify_exception: return str(exc) return self.constant + async def get_additional_fields(self, exc: Exception) -> Dict: + if not self.additional_fields: + return {} + if isinstance(self.additional_fields, Dict): + return self.additional_fields + if is_async(self.additional_fields): + return await self.additional_fields(exc) + return self.additional_fields(exc) + def with_status_code(self, status_code) -> "Scenario": self.status_code = status_code return self + def with_additional_fields(self, additional_fields: Union[Dict, Callable, Awaitable]) -> "Scenario": + self.additional_fields = additional_fields + return self + def and_stringify(self) -> "Scenario": self.stringify_exception = True return self diff --git a/pyproject.toml b/pyproject.toml index 2a55c31..dfdd0d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "aiohttp-catcher" version = "0.1.0" description = "A centralized error handler for aiohttp servers" -authors = ["Yuvi Herziger "] +authors = ["Yuval Herziger "] license = "MIT" readme = "README.md" homepage = "https://github.com/yuvalherziger/aiohttp-catcher" @@ -29,11 +29,12 @@ build-backend = "poetry.core.masonry.api" [tool.pylint.BASIC] good-names=[ "e", + "f" ] [tool.pylint.FORMAT] max-line-length=120 -max-args=6 +max-args=7 [tool.pylint.'MESSAGES CONTROL'] disable=[ diff --git a/tests/conftest.py b/tests/conftest.py index 6b23d58..8421158 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ class AppClientError(Exception): class EntityNotFound(AppClientError): - pass + error_code = "ENTITY_NOT_FOUND" class Forbidden(AppClientError): diff --git a/tests/test_catcher.py b/tests/test_catcher.py index 471df46..bf7fd76 100644 --- a/tests/test_catcher.py +++ b/tests/test_catcher.py @@ -203,7 +203,7 @@ async def test_scenarios_as_dict(aiohttp_client, routes, loop): assert "Out of bound: range object index out of range" == (await resp.json()).get("message") @staticmethod - async def test_return_constant_message(aiohttp_client, routes, loop): + async def test_default_to_500(aiohttp_client, routes, loop): catcher = Catcher() app = web.Application(middlewares=[catcher.middleware]) app.add_routes(routes) @@ -215,3 +215,55 @@ async def test_return_constant_message(aiohttp_client, routes, loop): resp = await client.get("/divide?a=Foo") assert 500 == resp.status + assert "Internal server error" == (await resp.json()).get("message") + + @staticmethod + async def test_additional_fields_from_dictionary(aiohttp_client, routes, loop): + catcher = Catcher() + await catcher.add_scenario( + catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields({"error_code": "ENTITY_NOT_FOUND"}) + ) + app = web.Application(middlewares=[catcher.middleware]) + app.add_routes(routes) + + client = await aiohttp_client(app) + resp = await client.get("/user/1009") + assert 404 == resp.status + assert "User ID 1009 could not be found" == (await resp.json()).get("message") + assert "ENTITY_NOT_FOUND" == (await resp.json()).get("error_code") + + @staticmethod + async def test_additional_fields_from_callable(aiohttp_client, routes, loop): + catcher = Catcher() + await catcher.add_scenario( + catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields(lambda e: {"error_code": e.error_code}) + ) + app = web.Application(middlewares=[catcher.middleware]) + app.add_routes(routes) + + client = await aiohttp_client(app) + resp = await client.get("/user/1009") + assert 404 == resp.status + assert "User ID 1009 could not be found" == (await resp.json()).get("message") + assert "ENTITY_NOT_FOUND" == (await resp.json()).get("error_code") + + @staticmethod + async def test_additional_fields_from_awaitable(aiohttp_client, routes, loop): + catcher = Catcher() + + async def get_additional_fields(e: Exception): + return {"error_code": e.error_code, "foo": "bar"} + + await catcher.add_scenario( + catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields(get_additional_fields) + ) + app = web.Application(middlewares=[catcher.middleware]) + app.add_routes(routes) + + client = await aiohttp_client(app) + resp = await client.get("/user/1009") + assert 404 == resp.status + assert "User ID 1009 could not be found" == (await resp.json()).get("message") + body = await resp.json() + assert "ENTITY_NOT_FOUND" == body.get("error_code") + assert "bar" == body.get("foo") \ No newline at end of file