Skip to content

Commit

Permalink
HH-228923 make excpetion handling async
Browse files Browse the repository at this point in the history
  • Loading branch information
712u3 committed Sep 4, 2024
1 parent 39c9dcd commit 70e4b70
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 99 deletions.
177 changes: 83 additions & 94 deletions frontik/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from asyncio.futures import Future
from functools import partial, wraps
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Type, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union, overload

import tornado.web
from fastapi import Depends, Request
Expand All @@ -20,7 +20,7 @@
from pydantic import BaseModel, ValidationError
from tornado import httputil
from tornado.ioloop import IOLoop
from tornado.web import Finish, RequestHandler
from tornado.web import Finish, HTTPError, RequestHandler

import frontik.auth
import frontik.producers.json_producer
Expand All @@ -30,7 +30,7 @@
from frontik.auth import DEBUG_AUTH_HEADER_NAME
from frontik.debug import DEBUG_HEADER_NAME, DebugMode
from frontik.futures import AbortAsyncGroup, AsyncGroup
from frontik.http_status import ALLOWED_STATUSES, NON_CRITICAL_BAD_GATEWAY
from frontik.http_status import ALLOWED_STATUSES, CLIENT_CLOSED_REQUEST, NON_CRITICAL_BAD_GATEWAY
from frontik.json_builder import FrontikJsonDecodeError, json_decode
from frontik.loggers import CUSTOM_JSON_EXTRA, JSON_REQUESTS_LOGGER
from frontik.loggers.stages import StagesLogger
Expand All @@ -56,10 +56,6 @@ def __init__(self, wait_finish_group: bool = False) -> None:
self.wait_finish_group = wait_finish_group


class RedirectSignal(Exception):
pass


class FinishSignal(Exception):
pass

Expand Down Expand Up @@ -130,8 +126,6 @@ def __init__(
self.request = request
self._headers_written = False
self._finished = False
self._auto_finish = True
self._prepared_future = None

self.clear()
self.initialize()
Expand Down Expand Up @@ -344,7 +338,6 @@ def redirect(self, url: str, *args: Any, allow_protocol_relative: bool = False,
raise tornado.web.HTTPError(403, 'cannot redirect path with two initial slashes')
self.log.info('redirecting to: %s', url)
super().redirect(url, *args, **kwargs)
raise RedirectSignal()

@property
def json_body(self):
Expand Down Expand Up @@ -398,23 +391,20 @@ async def execute(self) -> tuple[int, str, HTTPHeaders, bytes]:
self.prepare()

await self._execute_page()
if self._auto_finish and not self._finished:
self.finish()
self.finish()
except Exception as e:
try:
self._handle_request_exception(e)
except Exception:
await self._handle_exception(e)
except Exception as exc:
self.log.exception('Exception in exception handler')
self.send_error()
return self.handler_result_future.result()
if self._prepared_future is not None and not self._prepared_future.done():
self._prepared_future.set_result(None)
await self._send_error(exception=exc)
return self.__get_result()

done, pending = await asyncio.wait((self.handler_result_future,), timeout=5.0)
if not done:
self.log.error('handler was never finished')
self.send_error()
return self.handler_result_future.result()
await self._send_error(exception=RuntimeError('handler was never finished'))
return self.__get_result()

async def get(self, *args, **kwargs):
await self._execute_page()
Expand Down Expand Up @@ -464,27 +454,32 @@ async def _execute_page(self) -> None:
if render_result is not None:
self.write(render_result)

def get_page_fail_fast(self, request_result: RequestResult) -> None:
self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})
async def get_page_fail_fast(self, request_result: RequestResult) -> None:
await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})

def post_page_fail_fast(self, request_result: RequestResult) -> None:
self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})
async def post_page_fail_fast(self, request_result: RequestResult) -> None:
await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})

def put_page_fail_fast(self, request_result: RequestResult) -> None:
self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})
async def put_page_fail_fast(self, request_result: RequestResult) -> None:
await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})

def delete_page_fail_fast(self, request_result: RequestResult) -> None:
self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})
async def delete_page_fail_fast(self, request_result: RequestResult) -> None:
await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True})

def __return_error(self, response_code: int, **kwargs: Any) -> None:
async def __return_error(self, response_code: int, **kwargs: Any) -> None:
if not (300 <= response_code < 500 or response_code == NON_CRITICAL_BAD_GATEWAY):
response_code = HTTPStatus.BAD_GATEWAY
self.send_error(response_code, **kwargs)
await self._send_error(response_code, None, **kwargs)

# Finish page

def is_finished(self) -> bool:
return self._finished or getattr(self.request, 'canceled', False)
return self.handler_result_future.done() or getattr(self.request, 'canceled', False)

