diff --git a/frontik/handler.py b/frontik/handler.py index fb4ae7954..1b8337125 100644 --- a/frontik/handler.py +++ b/frontik/handler.py @@ -10,7 +10,7 @@ from asyncio.futures import Future from functools import partial, wraps from http import HTTPStatus -from typing import TYPE_CHECKING, Any, NoReturn, Optional, Type, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union, overload import tornado.web from fastapi import Depends, Request @@ -20,7 +20,7 @@ from pydantic import BaseModel, ValidationError from tornado import httputil from tornado.ioloop import IOLoop -from tornado.web import Finish, RequestHandler +from tornado.web import Finish, HTTPError, RequestHandler import frontik.auth import frontik.producers.json_producer @@ -30,7 +30,7 @@ from frontik.auth import DEBUG_AUTH_HEADER_NAME from frontik.debug import DEBUG_HEADER_NAME, DebugMode from frontik.futures import AbortAsyncGroup, AsyncGroup -from frontik.http_status import ALLOWED_STATUSES, NON_CRITICAL_BAD_GATEWAY +from frontik.http_status import ALLOWED_STATUSES, CLIENT_CLOSED_REQUEST, NON_CRITICAL_BAD_GATEWAY from frontik.json_builder import FrontikJsonDecodeError, json_decode from frontik.loggers import CUSTOM_JSON_EXTRA, JSON_REQUESTS_LOGGER from frontik.loggers.stages import StagesLogger @@ -56,10 +56,6 @@ def __init__(self, wait_finish_group: bool = False) -> None: self.wait_finish_group = wait_finish_group -class RedirectSignal(Exception): - pass - - class FinishSignal(Exception): pass @@ -130,8 +126,6 @@ def __init__( self.request = request self._headers_written = False self._finished = False - self._auto_finish = True - self._prepared_future = None self.clear() self.initialize() @@ -344,7 +338,6 @@ def redirect(self, url: str, *args: Any, allow_protocol_relative: bool = False, raise tornado.web.HTTPError(403, 'cannot redirect path with two initial slashes') self.log.info('redirecting to: %s', url) super().redirect(url, *args, **kwargs) - raise RedirectSignal() @property def json_body(self): @@ -398,23 +391,20 @@ async def execute(self) -> tuple[int, str, HTTPHeaders, bytes]: self.prepare() await self._execute_page() - if self._auto_finish and not self._finished: - self.finish() + self.finish() except Exception as e: try: - self._handle_request_exception(e) - except Exception: + await self._handle_exception(e) + except Exception as exc: self.log.exception('Exception in exception handler') - self.send_error() - return self.handler_result_future.result() - if self._prepared_future is not None and not self._prepared_future.done(): - self._prepared_future.set_result(None) + await self._send_error(exception=exc) + return self.__get_result() done, pending = await asyncio.wait((self.handler_result_future,), timeout=5.0) if not done: self.log.error('handler was never finished') - self.send_error() - return self.handler_result_future.result() + await self._send_error(exception=RuntimeError('handler was never finished')) + return self.__get_result() async def get(self, *args, **kwargs): await self._execute_page() @@ -464,27 +454,32 @@ async def _execute_page(self) -> None: if render_result is not None: self.write(render_result) - def get_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def get_page_fail_fast(self, request_result: RequestResult) -> None: + await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def post_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def post_page_fail_fast(self, request_result: RequestResult) -> None: + await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def put_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def put_page_fail_fast(self, request_result: RequestResult) -> None: + await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def delete_page_fail_fast(self, request_result: RequestResult) -> None: - self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) + async def delete_page_fail_fast(self, request_result: RequestResult) -> None: + await self.__return_error(request_result.status_code, error_info={'is_fail_fast': True}) - def __return_error(self, response_code: int, **kwargs: Any) -> None: + async def __return_error(self, response_code: int, **kwargs: Any) -> None: if not (300 <= response_code < 500 or response_code == NON_CRITICAL_BAD_GATEWAY): response_code = HTTPStatus.BAD_GATEWAY - self.send_error(response_code, **kwargs) + await self._send_error(response_code, None, **kwargs) # Finish page def is_finished(self) -> bool: - return self._finished or getattr(self.request, 'canceled', False) + return self.handler_result_future.done() or getattr(self.request, 'canceled', False) + + def __get_result(self) -> tuple[int, str, HTTPHeaders, bytes]: + if getattr(self.request, 'canceled', False): + return CLIENT_CLOSED_REQUEST, 'Client closed the connection: aborting request', self._headers, b'' + return self.handler_result_future.result() def check_finished(self, callback: Callable) -> Callable: @wraps(callback) @@ -496,20 +491,19 @@ def wrapper(*args, **kwargs): return wrapper - def finish_with_postprocessors(self) -> None: + async def finish_with_postprocessors(self) -> None: if not self.finish_group.get_finish_future().done(): self.finish_group.abort() - def _cb(future: Future) -> None: - if (ex := future.exception()) is not None: - self.log.error('postprocess failed %s', ex) - self.set_status(500) - self.finish() - if future.result() is not None: - self.finish(future.result()) - - asyncio.create_task(self._postprocess()).add_done_callback(_cb) - raise FinishSignal() + try: + result = await self._postprocess() + self.finish(result) + except FinishSignal: + return + except Exception as ex: + self.log.error('postprocess failed %s', ex) + self.set_status(500) + self.finish() def run_task(self: PageHandler, coro: Coroutine) -> Task: task = asyncio.create_task(coro) @@ -517,7 +511,7 @@ def run_task(self: PageHandler, coro: Coroutine) -> Task: return task async def _postprocess(self) -> Any: - if self._finished or getattr(self.request, 'canceled', False): + if self.is_finished(): self.log.info('page was already finished, skipping postprocessors') return @@ -550,7 +544,7 @@ def on_finish(self) -> None: self.stages_logger.commit_stage('flush') self.stages_logger.flush_stages(self.get_status()) - def _handle_request_exception(self, e: BaseException) -> None: + async def _handle_exception(self, e: BaseException) -> None: if isinstance(e, AbortAsyncGroup): self.log.info('page was aborted, skipping postprocessing') return @@ -559,17 +553,15 @@ def _handle_request_exception(self, e: BaseException) -> None: try: if e.wait_finish_group: self._handler_finished_notification() - self.add_future(self.finish_group.get_finish_future(), lambda _: self.finish_with_postprocessors()) - else: - self.finish_with_postprocessors() + await self.finish_group.get_finish_future() + + await self.finish_with_postprocessors() except FinishSignal: return - except Exception as exc: - super()._handle_request_exception(exc) return - if self._finished and not isinstance(e, Finish): + if self.is_finished() and not isinstance(e, Finish): return if isinstance(e, FinishSignal): @@ -598,64 +590,63 @@ def _handle_request_exception(self, e: BaseException) -> None: error_method_name = f'{self.request.method.lower()}_page_fail_fast' # type: ignore method = getattr(self, error_method_name, None) if callable(method): - method(e.failed_result) + await method(e.failed_result) self.finish() else: - self.__return_error(e.failed_result.status_code, error_info={'is_fail_fast': True}) + await self.__return_error(e.failed_result.status_code, error_info={'is_fail_fast': True}) + return except FinishSignal: return - except Exception as exc: - super()._handle_request_exception(exc) else: - super()._handle_request_exception(e) + await self._send_error(exception=e) - def send_error(self, status_code: int = 500, **kwargs: Any) -> None: + async def _send_error(self, status_code: int = 500, exception: Any = None, **kwargs: Any) -> None: """`send_error` is adapted to support `write_error` that can call `finish` asynchronously. """ self.stages_logger.commit_stage('page') - - if self._headers_written: - super().send_error(status_code, **kwargs) - return - - reason = kwargs.get('reason') - if 'exc_info' in kwargs: - exception = kwargs['exc_info'][1] - if isinstance(exception, tornado.web.HTTPError) and exception.reason: - reason = exception.reason - else: - exception = None + if exception is not None: + exc_info = type(exception), exception, exception.__traceback__ + kwargs['exc_info'] = exc_info + self.log_exception(*exc_info) # сентри этот метод манкипатчит if not isinstance(exception, HTTPErrorWithPostprocessors): self.clear() - self.set_status(status_code, reason=reason) + if isinstance(exception, HTTPError): + status_code = exception.status_code + self.set_status(status_code) try: - self.write_error(status_code, **kwargs) + await self._write_error(status_code, **kwargs) except FinishSignal: return except Exception: self.log.exception('Uncaught exception in write_error') - if not self._finished: - self.finish() + if not self.is_finished(): + self._finish(interrupt_execution=False) - def write_error(self, status_code: int = 500, **kwargs: Any) -> None: + async def _write_error(self, status_code: int = 500, **kwargs: Any) -> None: """ `write_error` can call `finish` asynchronously if HTTPErrorWithPostprocessors is raised. """ exception = kwargs['exc_info'][1] if 'exc_info' in kwargs else None if isinstance(exception, HTTPErrorWithPostprocessors): - self.finish_with_postprocessors() + await self.finish_with_postprocessors() return self.set_header('Content-Type', media_types.TEXT_HTML) super().write_error(status_code, **kwargs) - def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> NoReturn: + def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> Future[None]: + self._finish(chunk) + future: Future = Future() + future.set_result(None) + return future + + def _finish(self, chunk: Optional[Union[str, bytes, dict]] = None, interrupt_execution: bool = True) -> None: self.stages_logger.commit_stage('postprocess') for name, value in self._mandatory_headers.items(): self.set_header(name, value) @@ -670,29 +661,29 @@ def finish(self, chunk: Optional[Union[str, bytes, dict]] = None) -> NoReturn: self._write_buffer = [] chunk = None - if self._finished or getattr(self.request, 'canceled', False): + if self.is_finished(): raise RuntimeError('finish() called twice') if chunk is not None: self.write(chunk) - if not self._headers_written: - if self._status_code == 200 and self.request.method in ('GET', 'HEAD') and 'Etag' not in self._headers: - self.set_etag_header() - if self.check_etag_header(): - self._write_buffer = [] - self.set_status(304) - if self._status_code in (204, 304) or (100 <= self._status_code < 200): - assert not self._write_buffer, 'Cannot send body with %s' % self._status_code - self._clear_representation_headers() - elif 'Content-Length' not in self._headers: - content_length = sum(len(part) for part in self._write_buffer) - self.set_header('Content-Length', content_length) + if self._status_code == 200 and self.request.method in ('GET', 'HEAD') and 'Etag' not in self._headers: + self.set_etag_header() + if self.check_etag_header(): + self._write_buffer = [] + self.set_status(304) + if self._status_code in (204, 304) or (100 <= self._status_code < 200): + assert not self._write_buffer, 'Cannot send body with %s' % self._status_code + self._clear_representation_headers() + elif 'Content-Length' not in self._headers: + content_length = sum(len(part) for part in self._write_buffer) + self.set_header('Content-Length', content_length) self._flush() self._finished = True self.on_finish() - raise FinishSignal() + if interrupt_execution: + raise FinishSignal() def _flush(self) -> None: assert self.request.connection is not None @@ -743,7 +734,7 @@ async def _run_postprocessors(self, postprocessors: list) -> bool: else: p(self) - if self._finished or getattr(self.request, 'canceled', False): + if self.is_finished(): self.log.warning('page was already finished, breaking postprocessors chain') return False @@ -756,7 +747,7 @@ async def _run_template_postprocessors(self, postprocessors: list, rendered_temp else: rendered_template = p(self, rendered_template, meta_info) - if self._finished or getattr(self.request, 'canceled', False): + if self.is_finished(): self.log.warning('page was already finished, breaking postprocessors chain') return None @@ -1029,9 +1020,7 @@ def _execute_http_client_method( client_method: Callable, waited: bool, ) -> Future[RequestResult]: - if waited and ( - self.is_finished() or self.finish_group.is_finished() or getattr(self.request, 'canceled', False) - ): + if waited and (self.is_finished() or self.finish_group.is_finished()): handler_logger.info( 'attempted to make waited http request to %s %s in finished handler, ' 'ignoring. change "waited" method parameter to send it', diff --git a/tests/projects/test_app/pages/fail_fast/__init__.py b/tests/projects/test_app/pages/fail_fast/__init__.py index 0fb90ef3a..12cadf40f 100644 --- a/tests/projects/test_app/pages/fail_fast/__init__.py +++ b/tests/projects/test_app/pages/fail_fast/__init__.py @@ -10,14 +10,14 @@ async def get_page_preprocessor(handler: PageHandler = get_current_handler()) -> class Page(PageHandler): - def get_page_fail_fast(self, failed_future): + async def get_page_fail_fast(self, failed_future): if self.get_argument('exception_in_fail_fast', 'false') == 'true': msg = 'Exception in fail_fast' raise Exception(msg) self.json.replace({'fail_fast': True}) self.set_status(403) - self.finish_with_postprocessors() + await self.finish_with_postprocessors() @plain_router.get('/fail_fast', cls=Page, dependencies=[Depends(get_page_preprocessor)]) diff --git a/tests/projects/test_app/pages/http_client/raise_error.py b/tests/projects/test_app/pages/http_client/raise_error.py index 5a206dee5..3e3ff1924 100644 --- a/tests/projects/test_app/pages/http_client/raise_error.py +++ b/tests/projects/test_app/pages/http_client/raise_error.py @@ -3,7 +3,7 @@ class Page(PageHandler): - def send_error(self, status_code=500, exc_info=None, **kwargs): + def _send_error(self, status_code=500, exc_info=None, **kwargs): if isinstance(exc_info[1], UnicodeEncodeError): self.finish('UnicodeEncodeError') diff --git a/tests/projects/test_app/pages/write_error.py b/tests/projects/test_app/pages/write_error.py index ab0bfaca8..53f4aa2f6 100644 --- a/tests/projects/test_app/pages/write_error.py +++ b/tests/projects/test_app/pages/write_error.py @@ -3,13 +3,13 @@ class Page(PageHandler): - def write_error(self, status_code=500, **kwargs): + async def _write_error(self, status_code=500, **kwargs): self.json.put({'write_error': True}) if self.get_argument('fail_write_error', 'false') == 'true': raise Exception('exception in write_error') - self.finish_with_postprocessors() + await self.finish_with_postprocessors() @plain_router.get('/write_error', cls=Page)