Skip to content

Commit

Permalink
Merge branch 'main' of github.com:yuvalherziger/aiohttp-catcher
Browse files Browse the repository at this point in the history
  • Loading branch information
yuvalherziger committed Dec 31, 2021
2 parents 0ae7a2b + 69e487f commit c1c8d51
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 39 deletions.
72 changes: 53 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ will be handled however you want them to.
* [Callables and Awaitables](#callables-and-awaitables)
* [Handle Several Exceptions Similarly](#handle-several-exceptions-similarly)
* [Scenarios as Dictionaries](#scenarios-as-dictionaries)
* [Additional Fields](#additional-fields)
* [Default for Unhandled Exceptions](#default-for-unhandled-exceptions)
- [Development](#development)

***
Expand All @@ -29,9 +31,9 @@ will be handled however you want them to.
from aiohttp import web
from aiohttp_catcher import catch, Catcher

async def hello(request):
division = 1 / 0
return web.Response(text=f"1 / 0 = {division}")
async def divide(request):
quotient = 1 / 0
return web.Response(text=f"1 / 0 = {quotient}")


async def main():
Expand All @@ -45,7 +47,7 @@ async def main():

# Register your catcher as an aiohttp middleware:
app = web.Application(middlewares=[catcher.middleware])
app.add_routes([web.get("/divide-by-zero", hello)])
app.add_routes([web.get("/divide-by-zero", divide)])
web.run_app(app)
```

Expand All @@ -61,7 +63,7 @@ Making a request to `/divide-by-zero` will return a 400 status code with the fol
### Return a Constant

In case you want some exceptions to return a constant message across your application, you can do
so by using the `.and_return("some value")` method:
so by using the `and_return("some value")` method:

```python
await catcher.add_scenario(
Expand All @@ -76,7 +78,7 @@ await catcher.add_scenario(
In some cases, you would want to return a stringified version of your exception, should it entail
user-friendly information.

```
```python
class EntityNotFound(Exception):
def __init__(self, entity_id, *args, **kwargs):
super(EntityNotFound, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -164,6 +166,38 @@ await catcher.add_scenarios(

***

### Additional Fields

You can enrich your error responses with additional fields. You can provide additional fields using
literal dictionaries or with callables. Callables will be called with the exception object as their
only argument.


```python
# Using a literal dictionary:
await catcher.add_scenario(
catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields({"error_code": "ENTITY_NOT_FOUND"})
)

# Using a function (or an async function):
await catcher.add_scenario(
catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields(lambda e: {"error_code": e.error_code})
)
```

***

### Default for Unhandled Exceptions

Exceptions that aren't registered with scenarios in your `Catcher` will default to 500, with a payload similar to
the following:

```json
{"code": 500, "message": "Internal server error"}
```

***

## Development

Contributions are warmly welcomed. Before submitting your PR, please run the tests using the following Make target:
Expand All @@ -172,22 +206,22 @@ Contributions are warmly welcomed. Before submitting your PR, please run the te
make ci
```

Alternatively, you can run each test suite separately:
Alternatively, you can run each test separately:

1. Unit tests:
Unit tests:

```bash
make test/py
```
```bash
make test/py
```

2. Linting with pylint:
Linting with pylint:

```bash
make pylint
```
```bash
make pylint
```

3. Static security checks with bandit:
Static security checks with bandit:

```bash
make pybandit
```
```bash
make pybandit
```
27 changes: 16 additions & 11 deletions aiohttp_catcher/catcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,21 @@ async def catcher_middleware(request: Request, handler: Handler):
return await handler(request)
except Exception as exc:
exc_module = await get_full_class_name(exc.__class__)
scenario: Scenario
if exc_module in self.scenario_map:
scenario: Scenario = self.scenario_map[exc_module]
data = {
self.envelope: await scenario.get_response_message(exc),
self.code: scenario.status_code,
}
return json_response(
data=data,
status=scenario.status_code,
dumps=self.encoder
)
raise
scenario = self.scenario_map[exc_module]
else:
LOGGER.exception("aiohttp-catcher caught an unhandled exception")
scenario = Scenario(exceptions=[type(exc)])
additional_fields: Dict = await scenario.get_additional_fields(exc)
data = {
self.envelope: await scenario.get_response_message(exc),
self.code: scenario.status_code,
**additional_fields
}
return json_response(
data=data,
status=scenario.status_code,
dumps=self.encoder
)
return catcher_middleware
28 changes: 23 additions & 5 deletions aiohttp_catcher/scenario.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,56 @@
from typing import Any, Awaitable, Callable, List, Union
from typing import Any, Awaitable, Callable, Dict, List, Union
from inspect import isawaitable, iscoroutine, iscoroutinefunction


def is_async(f):
return isawaitable(f) or iscoroutine(f) or iscoroutinefunction(f)


class Scenario:
status_code: int = 500
is_callable: bool = False
stringify_exception: bool = False
additional_fields: Union[Dict, Callable, Awaitable] = None

def __init__(self, exceptions: List[Exception], func: Union[Callable, Awaitable] = None, constant: Any = None,
stringify_exception: bool = False, status_code: int = 500):
def __init__(self, exceptions: List[Exception], func: Union[Callable, Awaitable] = None, constant: Any = "Internal server error",
stringify_exception: bool = False, status_code: int = 500,
additional_fields: Union[Dict, Callable, Awaitable] = None):
self.exceptions = exceptions
self.stringify_exception = stringify_exception
self.func = func
self.constant = constant
self.additional_fields = additional_fields
if not stringify_exception:
if func and hasattr(func, "__call__"):
self.is_callable = True
self.status_code = status_code

async def get_response_message(self, exc: Exception) -> Any:
if self.is_callable:
awaitable = isawaitable(self.func) or iscoroutine(self.func) or iscoroutinefunction(self.func)
if awaitable:
if is_async(self.func):
return await self.func(exc)
return self.func(exc)
if self.stringify_exception:
return str(exc)
return self.constant

async def get_additional_fields(self, exc: Exception) -> Dict:
if not self.additional_fields:
return {}
if isinstance(self.additional_fields, Dict):
return self.additional_fields
if is_async(self.additional_fields):
return await self.additional_fields(exc)
return self.additional_fields(exc)

def with_status_code(self, status_code) -> "Scenario":
self.status_code = status_code
return self

def with_additional_fields(self, additional_fields: Union[Dict, Callable, Awaitable]) -> "Scenario":
self.additional_fields = additional_fields
return self

def and_stringify(self) -> "Scenario":
self.stringify_exception = True
return self
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "aiohttp-catcher"
version = "0.2.0"
description = "A centralized error handler for aiohttp servers"
authors = ["Yuvi Herziger <yherziger@immuta.com>"]
authors = ["Yuval Herziger <yuvalhrz@gmail.com>"]
license = "MIT"
readme = "README.md"
homepage = "https://github.com/yuvalherziger/aiohttp-catcher"
Expand All @@ -29,11 +29,12 @@ build-backend = "poetry.core.masonry.api"
[tool.pylint.BASIC]
good-names=[
"e",
"f"
]

[tool.pylint.FORMAT]
max-line-length=120
max-args=6
max-args=7

[tool.pylint.'MESSAGES CONTROL']
disable=[
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class AppClientError(Exception):


class EntityNotFound(AppClientError):
pass
error_code = "ENTITY_NOT_FOUND"


class Forbidden(AppClientError):
Expand Down
54 changes: 53 additions & 1 deletion tests/test_catcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def test_scenarios_as_dict(aiohttp_client, routes, loop):
assert "Out of bound: range object index out of range" == (await resp.json()).get("message")

@staticmethod
async def test_return_constant_message(aiohttp_client, routes, loop):
async def test_default_to_500(aiohttp_client, routes, loop):
catcher = Catcher()
app = web.Application(middlewares=[catcher.middleware])
app.add_routes(routes)
Expand All @@ -215,3 +215,55 @@ async def test_return_constant_message(aiohttp_client, routes, loop):

resp = await client.get("/divide?a=Foo")
assert 500 == resp.status
assert "Internal server error" == (await resp.json()).get("message")

@staticmethod
async def test_additional_fields_from_dictionary(aiohttp_client, routes, loop):
catcher = Catcher()
await catcher.add_scenario(
catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields({"error_code": "ENTITY_NOT_FOUND"})
)
app = web.Application(middlewares=[catcher.middleware])
app.add_routes(routes)

client = await aiohttp_client(app)
resp = await client.get("/user/1009")
assert 404 == resp.status
assert "User ID 1009 could not be found" == (await resp.json()).get("message")
assert "ENTITY_NOT_FOUND" == (await resp.json()).get("error_code")

@staticmethod
async def test_additional_fields_from_callable(aiohttp_client, routes, loop):
catcher = Catcher()
await catcher.add_scenario(
catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields(lambda e: {"error_code": e.error_code})
)
app = web.Application(middlewares=[catcher.middleware])
app.add_routes(routes)

client = await aiohttp_client(app)
resp = await client.get("/user/1009")
assert 404 == resp.status
assert "User ID 1009 could not be found" == (await resp.json()).get("message")
assert "ENTITY_NOT_FOUND" == (await resp.json()).get("error_code")

@staticmethod
async def test_additional_fields_from_awaitable(aiohttp_client, routes, loop):
catcher = Catcher()

async def get_additional_fields(e: Exception):
return {"error_code": e.error_code, "foo": "bar"}

await catcher.add_scenario(
catch(EntityNotFound).with_status_code(404).and_stringify().with_additional_fields(get_additional_fields)
)
app = web.Application(middlewares=[catcher.middleware])
app.add_routes(routes)

client = await aiohttp_client(app)
resp = await client.get("/user/1009")
assert 404 == resp.status
assert "User ID 1009 could not be found" == (await resp.json()).get("message")
body = await resp.json()
assert "ENTITY_NOT_FOUND" == body.get("error_code")
assert "bar" == body.get("foo")

0 comments on commit c1c8d51

Please sign in to comment.