def __get_result(self) -> tuple[int, str, HTTPHeaders, bytes]:
if getattr(self.request, 'canceled', False):
return CLIENT_CLOSED_REQUEST, 'Client closed the connection: aborting request', self._headers, b''
return self.handler_result_future.result()

def check_finished(self, callback: Callable) -> Callable:
@wraps(callback)
Expand All @@ -496,28 +491,27 @@ def wrapper(*args, **kwargs):

return wrapper

def finish_with_postprocessors(self) -> None:
async def finish_with_postprocessors(self) -> None:
if not self.finish_group.get_finish_future().done():
self.finish_group.abort()

def _cb(future: Future) -> None:
if (ex := future.exception()) is not None:
self.log.error('postprocess failed %s', ex)
self.set_status(500)
self.finish()
if future.result() is not None:
self.finish(future.result())

asyncio.create_task(self._postprocess()).add_done_callback(_cb)
raise FinishSignal()
try:
result = await self._postprocess()
self.finish(result)
except FinishSignal:
return
except Exception as ex:
self.log.error('postprocess failed %s', ex)
self.set_status(500)
self.finish()

def run_task(self: PageHandler, coro: Coroutine) -> Task:
task = asyncio.create_task(coro)
self.finish_group.add_future(task)
return task

async def _postprocess(self) -> Any:
if self._finished or getattr(self.request, 'canceled', False):
if self.is_finished():
self.log.info('page was already finished, skipping postprocessors')
return

Expand Down Expand Up @@ -550,7 +544,7 @@ def on_finish(self) -> None:
self.stages_logger.commit_stage('flush')
self.stages_logger.flush_stages(self.get_status())

def _handle_request_exception(self, e: BaseException) -> None:
async def _handle_exception(self, e: BaseException) -> None:
if isinstance(e, AbortAsyncGroup):
self.log.info('page was aborted, skipping postprocessing')
return
Expand All @@ -559,17 +553,15 @@ def _handle_request_exception(self, e: BaseException) -> None:
try:
if e.wait_finish_group:
self._handler_finished_notification()
self.add_future(self.finish_group.get_finish_future(), lambda _: self.finish_with_postprocessors())
else:
self.finish_with_postprocessors()
await self.finish_group.get_finish_future()

await self.finish_with_postprocessors()
except FinishSignal:
return
except Exception as exc:
super()._handle_request_exception(exc)

return

if self._finished and not isinstance(e, Finish):
if self.is_finished() and not isinstance(e, Finish):
return

if isinstance(e, FinishSignal):
Expand Down Expand Up @@ -598,64 +590,63 @@ def _handle_request_exception(self, e: BaseException) -> None:
error_method_name = f'{self.request.method.lower()}_page_fail_fast' # type: ignore
method = getattr(self, error_method_name, None)
if callable(method):
method(e.failed_result)
await method(e.failed_result)
self.finish()
else:
self.__return_error(e.failed_result.status_code, error_info={'is_fail_fast': True})
await self.__return_error(e.failed_result.status_code, error_info={'is_fail_fast': True})
return
except FinishSignal:
return
except Exception as exc:
super()._handle_request_exception(exc)

else:
super()._handle_request_exception(e)
await self._send_error(exception=e)

def send_error(self, status_code: int = 500, **kwargs: Any) -> None:
async def _send_error(self, status_code: int = 500, exception: Any = None, **kwargs: Any) -> None:
"""`send_error` is adapted to support `write_error` that can call
`finish` asynchronously.
"""
self.stages_logger.commit_stage('page')

if self._headers_written:
super().send_error(status_code, **kwargs)
return

reason = kwargs.get('reason')
if 'exc_info' in kwargs:
exception = kwargs['exc_info'][1]
if isinstance(exception, tornado.web.HTTPError) and exception.reason:
reason = exception.reason
else:
exception = None
if exception is not None:
exc_info = type(exception), exception, exception.__traceback__
kwargs['exc_info'] = exc_info
self.log_exception(*exc_info) # сентри этот метод манкипатчит

if not isinstance(exception, HTTPErrorWithPostprocessors):
self.clear()

self.set_status(status_code, reason=reason)
if isinstance(exception, HTTPError):
status_code = exception.status_code
self.set_status(status_code)

try:
self.write_error(status_code, **kwargs)
await self._write_error(status_code, **kwargs)
except FinishSignal:
return
except Exception:
self.log.exception('Uncaught exception in write_error')
if not self._finished:
self.finish()
if not self.is_finished():
self._finish(interrupt_execution=False)

