Skip to content

Commit

Permalink
retry action multiple times(3), lowered timeout (#1583)
Browse files Browse the repository at this point in the history
* retry action multiple times(3), lowered timeout

* retry action multiple times(3), lowered timeout

---------

Co-authored-by: spandan.mondal <[email protected]>
  • Loading branch information
hasinaxp and spandan.mondal authored Oct 25, 2024
1 parent 3ac4d3a commit b4f8da9
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 6 deletions.
96 changes: 90 additions & 6 deletions kairon/chat/actions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import ssl
import ujson as json
import logging
from typing import (
Expand All @@ -13,9 +15,10 @@
import aiohttp
import rasa.core
import rasa.shared.utils.io
from aiohttp import ContentTypeError
from aiohttp_retry import ExponentialRetry, RetryClient
from rasa.core.actions.constants import DEFAULT_SELECTIVE_DOMAIN, SELECTIVE_DOMAIN
from rasa.core.constants import (
DEFAULT_REQUEST_TIMEOUT,
COMPRESS_ACTION_SERVER_REQUEST_ENV_NAME,
DEFAULT_COMPRESS_ACTION_SERVER_REQUEST,
)
Expand All @@ -30,10 +33,10 @@
BotUttered,
)
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.exceptions import RasaException
from rasa.shared.exceptions import RasaException, FileNotFoundException
from rasa.shared.utils.schemas.events import EVENTS_SCHEMA
from rasa.utils.common import get_bool_env_variable
from rasa.utils.endpoints import EndpointConfig, ClientResponseError
from rasa.utils.endpoints import EndpointConfig, ClientResponseError, concat_url
from rasa.core.actions.action import Action, ActionExecutionRejection, create_bot_utterance

if TYPE_CHECKING:
Expand All @@ -42,12 +45,15 @@

logger = logging.getLogger(__name__)

ACTION_SERVER_REQUEST_TIMEOUT = 30 # seconds


class KRemoteAction(Action):
def __init__(self, name: Text, action_endpoint: Optional[EndpointConfig]) -> None:
def __init__(self, name: Text, action_endpoint: Optional[EndpointConfig], retry_attempts=3) -> None:

self._name = name
self.action_endpoint = action_endpoint
self.retry_attempts = retry_attempts

def _action_call_format(
self,
Expand Down Expand Up @@ -179,11 +185,13 @@ async def run(
modified_json = plugin_manager().hook.prefix_stripping_for_custom_actions(
json_body=json_body
)
response: Any = await self.action_endpoint.request(
response: Any = await KRemoteAction.multi_try_rasa_request(
endpoint_config=self.action_endpoint,
json=modified_json if modified_json else json_body,
method="post",
timeout=DEFAULT_REQUEST_TIMEOUT,
timeout=ACTION_SERVER_REQUEST_TIMEOUT,
compress=should_compress,
retry_attempts=self.retry_attempts
)
if modified_json:
plugin_manager().hook.prefixing_custom_actions_response(
Expand Down Expand Up @@ -239,5 +247,81 @@ async def run(
"Error: {}".format(self.name(), status, e)
)

@staticmethod
async def multi_try_rasa_request(
endpoint_config: EndpointConfig,
method: Text = "post",
subpath: Optional[Text] = None,
content_type: Optional[Text] = "application/json",
compress: bool = False,
retry_attempts: int = 3,
**kwargs: Any,
) -> Optional[Any]:
"""Send a HTTP request to the endpoint. Return json response, if available.
All additional arguments will get passed through
to aiohttp's `session.request`.
"""
# create the appropriate headers
headers = {}
if content_type:
headers["Content-Type"] = content_type

if "headers" in kwargs:
headers.update(kwargs["headers"])
del kwargs["headers"]

if endpoint_config.headers:
headers.update(endpoint_config.headers)

url = concat_url(endpoint_config.url, subpath)

sslcontext = None
if endpoint_config.cafile:
try:
sslcontext = ssl.create_default_context(cafile=endpoint_config.cafile)
except FileNotFoundError as e:
raise FileNotFoundException(
f"Failed to find certificate file, "
f"'{os.path.abspath(endpoint_config.cafile)}' does not exist."
) from e

if endpoint_config.basic_auth:
auth = aiohttp.BasicAuth(
endpoint_config.basic_auth["username"], endpoint_config.basic_auth["password"]
)
else:
auth = None

retry_options = ExponentialRetry(attempts=retry_attempts)
session = RetryClient(
raise_for_status=False, # Set this to True if you want to raise an exception for non-200 responses
retry_options=retry_options,
headers=endpoint_config.headers,
auth=auth,
timeout=aiohttp.ClientTimeout(total=ACTION_SERVER_REQUEST_TIMEOUT),
)

async with session:
async with session.request(
method,
url,
headers=headers,
params=endpoint_config.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status,
response.reason,
str(await response.content.read()),
)
try:
return await response.json()
except ContentTypeError:
return None

def name(self) -> Text:
return self._name
103 changes: 103 additions & 0 deletions tests/unit_test/action/kremote_action_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from aioresponses import aioresponses
from yarl import URL

from kairon.chat.actions import KRemoteAction
from rasa.utils.endpoints import EndpointConfig
from rasa.utils.endpoints import ClientResponseError


@pytest.mark.asyncio
async def test_multi_try_rasa_request_success():
endpoint_config = EndpointConfig(url="http://test.com", headers={"Authorization": "Bearer token"})
subpath = "path/to/resource"
method = "post"
response_data = {"key": "value"}

with aioresponses() as mock:
mock.post(f"http://test.com/{subpath}", payload=response_data, status=200)

result = await KRemoteAction.multi_try_rasa_request(
endpoint_config=endpoint_config, method=method, subpath=subpath
)

assert result == response_data

@pytest.mark.asyncio
async def test_multi_try_rasa_request_failure():
endpoint_config = EndpointConfig(url="http://test.com")
subpath = "invalid/path"
method = "post"

with aioresponses() as mock:
# Simulate a 404 error response with a JSON body as bytes
mock.post(
f"http://test.com/{subpath}",
status=404,
body=b'{"error": "Not Found"}'
)

with pytest.raises(ClientResponseError) as exc_info:
await KRemoteAction.multi_try_rasa_request(
endpoint_config=endpoint_config, method=method, subpath=subpath
)
assert exc_info.value.status == 404
assert exc_info.value.message == "Not Found"



@pytest.mark.asyncio
async def test_multi_try_rasa_request_ssl_error():
endpoint_config = EndpointConfig(url="https://test.com", cafile="invalid/path/to/cafile")
subpath = "path/to/resource"
method = "post"

with pytest.raises(FileNotFoundError):
await KRemoteAction.multi_try_rasa_request(
endpoint_config=endpoint_config, method=method, subpath=subpath
)


@pytest.mark.asyncio
async def test_multi_try_rasa_request_retry_success():
endpoint_config = EndpointConfig(url="http://test.com", headers={"Authorization": "Bearer token"})
subpath = "path/to/resource"
method = "post"
success_response_data = {"key": "value"}

with aioresponses() as mock:
mock.post(f"http://test.com/{subpath}", status=500)
mock.post(f"http://test.com/{subpath}", status=200, payload=success_response_data)

result = await KRemoteAction.multi_try_rasa_request(
endpoint_config=endpoint_config,
method=method,
subpath=subpath,
retry_attempts=2
)

assert result == success_response_data

@pytest.mark.asyncio
async def test_multi_try_rasa_request_retry_fail():
endpoint_config = EndpointConfig(url="http://test.com", headers={"Authorization": "Bearer token"})
subpath = "path/to/resource"
method = "post"
success_response_data = {"key": "value"}

with aioresponses() as mock:
mock.post(f"http://test.com/{subpath}", status=500)
mock.post(f"http://test.com/{subpath}", status=500)
mock.post(f"http://test.com/{subpath}", status=500)
mock.post(f"http://test.com/{subpath}", status=200, payload=success_response_data)

with pytest.raises(ClientResponseError) as exc_info:
result = await KRemoteAction.multi_try_rasa_request(
endpoint_config=endpoint_config,
method=method,
subpath=subpath,
retry_attempts=2
)

assert exc_info.value.status == 500

0 comments on commit b4f8da9

Please sign in to comment.