From b4f8da984f376d125b79ae8f948a2b2cc43f5547 Mon Sep 17 00:00:00 2001 From: Spandan Mondal Date: Fri, 25 Oct 2024 13:22:17 +0530 Subject: [PATCH] retry action multiple times(3), lowered timeout (#1583) * retry action multiple times(3), lowered timeout * retry action multiple times(3), lowered timeout --------- Co-authored-by: spandan.mondal --- kairon/chat/actions.py | 96 +++++++++++++++- tests/unit_test/action/kremote_action_test.py | 103 ++++++++++++++++++ 2 files changed, 193 insertions(+), 6 deletions(-) create mode 100644 tests/unit_test/action/kremote_action_test.py diff --git a/kairon/chat/actions.py b/kairon/chat/actions.py index f793c4ced..e84760b12 100644 --- a/kairon/chat/actions.py +++ b/kairon/chat/actions.py @@ -1,3 +1,5 @@ +import os +import ssl import ujson as json import logging from typing import ( @@ -13,9 +15,10 @@ import aiohttp import rasa.core import rasa.shared.utils.io +from aiohttp import ContentTypeError +from aiohttp_retry import ExponentialRetry, RetryClient from rasa.core.actions.constants import DEFAULT_SELECTIVE_DOMAIN, SELECTIVE_DOMAIN from rasa.core.constants import ( - DEFAULT_REQUEST_TIMEOUT, COMPRESS_ACTION_SERVER_REQUEST_ENV_NAME, DEFAULT_COMPRESS_ACTION_SERVER_REQUEST, ) @@ -30,10 +33,10 @@ BotUttered, ) from rasa.shared.core.trackers import DialogueStateTracker -from rasa.shared.exceptions import RasaException +from rasa.shared.exceptions import RasaException, FileNotFoundException from rasa.shared.utils.schemas.events import EVENTS_SCHEMA from rasa.utils.common import get_bool_env_variable -from rasa.utils.endpoints import EndpointConfig, ClientResponseError +from rasa.utils.endpoints import EndpointConfig, ClientResponseError, concat_url from rasa.core.actions.action import Action, ActionExecutionRejection, create_bot_utterance if TYPE_CHECKING: @@ -42,12 +45,15 @@ logger = logging.getLogger(__name__) +ACTION_SERVER_REQUEST_TIMEOUT = 30 # seconds + class KRemoteAction(Action): - def __init__(self, name: Text, action_endpoint: Optional[EndpointConfig]) -> None: + def __init__(self, name: Text, action_endpoint: Optional[EndpointConfig], retry_attempts=3) -> None: self._name = name self.action_endpoint = action_endpoint + self.retry_attempts = retry_attempts def _action_call_format( self, @@ -179,11 +185,13 @@ async def run( modified_json = plugin_manager().hook.prefix_stripping_for_custom_actions( json_body=json_body ) - response: Any = await self.action_endpoint.request( + response: Any = await KRemoteAction.multi_try_rasa_request( + endpoint_config=self.action_endpoint, json=modified_json if modified_json else json_body, method="post", - timeout=DEFAULT_REQUEST_TIMEOUT, + timeout=ACTION_SERVER_REQUEST_TIMEOUT, compress=should_compress, + retry_attempts=self.retry_attempts ) if modified_json: plugin_manager().hook.prefixing_custom_actions_response( @@ -239,5 +247,81 @@ async def run( "Error: {}".format(self.name(), status, e) ) + @staticmethod + async def multi_try_rasa_request( + endpoint_config: EndpointConfig, + method: Text = "post", + subpath: Optional[Text] = None, + content_type: Optional[Text] = "application/json", + compress: bool = False, + retry_attempts: int = 3, + **kwargs: Any, + ) -> Optional[Any]: + """Send a HTTP request to the endpoint. Return json response, if available. + + All additional arguments will get passed through + to aiohttp's `session.request`. + """ + # create the appropriate headers + headers = {} + if content_type: + headers["Content-Type"] = content_type + + if "headers" in kwargs: + headers.update(kwargs["headers"]) + del kwargs["headers"] + + if endpoint_config.headers: + headers.update(endpoint_config.headers) + + url = concat_url(endpoint_config.url, subpath) + + sslcontext = None + if endpoint_config.cafile: + try: + sslcontext = ssl.create_default_context(cafile=endpoint_config.cafile) + except FileNotFoundError as e: + raise FileNotFoundException( + f"Failed to find certificate file, " + f"'{os.path.abspath(endpoint_config.cafile)}' does not exist." + ) from e + + if endpoint_config.basic_auth: + auth = aiohttp.BasicAuth( + endpoint_config.basic_auth["username"], endpoint_config.basic_auth["password"] + ) + else: + auth = None + + retry_options = ExponentialRetry(attempts=retry_attempts) + session = RetryClient( + raise_for_status=False, # Set this to True if you want to raise an exception for non-200 responses + retry_options=retry_options, + headers=endpoint_config.headers, + auth=auth, + timeout=aiohttp.ClientTimeout(total=ACTION_SERVER_REQUEST_TIMEOUT), + ) + + async with session: + async with session.request( + method, + url, + headers=headers, + params=endpoint_config.combine_parameters(kwargs), + compress=compress, + ssl=sslcontext, + **kwargs, + ) as response: + if response.status >= 400: + raise ClientResponseError( + response.status, + response.reason, + str(await response.content.read()), + ) + try: + return await response.json() + except ContentTypeError: + return None + def name(self) -> Text: return self._name diff --git a/tests/unit_test/action/kremote_action_test.py b/tests/unit_test/action/kremote_action_test.py new file mode 100644 index 000000000..98d61b768 --- /dev/null +++ b/tests/unit_test/action/kremote_action_test.py @@ -0,0 +1,103 @@ +import pytest +from aioresponses import aioresponses +from yarl import URL + +from kairon.chat.actions import KRemoteAction +from rasa.utils.endpoints import EndpointConfig +from rasa.utils.endpoints import ClientResponseError + + +@pytest.mark.asyncio +async def test_multi_try_rasa_request_success(): + endpoint_config = EndpointConfig(url="http://test.com", headers={"Authorization": "Bearer token"}) + subpath = "path/to/resource" + method = "post" + response_data = {"key": "value"} + + with aioresponses() as mock: + mock.post(f"http://test.com/{subpath}", payload=response_data, status=200) + + result = await KRemoteAction.multi_try_rasa_request( + endpoint_config=endpoint_config, method=method, subpath=subpath + ) + + assert result == response_data + +@pytest.mark.asyncio +async def test_multi_try_rasa_request_failure(): + endpoint_config = EndpointConfig(url="http://test.com") + subpath = "invalid/path" + method = "post" + + with aioresponses() as mock: + # Simulate a 404 error response with a JSON body as bytes + mock.post( + f"http://test.com/{subpath}", + status=404, + body=b'{"error": "Not Found"}' + ) + + with pytest.raises(ClientResponseError) as exc_info: + await KRemoteAction.multi_try_rasa_request( + endpoint_config=endpoint_config, method=method, subpath=subpath + ) + assert exc_info.value.status == 404 + assert exc_info.value.message == "Not Found" + + + +@pytest.mark.asyncio +async def test_multi_try_rasa_request_ssl_error(): + endpoint_config = EndpointConfig(url="https://test.com", cafile="invalid/path/to/cafile") + subpath = "path/to/resource" + method = "post" + + with pytest.raises(FileNotFoundError): + await KRemoteAction.multi_try_rasa_request( + endpoint_config=endpoint_config, method=method, subpath=subpath + ) + + +@pytest.mark.asyncio +async def test_multi_try_rasa_request_retry_success(): + endpoint_config = EndpointConfig(url="http://test.com", headers={"Authorization": "Bearer token"}) + subpath = "path/to/resource" + method = "post" + success_response_data = {"key": "value"} + + with aioresponses() as mock: + mock.post(f"http://test.com/{subpath}", status=500) + mock.post(f"http://test.com/{subpath}", status=200, payload=success_response_data) + + result = await KRemoteAction.multi_try_rasa_request( + endpoint_config=endpoint_config, + method=method, + subpath=subpath, + retry_attempts=2 + ) + + assert result == success_response_data + +@pytest.mark.asyncio +async def test_multi_try_rasa_request_retry_fail(): + endpoint_config = EndpointConfig(url="http://test.com", headers={"Authorization": "Bearer token"}) + subpath = "path/to/resource" + method = "post" + success_response_data = {"key": "value"} + + with aioresponses() as mock: + mock.post(f"http://test.com/{subpath}", status=500) + mock.post(f"http://test.com/{subpath}", status=500) + mock.post(f"http://test.com/{subpath}", status=500) + mock.post(f"http://test.com/{subpath}", status=200, payload=success_response_data) + + with pytest.raises(ClientResponseError) as exc_info: + result = await KRemoteAction.multi_try_rasa_request( + endpoint_config=endpoint_config, + method=method, + subpath=subpath, + retry_attempts=2 + ) + + assert exc_info.value.status == 500 +