def write_error(self, status_code: int = 500, **kwargs: Any) -> None:
async def _write_error(self, status_code: int = 500, **kwargs: Any) -> None:
"""
`write_error` can call `finish` asynchronously if HTTPErrorWithPostprocessors is raised.
"""
exception = kwargs['exc_info'][1] if 'exc_info' in kwargs else None

if isinstance(exception, HTTPErrorWithPostprocessors):
self.finish_with_postprocessors()
await self.finish_with_postprocessors()
return

self.set_header('Content-Type', media_types.TEXT_HTML)
super().write_error(status_code, **kwargs)

def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> NoReturn:
def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> Future[None]:
self._finish(chunk)
future: Future = Future()
future.set_result(None)
return future

def _finish(self, chunk: Optional[Union[str, bytes, dict]] = None, interrupt_execution: bool = True) -> None:
self.stages_logger.commit_stage('postprocess')
for name, value in self._mandatory_headers.items():
self.set_header(name, value)
Expand All @@ -670,29 +661,29 @@ def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> NoReturn:
self._write_buffer = []
chunk = None

if self._finished or getattr(self.request, 'canceled', False):
if self.is_finished():
raise RuntimeError('finish() called twice')

if chunk is not None:
self.write(chunk)

if not self._headers_written:
if self._status_code == 200 and self.request.method in ('GET', 'HEAD') and 'Etag' not in self._headers:
self.set_etag_header()
if self.check_etag_header():
self._write_buffer = []
self.set_status(304)
if self._status_code in (204, 304) or (100 <= self._status_code < 200):
assert not self._write_buffer, 'Cannot send body with %s' % self._status_code
self._clear_representation_headers()
elif 'Content-Length' not in self._headers:
content_length = sum(len(part) for part in self._write_buffer)
self.set_header('Content-Length', content_length)
if self._status_code == 200 and self.request.method in ('GET', 'HEAD') and 'Etag' not in self._headers:
self.set_etag_header()
if self.check_etag_header():
self._write_buffer = []
self.set_status(304)
if self._status_code in (204, 304) or (100 <= self._status_code < 200):
assert not self._write_buffer, 'Cannot send body with %s' % self._status_code
self._clear_representation_headers()
elif 'Content-Length' not in self._headers:
content_length = sum(len(part) for part in self._write_buffer)
self.set_header('Content-Length', content_length)

self._flush()
self._finished = True
self.on_finish()
raise FinishSignal()
if interrupt_execution:
raise FinishSignal()

def _flush(self) -> None:
assert self.request.connection is not None
Expand Down Expand Up @@ -743,7 +734,7 @@ async def _run_postprocessors(self, postprocessors: list) -> bool:
else:
p(self)

if self._finished or getattr(self.request, 'canceled', False):
if self.is_finished():
self.log.warning('page was already finished, breaking postprocessors chain')
return False

Expand All @@ -756,7 +747,7 @@ async def _run_template_postprocessors(self, postprocessors: list, rendered_temp
else:
rendered_template = p(self, rendered_template, meta_info)

if self._finished or getattr(self.request, 'canceled', False):
if self.is_finished():
self.log.warning('page was already finished, breaking postprocessors chain')
return None

Expand Down Expand Up @@ -1029,9 +1020,7 @@ def _execute_http_client_method(
client_method: Callable,
waited: bool,
) -> Future[RequestResult]:
if waited and (
self.is_finished() or self.finish_group.is_finished() or getattr(self.request, 'canceled', False)
):
if waited and (self.is_finished() or self.finish_group.is_finished()):
handler_logger.info(
'attempted to make waited http request to %s %s in finished handler, '
'ignoring. change "waited" method parameter to send it',
Expand Down
4 changes: 2 additions & 2 deletions tests/projects/test_app/pages/fail_fast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ async def get_page_preprocessor(handler: PageHandler = get_current_handler()) ->


class Page(PageHandler):
def get_page_fail_fast(self, failed_future):
async def get_page_fail_fast(self, failed_future):
if self.get_argument('exception_in_fail_fast', 'false') == 'true':
msg = 'Exception in fail_fast'
raise Exception(msg)

self.json.replace({'fail_fast': True})
self.set_status(403)
self.finish_with_postprocessors()
await self.finish_with_postprocessors()


@plain_router.get('/fail_fast', cls=Page, dependencies=[Depends(get_page_preprocessor)])
Expand Down
2 changes: 1 addition & 1 deletion tests/projects/test_app/pages/http_client/raise_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class Page(PageHandler):
def send_error(self, status_code=500, exc_info=None, **kwargs):
def _send_error(self, status_code=500, exc_info=None, **kwargs):
if isinstance(exc_info[1], UnicodeEncodeError):
self.finish('UnicodeEncodeError')

Expand Down
Loading

0 comments on commit 70e4b70

Please sign in to comment.