diff --git a/src/prefect/client/base.py b/src/prefect/client/base.py index 81bd6864b115..b6516295147e 100644 --- a/src/prefect/client/base.py +++ b/src/prefect/client/base.py @@ -3,7 +3,8 @@ import threading from collections import defaultdict from contextlib import asynccontextmanager -from typing import ContextManager, Dict, Tuple, Type +from functools import partial +from typing import Callable, ContextManager, Dict, Set, Tuple, Type import anyio import httpx @@ -13,6 +14,7 @@ from typing_extensions import Self from prefect.exceptions import PrefectHTTPStatusError +from prefect.logging import get_logger # Datastores for lifespan management, keys should be a tuple of thread and app identities. APP_LIFESPANS: Dict[Tuple[int, int], LifespanManager] = {} @@ -21,6 +23,9 @@ APP_LIFESPANS_LOCKS: Dict[int, anyio.Lock] = defaultdict(anyio.Lock) +logger = get_logger("client") + + @asynccontextmanager async def app_lifespan_context(app: FastAPI) -> ContextManager[None]: """ @@ -154,27 +159,90 @@ class PrefectHttpxClient(httpx.AsyncClient): RETRY_MAX = 5 - async def send(self, *args, **kwargs) -> Response: - retry_count = 0 - response = PrefectResponse.from_httpx_response( - await super().send(*args, **kwargs) - ) - while ( - response.status_code - in {status.HTTP_429_TOO_MANY_REQUESTS, status.HTTP_503_SERVICE_UNAVAILABLE} - and retry_count < self.RETRY_MAX - ): - retry_count += 1 - - # Respect the "Retry-After" header, falling back to an exponential back-off - retry_after = response.headers.get("Retry-After") - if retry_after: - retry_seconds = float(retry_after) - else: - retry_seconds = 2**retry_count + async def _send_with_retry( + self, + request: Callable, + retry_codes: Set[int] = set(), + retry_exceptions: Tuple[Exception, ...] = tuple(), + ): + """ + Send a request and retry it if it fails. + Sends the provided request and retries it up to self.RETRY_MAX times if + the request either raises an exception listed in `retry_exceptions` or receives + a response with a status code listed in `retry_codes`. + + Retries will be delayed based on either the retry header (preferred) or + exponential backoff if a retry header is not provided. + """ + try_count = 0 + response = None + + while try_count <= self.RETRY_MAX: + try_count += 1 + retry_seconds = None + exc_info = None + + try: + response = await request() + except retry_exceptions: + if try_count > self.RETRY_MAX: + raise + # Otherwise, we will ignore this error but capture the info for logging + exc_info = sys.exc_info() + else: + # We got a response; return immediately if it is not retryable + if response.status_code not in retry_codes: + return response + + if "Retry-After" in response.headers: + retry_seconds = float(response.headers["Retry-After"]) + + # Use an exponential back-off if not set in a header + if retry_seconds is None: + retry_seconds = 2**try_count + + logger.debug( + ( + "Encountered retryable exception during request. " + if exc_info + else "Received response with retryable status code. " + ) + + ( + f"Another attempt will be made in {retry_seconds}s. " + f"This is attempt {try_count}/{self.RETRY_MAX + 1}." + ), + exc_info=exc_info, + ) await anyio.sleep(retry_seconds) - response = await super().send(*args, **kwargs) + + assert ( + response is not None + ), "Retry handling ended without response or exception" + + # We ran out of retries, return the failed response + return response + + async def send(self, *args, **kwargs) -> Response: + api_request = partial(super().send, *args, **kwargs) + + response = await self._send_with_retry( + request=api_request, + retry_codes={ + status.HTTP_429_TOO_MANY_REQUESTS, + status.HTTP_503_SERVICE_UNAVAILABLE, + }, + retry_exceptions=( + httpx.ReadTimeout, + httpx.PoolTimeout, + # `ConnectionResetError` when reading socket raises as a `ReadError` + httpx.ReadError, + # Uvicorn bug, see https://github.com/PrefectHQ/prefect/issues/7512 + httpx.RemoteProtocolError, + # HTTP2 bug, see https://github.com/PrefectHQ/prefect/issues/7442 + httpx.LocalProtocolError, + ), + ) # Always raise bad responses # NOTE: We may want to remove this and handle responses per route in the diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index faa3bb784933..52100db6bcd9 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -1,38 +1,51 @@ from unittest.mock import call +import httpx import pytest from fastapi import status -from httpx import AsyncClient, HTTPStatusError, Request, Response +from httpx import AsyncClient, Request, Response from prefect.client.base import PrefectHttpxClient from prefect.testing.utilities import AsyncMock +RESPONSE_429_RETRY_AFTER_0 = Response( + status.HTTP_429_TOO_MANY_REQUESTS, + headers={"Retry-After": "0"}, + request=Request("a test request", "fake.url/fake/route"), +) + +RESPONSE_429_RETRY_AFTER_MISSING = Response( + status.HTTP_429_TOO_MANY_REQUESTS, + request=Request("a test request", "fake.url/fake/route"), +) + +RESPONSE_200 = Response( + status.HTTP_200_OK, + request=Request("a test request", "fake.url/fake/route"), +) + class TestPrefectHttpxClient: + @pytest.mark.usefixtures("mock_anyio_sleep") @pytest.mark.parametrize( "error_code", [status.HTTP_429_TOO_MANY_REQUESTS, status.HTTP_503_SERVICE_UNAVAILABLE], ) async def test_prefect_httpx_client_retries_on_designated_error_codes( - self, monkeypatch, error_code + self, monkeypatch, error_code, caplog ): base_client_send = AsyncMock() monkeypatch.setattr(AsyncClient, "send", base_client_send) client = PrefectHttpxClient() retry_response = Response( error_code, - headers={"Retry-After": "0"}, - request=Request("a test request", "fake.url/fake/route"), - ) - success_response = Response( - status.HTTP_200_OK, request=Request("a test request", "fake.url/fake/route"), ) base_client_send.side_effect = [ retry_response, retry_response, retry_response, - success_response, + RESPONSE_200, ] response = await client.post( url="fake.url/fake/route", data={"evenmorefake": "data"} @@ -40,23 +53,111 @@ async def test_prefect_httpx_client_retries_on_designated_error_codes( assert response.status_code == status.HTTP_200_OK assert base_client_send.call_count == 4 - async def test_prefect_httpx_client_retries_429s_up_to_five_times( - self, monkeypatch + # We log on retry + assert "Received response with retryable status code" in caplog.text + assert "Another attempt will be made in 2s" in caplog.text + assert "This is attempt 1/6" in caplog.text + + # A traceback should not be included + assert "Traceback" not in caplog.text + + # Ensure the messaging changes + assert "Another attempt will be made in 4s" in caplog.text + assert "This is attempt 2/6" in caplog.text + + @pytest.mark.usefixtures("mock_anyio_sleep") + @pytest.mark.parametrize( + "exception_type", + [ + httpx.RemoteProtocolError, + httpx.ReadError, + httpx.LocalProtocolError, + httpx.PoolTimeout, + httpx.ReadTimeout, + ], + ) + async def test_prefect_httpx_client_retries_on_designated_exceptions( + self, + monkeypatch, + exception_type, + caplog, ): - client = PrefectHttpxClient() base_client_send = AsyncMock() monkeypatch.setattr(AsyncClient, "send", base_client_send) + client = PrefectHttpxClient() - retry_response = Response( - status.HTTP_429_TOO_MANY_REQUESTS, - headers={"Retry-After": "0"}, - request=Request("a test request", "fake.url/fake/route"), + base_client_send.side_effect = [ + exception_type("test"), + exception_type("test"), + exception_type("test"), + RESPONSE_200, + ] + response = await client.post( + url="fake.url/fake/route", data={"evenmorefake": "data"} ) + assert response.status_code == status.HTTP_200_OK + assert base_client_send.call_count == 4 + + # We log on retry + assert "Encountered retryable exception during request" in caplog.text + assert "Another attempt will be made in 2s" in caplog.text + assert "This is attempt 1/6" in caplog.text + + # The traceback should be included + assert "Traceback" in caplog.text + + # Ensure the messaging changes + assert "Another attempt will be made in 4s" in caplog.text + assert "This is attempt 2/6" in caplog.text + + @pytest.mark.usefixtures("mock_anyio_sleep") + @pytest.mark.parametrize( + "response_or_exc", + [RESPONSE_429_RETRY_AFTER_0, httpx.RemoteProtocolError("test")], + ) + async def test_prefect_httpx_client_retries_up_to_five_times( + self, + monkeypatch, + response_or_exc, + ): + client = PrefectHttpxClient() + base_client_send = AsyncMock() + monkeypatch.setattr(AsyncClient, "send", base_client_send) + + # Return more than 6 retryable responses + base_client_send.side_effect = [response_or_exc] * 10 + + with pytest.raises(Exception): + await client.post( + url="fake.url/fake/route", + data={"evenmorefake": "data"}, + ) + + # 5 retries + 1 first attempt + assert base_client_send.call_count == 6 - # Return more than 6 retry responses - base_client_send.side_effect = [retry_response] * 7 + @pytest.mark.usefixtures("mock_anyio_sleep") + @pytest.mark.parametrize( + "final_response,expected_error_type", + [ + ( + RESPONSE_429_RETRY_AFTER_0, + httpx.HTTPStatusError, + ), + (httpx.RemoteProtocolError("test"), httpx.RemoteProtocolError), + ], + ) + async def test_prefect_httpx_client_raises_final_error_after_retries( + self, monkeypatch, final_response, expected_error_type + ): + client = PrefectHttpxClient() + base_client_send = AsyncMock() + monkeypatch.setattr(AsyncClient, "send", base_client_send) + + # First throw a bunch of retryable errors, then the final one + base_client_send.side_effect = [httpx.ReadError("test")] * 5 + [final_response] - with pytest.raises(HTTPStatusError, match="429"): + with pytest.raises(expected_error_type): await client.post( url="fake.url/fake/route", data={"evenmorefake": "data"}, @@ -78,14 +179,9 @@ async def test_prefect_httpx_client_respects_retry_header( request=Request("a test request", "fake.url/fake/route"), ) - success_response = Response( - status.HTTP_200_OK, - request=Request("a test request", "fake.url/fake/route"), - ) - base_client_send.side_effect = [ retry_response, - success_response, + RESPONSE_200, ] with mock_anyio_sleep.assert_sleeps_for(5): @@ -94,28 +190,23 @@ async def test_prefect_httpx_client_respects_retry_header( ) assert response.status_code == status.HTTP_200_OK - async def test_prefect_httpx_client_falls_back_to_exponential_backoff( - self, mock_anyio_sleep, monkeypatch + @pytest.mark.parametrize( + "response_or_exc", + [RESPONSE_429_RETRY_AFTER_MISSING, httpx.RemoteProtocolError("test")], + ) + async def test_prefect_httpx_client_uses_exponential_backoff_without_retry_after_header( + self, mock_anyio_sleep, response_or_exc, monkeypatch ): base_client_send = AsyncMock() monkeypatch.setattr(AsyncClient, "send", base_client_send) client = PrefectHttpxClient() - retry_response = Response( - status.HTTP_429_TOO_MANY_REQUESTS, - request=Request("a test request", "fake.url/fake/route"), - ) - - success_response = Response( - status.HTTP_200_OK, - request=Request("a test request", "fake.url/fake/route"), - ) base_client_send.side_effect = [ - retry_response, - retry_response, - retry_response, - success_response, + response_or_exc, + response_or_exc, + response_or_exc, + RESPONSE_200, ] with mock_anyio_sleep.assert_sleeps_for(2 + 4 + 8): @@ -133,25 +224,17 @@ async def test_prefect_httpx_client_respects_retry_header_per_response( client = PrefectHttpxClient() - def make_retry_response(retry_after): - return Response( + base_client_send.side_effect = [ + # Generate responses with retry after headers + Response( status.HTTP_429_TOO_MANY_REQUESTS, headers={"Retry-After": str(retry_after)}, request=Request("a test request", "fake.url/fake/route"), ) - - success_response = Response( - status.HTTP_200_OK, - request=Request("a test request", "fake.url/fake/route"), - ) - - base_client_send.side_effect = [ - make_retry_response(5), - make_retry_response(0), - make_retry_response(10), - make_retry_response(2.0), - success_response, - ] + for retry_after in [5, 0, 10, 2.0] + ] + [ + RESPONSE_200 + ] # Then succeed with mock_anyio_sleep.assert_sleeps_for(5 + 10 + 2): response = await client.post( @@ -159,3 +242,18 @@ def make_retry_response(retry_after): ) assert response.status_code == status.HTTP_200_OK mock_anyio_sleep.assert_has_awaits([call(5), call(0), call(10), call(2.0)]) + + async def test_prefect_httpx_client_does_not_retry_other_exceptions( + self, mock_anyio_sleep, monkeypatch + ): + base_client_send = AsyncMock() + monkeypatch.setattr(AsyncClient, "send", base_client_send) + + client = PrefectHttpxClient() + + base_client_send.side_effect = [TypeError("This error should not be retried")] + + with pytest.raises(TypeError, match="This error should not be retried"): + await client.post(url="fake.url/fake/route", data={"evenmorefake": "data"}) + + mock_anyio_sleep.assert_not_called()