diff --git a/frontik/app_integrations/telemetry.py b/frontik/app_integrations/telemetry.py index fbbf980cc..fe3b7d788 100644 --- a/frontik/app_integrations/telemetry.py +++ b/frontik/app_integrations/telemetry.py @@ -162,8 +162,7 @@ def generate_trace_id(self) -> int: request_id = request_context.get_request_id() try: if request_id is None: - msg = 'bad request_id' - raise Exception(msg) + raise Exception('bad request_id') if len(request_id) < 32: log.debug('request_id = %s is less than 32 characters. Generating random trace_id', request_id) diff --git a/frontik/frontik_response.py b/frontik/frontik_response.py index 351891a2f..1b6a15e40 100644 --- a/frontik/frontik_response.py +++ b/frontik/frontik_response.py @@ -15,6 +15,7 @@ def __init__( status_code: int, headers: dict[str, str] | None | HTTPHeaders = None, body: bytes = b'', + reason: str | None = None, ): self.headers = HTTPHeaders(get_default_headers()) # type: ignore @@ -26,11 +27,12 @@ def __init__( self.status_code = status_code self.body = body + self._reason = reason self.data_written = False @property def reason(self) -> str: - return httputil.responses.get(self.status_code, 'Unknown') + return self._reason or httputil.responses.get(self.status_code, 'Unknown') def get_default_headers() -> Mapping[str, str | None]: diff --git a/frontik/handler.py b/frontik/handler.py index dd614ed5a..00d825ca1 100644 --- a/frontik/handler.py +++ b/frontik/handler.py @@ -30,6 +30,7 @@ from frontik import media_types, request_context from frontik.auth import DEBUG_AUTH_HEADER_NAME from frontik.debug import DEBUG_HEADER_NAME, DebugMode +from frontik.frontik_response import FrontikResponse from frontik.futures import AbortAsyncGroup, AsyncGroup from frontik.http_status import ALLOWED_STATUSES, NON_CRITICAL_BAD_GATEWAY from frontik.json_builder import FrontikJsonDecodeError, json_decode @@ -39,13 +40,12 @@ from frontik.timeout_tracking import get_timeout_checker from frontik.util import gather_dict, make_url from frontik.validator import BaseValidationModel, Validators -from frontik.version import version as frontik_version if TYPE_CHECKING: from collections.abc import Callable, Coroutine from http_client import HttpClient - from tornado.httputil import HTTPHeaders, HTTPServerRequest + from tornado.httputil import HTTPServerRequest from frontik.app import FrontikApplication from frontik.app_integrations.statsd import StatsDClient, StatsDClientStub @@ -111,7 +111,7 @@ def __init__( path_params: dict[str, str], ) -> None: self.name = self.__class__.__name__ - self.request_id: str = request_context.get_request_id() # type: ignore + self.request_id: str = request.request_id # type: ignore self.config = application.config self.log = handler_logger self.text: Any = None @@ -158,7 +158,7 @@ def __init__( request._start_time, ) - self.handler_result_future: Future[tuple[int, str, HTTPHeaders, bytes]] = Future() + self.handler_result_future: Future[FrontikResponse] = Future() def __repr__(self): return f'{self.__module__}.{self.__class__.__name__}' @@ -183,10 +183,7 @@ def prepare(self) -> None: super().prepare() def set_default_headers(self): - self._headers = httputil.HTTPHeaders({ - 'Server': f'Frontik/{frontik_version}', - 'X-Request-Id': self.request_id, - }) + self._headers = httputil.HTTPHeaders() @property def path(self) -> str: @@ -370,7 +367,7 @@ def add_future(cls, future: Future, callback: Callable) -> None: # Requests handling - async def execute(self) -> tuple[int, str, HTTPHeaders, bytes]: + async def execute(self) -> FrontikResponse: self._transforms = [] try: if self.request.method not in self.SUPPORTED_METHODS: @@ -701,7 +698,7 @@ def _flush(self) -> None: for cookie in self._new_cookie.values(): self.add_header('Set-Cookie', cookie.OutputString(None)) - self.handler_result_future.set_result((self._status_code, self._reason, self._headers, chunk)) + self.handler_result_future.set_result(FrontikResponse(self._status_code, self._headers, chunk, self._reason)) # postprocessors diff --git a/frontik/handler_asgi.py b/frontik/handler_asgi.py index 23d0e04f8..bfb4bc3f7 100644 --- a/frontik/handler_asgi.py +++ b/frontik/handler_asgi.py @@ -226,15 +226,17 @@ async def execute_tornado_page( route, page_cls, path_params = scope['route'], scope['page_cls'], scope['path_params'] request_context.set_handler_name(route) handler: PageHandler = page_cls(frontik_app, tornado_request, route, debug_mode, path_params) - status_code, _, headers, body = await handler.execute() - return FrontikResponse(status_code=status_code, headers=headers, body=body) + return await handler.execute() def _on_connection_close(tornado_request, process_request_task, integrations): - response = FrontikResponse(CLIENT_CLOSED_REQUEST) - for integration in integrations.values(): - integration.set_response(response) - - log_request(tornado_request, CLIENT_CLOSED_REQUEST) - setattr(tornado_request, 'canceled', False) - process_request_task.cancel() # serve_tornado_request will be interrupted with CanceledError + request_id = integrations.get('request_id', IntegrationDto()).get_value() + with request_context.request_context(request_id): + log.info('client has canceled request') + response = FrontikResponse(CLIENT_CLOSED_REQUEST) + for integration in integrations.values(): + integration.set_response(response) + + log_request(tornado_request, CLIENT_CLOSED_REQUEST) + setattr(tornado_request, 'canceled', False) + process_request_task.cancel() # serve_tornado_request will be interrupted with CanceledError diff --git a/frontik/loggers/__init__.py b/frontik/loggers/__init__.py index 7a6097b1a..01514a251 100644 --- a/frontik/loggers/__init__.py +++ b/frontik/loggers/__init__.py @@ -26,11 +26,9 @@ class Mdc: def __init__(self) -> None: - self.pid: int self.role: Union[str, None] = None def init(self, role: Union[str, None] = None) -> None: - self.pid = os.getpid() self.role = role @@ -40,8 +38,7 @@ def init(self, role: Union[str, None] = None) -> None: class ContextFilter(Filter): def filter(self, record): handler_name = request_context.get_handler_name() - request_id = request_context.get_request_id() - record.name = '.'.join(filter(None, [record.name, handler_name, request_id])) + record.name = '.'.join(filter(None, [record.name, handler_name])) return True @@ -94,7 +91,7 @@ def format(self, record): @staticmethod def get_mdc() -> dict: - mdc: dict = {'thread': MDC.pid} + mdc: dict = {'thread': os.getpid()} if MDC.role is not None: mdc['role'] = MDC.role @@ -126,15 +123,11 @@ def format_stack_trace(self, record: logging.LogRecord) -> str: return stack_trace -_JSON_FORMATTER = JSONFormatter() +JSON_FORMATTER = JSONFormatter() class StderrFormatter(LogFormatter): def format(self, record): - handler_name = request_context.get_handler_name() - request_id = request_context.get_request_id() - record.name = '.'.join(filter(None, [record.name, handler_name, request_id])) - if not record.msg: record.msg = ', '.join(f'{k}={v}' for k, v in getattr(record, CUSTOM_JSON_EXTRA, {}).items()) @@ -197,7 +190,7 @@ def _configure_file( if formatter is not None: file_handler.setFormatter(formatter) elif use_json_formatter: - file_handler.setFormatter(_JSON_FORMATTER) + file_handler.setFormatter(JSON_FORMATTER) else: file_handler.setFormatter(get_text_formatter()) file_handler.addFilter(_CONTEXT_FILTER) @@ -232,7 +225,7 @@ def _configure_syslog( if formatter is not None: syslog_handler.setFormatter(formatter) elif use_json_formatter: - syslog_handler.setFormatter(_JSON_FORMATTER) + syslog_handler.setFormatter(JSON_FORMATTER) else: syslog_handler.setFormatter(get_text_formatter()) syslog_handler.addFilter(_CONTEXT_FILTER) diff --git a/frontik/loggers/logleveloverride/http_log_level_override_extension.py b/frontik/loggers/logleveloverride/http_log_level_override_extension.py index 4d9bb4c3d..49fcbf21e 100644 --- a/frontik/loggers/logleveloverride/http_log_level_override_extension.py +++ b/frontik/loggers/logleveloverride/http_log_level_override_extension.py @@ -4,9 +4,9 @@ from fastapi import HTTPException from http_client import HttpClientFactory -from frontik import request_context from frontik.loggers.logleveloverride.log_level_override_extension import LogLevelOverride, LogLevelOverrideExtension from frontik.loggers.logleveloverride.logging_configurator_client import LOG_LEVEL_MAPPING +from frontik.util import generate_uniq_timestamp_request_id logger = logging.getLogger('http_log_level_override_extension') @@ -33,7 +33,7 @@ def __init__(self, host: str, uri: str, http_client_factory: HttpClientFactory) self.http_client_factory = http_client_factory async def load_log_level_overrides(self) -> list[LogLevelOverride]: - headers = {'X-Request-Id': request_context.get_request_id()} + headers = {'X-Request-Id': generate_uniq_timestamp_request_id()} result = await self.http_client_factory.get_http_client().get_url(self.host, self.uri, headers=headers) if result.failed: logger.error('some problem with fetching log level overrides: %s', result.failed) diff --git a/frontik/request_context.py b/frontik/request_context.py index 109434a2c..8802e31ce 100644 --- a/frontik/request_context.py +++ b/frontik/request_context.py @@ -1,11 +1,14 @@ from __future__ import annotations import contextvars +from contextlib import contextmanager from typing import TYPE_CHECKING, Optional from fastapi.routing import APIRoute if TYPE_CHECKING: + from collections.abc import Iterator + from frontik.debug import DebugBufferedHandler @@ -21,6 +24,15 @@ def __init__(self, request_id: Optional[str]) -> None: _context = contextvars.ContextVar('context', default=_Context(None)) +@contextmanager +def request_context(request_id: Optional[str]) -> Iterator: + token = _context.set(_Context(request_id)) + try: + yield + finally: + _context.reset(token) + + def get_request_id() -> Optional[str]: return _context.get().request_id diff --git a/frontik/request_integrations/request_id.py b/frontik/request_integrations/request_id.py index 0653b4a5f..6b6ae9309 100644 --- a/frontik/request_integrations/request_id.py +++ b/frontik/request_integrations/request_id.py @@ -13,8 +13,5 @@ def request_id_ctx(_, tornado_request): check_request_id(request_id) tornado_request.request_id = request_id - token = request_context._context.set(request_context._Context(request_id)) - try: - yield IntegrationDto() - finally: - request_context._context.reset(token) + with request_context.request_context(request_id): + yield IntegrationDto(request_id) diff --git a/frontik/request_integrations/telemetry.py b/frontik/request_integrations/telemetry.py index 2d331ccee..88cc6f7bf 100644 --- a/frontik/request_integrations/telemetry.py +++ b/frontik/request_integrations/telemetry.py @@ -17,6 +17,7 @@ from opentelemetry.util.http import normalise_response_header_name from opentelemetry.trace.status import Status, StatusCode from tornado import httputil +from frontik import request_context _traced_request_attrs = get_traced_request_attrs('TORNADO') _excluded_urls = ['/status'] @@ -86,6 +87,12 @@ def _finish_span(span, dto: IntegrationDto): if span.is_recording(): span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, status_code) + if (handler_name := request_context.get_handler_name()) is not None: + method_path, method_name = handler_name.rsplit('.', 1) + span.update_name(f'{method_path}.{method_name}') + span.set_attribute(SpanAttributes.CODE_FUNCTION, method_name) + span.set_attribute(SpanAttributes.CODE_NAMESPACE, method_path) + otel_status_code = http_status_to_status_code(status_code, server_span=True) otel_status_description = None if otel_status_code is StatusCode.ERROR: diff --git a/frontik/testing.py b/frontik/testing.py index 30eba07f9..247977f51 100644 --- a/frontik/testing.py +++ b/frontik/testing.py @@ -30,6 +30,10 @@ def event_loop(self): yield loop loop.close() + @pytest.fixture(scope='class') + def frontik_app(self) -> FrontikApplication: + return FrontikApplication() + @pytest.fixture(scope='class', autouse=True) async def _run_server(self, frontik_app): options.stderr_log = True diff --git a/tests/test_cookies.py b/tests/test_cookies.py index 689944bc7..e24079475 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -1,7 +1,5 @@ -import pytest from fastapi import Response -from frontik.app import FrontikApplication from frontik.handler import PageHandler, get_current_handler from frontik.routing import plain_router from frontik.testing import FrontikTestBase @@ -19,11 +17,7 @@ async def asgi_cookies_page(response: Response) -> None: response.set_cookie('key2', 'val2') -class TestFrontikTesting(FrontikTestBase): - @pytest.fixture(scope='class') - def frontik_app(self) -> FrontikApplication: - return FrontikApplication() - +class TestCookies(FrontikTestBase): async def test_cookies(self): response = await self.fetch('/cookies') diff --git a/tests/test_logging.py b/tests/test_logging.py index 050811876..2d92f742a 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -113,7 +113,7 @@ def test_send_to_syslog(self): { 'priority': '10', 'message': r'\[\d+\] [\d-]+ [\d:,]+ CRITICAL ' - r'custom_logger\.tests\.projects\.test_app\.pages\.log\.get_page\.\w+: fatal', + r'custom_logger\.tests\.projects\.test_app\.pages\.log\.get_page: fatal', }, ] diff --git a/tests/test_request_id.py b/tests/test_request_id.py new file mode 100644 index 000000000..96e8cfcbf --- /dev/null +++ b/tests/test_request_id.py @@ -0,0 +1,87 @@ +import asyncio +import json + +from frontik.handler import PageHandler +from frontik.loggers import JSON_FORMATTER +from frontik.routing import plain_router +from frontik.testing import FrontikTestBase + +known_loggers = ['handler', 'stages'] + + +@plain_router.get('/request_id', cls=PageHandler) +async def request_id_page() -> None: + pass + + +@plain_router.get('/request_id_long', cls=PageHandler) +async def request_id_long_page() -> None: + await asyncio.sleep(2) + + +@plain_router.get('/asgi_request_id') +async def asgi_request_id_page() -> None: + pass + + +@plain_router.get('/asgi_request_id_long') +async def asgi_request_id_long_page() -> None: + await asyncio.sleep(2) + + +class TestRequestId(FrontikTestBase): + async def test_request_id(self): + response = await self.fetch('/request_id') + + assert response.status_code == 200 + assert len(response.headers.getall('X-Request-Id')) == 1 + + async def test_asgi_request_id(self): + response = await self.fetch('/asgi_request_id') + + assert response.status_code == 200 + assert len(response.headers.getall('X-Request-Id')) == 1 + + async def test_request_id_canceled_request(self, caplog): + caplog.handler.setFormatter(JSON_FORMATTER) + response = await self.fetch('/request_id_long', request_timeout=0.1) + await asyncio.sleep(1) + + assert response.status_code == 599 + assert 'client has canceled request' in caplog.text + + rid = None + for log_row in caplog.text.split('\n'): + if log_row == '': + continue + log_obj = json.loads(log_row) + assert log_obj.get('logger') in known_loggers + + mdc = log_obj.get('mdc') + assert mdc is not None + + assert mdc.get('rid') is not None + rid = rid or mdc.get('rid') + assert mdc.get('rid') == rid + + async def test_asgi_request_id_canceled_request(self, caplog): + caplog.handler.setFormatter(JSON_FORMATTER) + response = await self.fetch('/asgi_request_id_long', request_timeout=0.1) + await asyncio.sleep(1) + + assert response.status_code == 599 + assert 'client has canceled request' in caplog.text + + rid = None + for log_row in caplog.text.split('\n'): + if log_row == '': + continue + log_obj = json.loads(log_row) + assert log_obj.get('logger') in known_loggers + + mdc = log_obj.get('mdc') + assert mdc is not None + + assert mdc.get('rid') is not None + rid = rid or mdc.get('rid') + assert mdc.get('rid') == rid diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index 292d613f1..333f2c9ed 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -1,11 +1,11 @@ -from collections.abc import Iterator, Sequence -from contextlib import contextmanager +from collections.abc import Sequence from typing import Any, Optional import pytest from fastapi import Request from opentelemetry import trace from opentelemetry.sdk.trace.export import BatchSpanProcessor, ReadableSpan, SpanExporter, SpanExportResult +from opentelemetry.semconv.trace import SpanAttributes from frontik import request_context from frontik.app import FrontikApplication @@ -18,15 +18,6 @@ dummy_request = Request({'type': 'http'}) -@contextmanager -def request_id_context(request_id: str) -> Iterator: - token = request_context._context.set(request_context._Context(request_id)) - try: - yield - finally: - request_context._context.reset(token) - - class TestTelemetry: def setup_method(self) -> None: self.trace_id_generator = FrontikIdGenerator() @@ -36,33 +27,33 @@ def test_generate_trace_id_with_none_request_id(self) -> None: assert trace_id is not None def test_generate_trace_id_with_hex_request_id(self) -> None: - with request_id_context('163897206709842601f90a070699ac44'): + with request_context.request_context('163897206709842601f90a070699ac44'): trace_id = self.trace_id_generator.generate_trace_id() assert '0x163897206709842601f90a070699ac44' == hex(trace_id) def test_generate_trace_id_with_no_hex_request_id(self) -> None: - with request_id_context('non-hex-string-1234'): + with request_context.request_context('non-hex-string-1234'): trace_id = self.trace_id_generator.generate_trace_id() assert trace_id is not None def test_generate_trace_id_with_no_str_request_id(self) -> None: - with request_id_context(12345678910): # type: ignore + with request_context.request_context(12345678910): # type: ignore trace_id = self.trace_id_generator.generate_trace_id() assert trace_id is not None def test_generate_trace_id_with_hex_request_id_and_postfix(self) -> None: - with request_id_context('163897206709842601f90a070699ac44_some_postfix_string'): + with request_context.request_context('163897206709842601f90a070699ac44_some_postfix_string'): trace_id = self.trace_id_generator.generate_trace_id() assert '0x163897206709842601f90a070699ac44' == hex(trace_id) def test_generate_trace_id_with_no_hex_request_id_in_first_32_characters(self) -> None: - with request_id_context('16389720670_NOT_HEX_9842601f90a070699ac44_some_postfix_string'): + with request_context.request_context('16389720670_NOT_HEX_9842601f90a070699ac44_some_postfix_string'): trace_id = self.trace_id_generator.generate_trace_id() assert trace_id is not None assert '0x16389720670_NOT_HEX_9842601f90a0' != hex(trace_id) def test_generate_trace_id_with_request_id_len_less_32_characters(self) -> None: - with request_id_context('163897206'): + with request_context.request_context('163897206'): trace_id = self.trace_id_generator.generate_trace_id() assert trace_id is not None assert '0x163897206' != hex(trace_id) @@ -137,3 +128,7 @@ async def test_parent_span(self, frontik_app: FrontikApplication) -> None: assert client_a_span.parent is not None assert server_b_span is not None assert server_b_span.parent is not None + + assert server_b_span.attributes is not None + assert server_b_span.attributes.get(SpanAttributes.CODE_FUNCTION) == 'get_page_b' + assert server_b_span.attributes.get(SpanAttributes.CODE_NAMESPACE) == 'tests.test_telemetry'