From 021e8fee7ecebe6a00edf92363288cfc79250b2a Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 15:56:50 +0100 Subject: [PATCH 01/11] update Sanic version and Sanic app instantiation for testing purposes --- poetry.lock | 35 ++++++++++++++++++----------------- pyproject.toml | 4 ++-- rasa_sdk/endpoint.py | 25 ++++++++++++++++++------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/poetry.lock b/poetry.lock index 46e670166..2f688c463 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1335,30 +1335,31 @@ files = [ [[package]] name = "sanic" -version = "21.12.2" +version = "22.12.0" description = "A web server and web framework that's written to go fast. Build fast. Run fast." optional = false python-versions = ">=3.7" files = [ - {file = "sanic-21.12.2-py3-none-any.whl", hash = "sha256:ef4edddfba46f2f8728400470f84deb91d9e5fc21cf2531512acf70638699620"}, - {file = "sanic-21.12.2.tar.gz", hash = "sha256:c426e15aac6984860c6d1221329be17e02e2dfed4ce0abf8532ab096026ee5e3"}, + {file = "sanic-22.12.0-py3-none-any.whl", hash = "sha256:84edf46cc17d13264ccec0ae6622e43087498f95644dc336ade74a2d5e6c88cb"}, + {file = "sanic-22.12.0.tar.gz", hash = "sha256:e5f81115f838956957046b6c52e7a08c1bd6e8ff530ee1376471eaf1579bfffa"}, ] [package.dependencies] aiofiles = ">=0.6.0" httptools = ">=0.0.10" -multidict = ">=5.0,<6.0" -sanic-routing = ">=0.7,<1.0" +multidict = ">=5.0,<7.0" +sanic-routing = ">=22.8.0" ujson = {version = ">=1.35", markers = "sys_platform != \"win32\" and implementation_name == \"cpython\""} -uvloop = {version = ">=0.5.3", markers = "sys_platform != \"win32\" and implementation_name == \"cpython\""} +uvloop = {version = ">=0.15.0", markers = "sys_platform != \"win32\" and implementation_name == \"cpython\""} websockets = ">=10.0" [package.extras] -all = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage (==5.3)", "cryptography", "docutils", "flake8", "gunicorn (==20.0.4)", "isort (>=5.0.0)", "m2r2", "mistune (<2.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==6.2.5)", "pytest-benchmark", "pytest-cov", "pytest-sanic", "pytest-sugar", "sanic-testing (>=0.7.0)", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] -dev = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage (==5.3)", "cryptography", "docutils", "flake8", "gunicorn (==20.0.4)", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==6.2.5)", "pytest-benchmark", "pytest-cov", "pytest-sanic", "pytest-sugar", "sanic-testing (>=0.7.0)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] -docs = ["docutils", "m2r2", "mistune (<2.0.0)", "pygments", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)"] +all = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage", "cryptography", "docutils", "enum-tools[sphinx]", "flake8", "isort (>=5.0.0)", "m2r2", "mistune (<2.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==7.1.*)", "pytest-benchmark", "pytest-sanic", "sanic-testing (>=22.9.0)", "slotscheck (>=0.8.0,<1)", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] +dev = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage", "cryptography", "docutils", "flake8", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==7.1.*)", "pytest-benchmark", "pytest-sanic", "sanic-testing (>=22.9.0)", "slotscheck (>=0.8.0,<1)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] +docs = ["docutils", "enum-tools[sphinx]", "m2r2", "mistune (<2.0.0)", "pygments", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)"] ext = ["sanic-ext"] -test = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage (==5.3)", "docutils", "flake8", "gunicorn (==20.0.4)", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==6.2.5)", "pytest-benchmark", "pytest-cov", "pytest-sanic", "pytest-sugar", "sanic-testing (>=0.7.0)", "types-ujson", "uvicorn (<0.15.0)"] +http3 = ["aioquic"] +test = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage", "docutils", "flake8", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==7.1.*)", "pytest-benchmark", "pytest-sanic", "sanic-testing (>=22.9.0)", "slotscheck (>=0.8.0,<1)", "types-ujson", "uvicorn (<0.15.0)"] [[package]] name = "sanic-cors" @@ -1376,24 +1377,24 @@ sanic = ">=21.9.3" [[package]] name = "sanic-routing" -version = "0.7.2" +version = "23.12.0" description = "Core routing component for Sanic" optional = false python-versions = "*" files = [ - {file = "sanic-routing-0.7.2.tar.gz", hash = "sha256:139ce88b3f054e7aa336e2ecc8459837092b103b275d3a97609a34092c55374d"}, - {file = "sanic_routing-0.7.2-py3-none-any.whl", hash = "sha256:523034ffd07aca056040e08de438269c9a880722eee1ace3a32e4f74b394d9aa"}, + {file = "sanic-routing-23.12.0.tar.gz", hash = "sha256:1dcadc62c443e48c852392dba03603f9862b6197fc4cba5bbefeb1ace0848b04"}, + {file = "sanic_routing-23.12.0-py3-none-any.whl", hash = "sha256:1558a72afcb9046ed3134a5edae02fc1552cff08f0fff2e8d5de0877ea43ed73"}, ] [[package]] name = "sanic-testing" -version = "22.6.0" +version = "22.12.0" description = "Core testing clients for Sanic" optional = false python-versions = "*" files = [ - {file = "sanic-testing-22.6.0.tar.gz", hash = "sha256:8f006d2332106539cd6f3da8a5c0d1f31472261f3293e43e2c9bbad605e72c5b"}, - {file = "sanic_testing-22.6.0-py3-none-any.whl", hash = "sha256:d84303e83066de7f18e8c3a0cd04512ba1517dbc31123f14e8aec318b22c008c"}, + {file = "sanic-testing-22.12.0.tar.gz", hash = "sha256:c9582c9bb9aabd82d3bf9fba2514a0274d0d741d84ce600e3ba2bef7b6c87aed"}, + {file = "sanic_testing-22.12.0-py3-none-any.whl", hash = "sha256:2cc3338207c6aab4cdc6b89264744a3d51ea66685fe1f30f81f9c376f3ee93a3"}, ] [package.dependencies] @@ -1850,4 +1851,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "77b9438ee3a1fcd07e5f6891a815463ee35d488b4d8ac1ddb43eaf6afee2c4c6" +content-hash = "0c854c620eb789aa7b5697a6430b0af5b2c5d9f29a2f045adfe71be21b35ed48" diff --git a/pyproject.toml b/pyproject.toml index 7e0fbadd2..624260c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ select = [ "D", "E", "F", "W", "RUF",] [tool.poetry.dependencies] python = ">=3.8,<3.11" coloredlogs = ">=10,<16" -sanic = "^21.12.0" +sanic = "^22.12" typing-extensions = ">=4.1.1,<5.0.0" Sanic-Cors = "^2.0.0" prompt-toolkit = "^3.0,<3.0.29" @@ -99,7 +99,7 @@ toml = "^0.10.0" pep440-version-utils = "^0.3.0" semantic_version = "^2.8.5" mypy = "^1.5" -sanic-testing = "^22.3.0, <22.9.0" +sanic-testing = "^22.12" [tool.ruff.pydocstyle] convention = "google" diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index eb922c51a..4b9ae08ba 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -5,10 +5,12 @@ import warnings import zlib import json +from functools import partial from typing import List, Text, Union, Optional from ssl import SSLContext from sanic import Sanic, response from sanic.response import HTTPResponse +from sanic.worker.loader import AppLoader # catching: # - all `pkg_resources` deprecation warning from multiple dependencies @@ -178,22 +180,31 @@ def run( ) -> None: """Starts the action endpoint server with given config values.""" logger.info("Starting action endpoint server...") - app = create_app( - action_package_name, - cors_origins=cors_origins, - auto_reload=auto_reload, - tracer_provider=tracer_provider, + loader = AppLoader( + factory=partial( + create_app, + action_package_name, + cors_origins=cors_origins, + auto_reload=auto_reload, + tracer_provider=tracer_provider, + ), ) + app = loader.load() app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout - ## Attach additional sanic extensions: listeners, middleware and routing + + # Attach additional sanic extensions: listeners, middleware and routing logger.info("Starting plugins...") plugin_manager().hook.attach_sanic_app_extensions(app=app) + ssl_context = create_ssl_context(ssl_certificate, ssl_keyfile, ssl_password) protocol = "https" if ssl_context else "http" host = os.environ.get("SANIC_HOST", "0.0.0.0") logger.info(f"Action endpoint is up and running on {protocol}://{host}:{port}") - app.run(host, port, ssl=ssl_context, workers=utils.number_of_sanic_workers()) + app.prepare( + host=host, port=port, ssl=ssl_context, workers=utils.number_of_sanic_workers() + ) + Sanic.serve(primary=app, app_loader=loader) if __name__ == "__main__": From 5a4b705a3cb39855f232d1f9dfaf848522824502 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 16:45:58 +0100 Subject: [PATCH 02/11] add listener to share context between Sanic workers --- rasa_sdk/endpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 4b9ae08ba..66160ba11 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -104,6 +104,10 @@ def create_app( executor = ActionExecutor() executor.register_package(action_package_name) + @app.main_process_start + async def main_process_start(app: Sanic): + app.shared_ctx.tracer_provider = tracer_provider + @app.get("/health") async def health(_) -> HTTPResponse: """Ping endpoint to check if the server is running and well.""" @@ -113,7 +117,7 @@ async def health(_) -> HTTPResponse: @app.post("/webhook") async def webhook(request: Request) -> HTTPResponse: """Webhook to retrieve action calls.""" - tracer, context, span_name = get_tracer_and_context(tracer_provider, request) + tracer, context, span_name = get_tracer_and_context(app.shared_ctx.tracer_provider, request) with tracer.start_as_current_span(span_name, context=context) as span: if request.headers.get("Content-Encoding") == "deflate": From 63245c8559884eeb430f183fd72fbf08ac631588 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 17:01:28 +0100 Subject: [PATCH 03/11] update sharing of context --- rasa_sdk/endpoint.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 66160ba11..193843d9a 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -44,7 +44,6 @@ def configure_cors( app: Sanic, cors_origins: Union[Text, List[Text], None] = "" ) -> None: """Configure CORS origins for the given app.""" - CORS( app, resources={r"/*": {"origins": cors_origins or ""}}, automatic_options=True ) @@ -56,7 +55,6 @@ def create_ssl_context( ssl_password: Optional[Text] = None, ) -> Optional[SSLContext]: """Create a SSL context if a certificate is passed.""" - if ssl_certificate: import ssl @@ -71,7 +69,6 @@ def create_ssl_context( def create_argument_parser(): """Parse all the command line arguments for the run script.""" - parser = argparse.ArgumentParser(description="starts the action endpoint") add_endpoint_arguments(parser) utils.add_logging_level_option_arguments(parser) @@ -79,6 +76,11 @@ def create_argument_parser(): return parser +async def load_tracer_provider(app: Sanic, tracer_provider: Optional[TracerProvider]): + """Load the tracer provider into the Sanic app.""" + app.shared_ctx.tracer_provider = tracer_provider + + def create_app( action_package_name: Union[Text, types.ModuleType], cors_origins: Union[Text, List[Text], None] = "*", @@ -104,9 +106,10 @@ def create_app( executor = ActionExecutor() executor.register_package(action_package_name) - @app.main_process_start - async def main_process_start(app: Sanic): - app.shared_ctx.tracer_provider = tracer_provider + app.register_listener( + partial(load_tracer_provider, tracer_provider=tracer_provider), + "main_process_start", + ) @app.get("/health") async def health(_) -> HTTPResponse: @@ -117,7 +120,9 @@ async def health(_) -> HTTPResponse: @app.post("/webhook") async def webhook(request: Request) -> HTTPResponse: """Webhook to retrieve action calls.""" - tracer, context, span_name = get_tracer_and_context(app.shared_ctx.tracer_provider, request) + tracer, context, span_name = get_tracer_and_context( + app.shared_ctx.tracer_provider, request + ) with tracer.start_as_current_span(span_name, context=context) as span: if request.headers.get("Content-Encoding") == "deflate": From 0f24207d150c0f443c6ac7b1b12e3aedb2b8cf71 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 20:31:51 +0100 Subject: [PATCH 04/11] update how the tracer provider listener is shared across sanic workers --- rasa_sdk/__main__.py | 4 +--- rasa_sdk/endpoint.py | 32 ++++++++++++++++++++------------ rasa_sdk/tracing/utils.py | 20 ++++++-------------- tests/test_endpoint.py | 31 +++++++++++++++---------------- 4 files changed, 42 insertions(+), 45 deletions(-) diff --git a/rasa_sdk/__main__.py b/rasa_sdk/__main__.py index 05b1805cb..67b9fb497 100644 --- a/rasa_sdk/__main__.py +++ b/rasa_sdk/__main__.py @@ -3,7 +3,6 @@ from rasa_sdk import utils from rasa_sdk.endpoint import create_argument_parser, run from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME -from rasa_sdk.tracing.utils import get_tracer_provider def main_from_args(args): @@ -18,7 +17,6 @@ def main_from_args(args): args.logging_config_file, ) utils.update_sanic_log_level() - tracer_provider = get_tracer_provider(args) run( args.actions, @@ -28,7 +26,7 @@ def main_from_args(args): args.ssl_keyfile, args.ssl_password, args.auto_reload, - tracer_provider, + args.endpoints, ) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 193843d9a..75630b707 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -26,16 +26,23 @@ category=DeprecationWarning, message="distutils Version classes are deprecated", ) - from opentelemetry.sdk.trace import TracerProvider from sanic_cors import CORS from sanic.request import Request from rasa_sdk import utils from rasa_sdk.cli.arguments import add_endpoint_arguments - from rasa_sdk.constants import DEFAULT_KEEP_ALIVE_TIMEOUT, DEFAULT_SERVER_PORT + from rasa_sdk.constants import ( + DEFAULT_ENDPOINTS_PATH, + DEFAULT_KEEP_ALIVE_TIMEOUT, + DEFAULT_SERVER_PORT, + ) from rasa_sdk.executor import ActionExecutor from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException from rasa_sdk.plugin import plugin_manager - from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes + from rasa_sdk.tracing.utils import ( + get_tracer_and_context, + get_tracer_provider, + set_span_attributes, + ) logger = logging.getLogger(__name__) @@ -76,8 +83,9 @@ def create_argument_parser(): return parser -async def load_tracer_provider(app: Sanic, tracer_provider: Optional[TracerProvider]): +async def load_tracer_provider(endpoints: str, app: Sanic): """Load the tracer provider into the Sanic app.""" + tracer_provider = get_tracer_provider(endpoints) app.shared_ctx.tracer_provider = tracer_provider @@ -85,7 +93,6 @@ def create_app( action_package_name: Union[Text, types.ModuleType], cors_origins: Union[Text, List[Text], None] = "*", auto_reload: bool = False, - tracer_provider: Optional[TracerProvider] = None, ) -> Sanic: """Create a Sanic application and return it. @@ -94,7 +101,6 @@ def create_app( from. cors_origins: CORS origins to allow. auto_reload: When `True`, auto-reloading of actions is enabled. - tracer_provider: Tracer provider to use for tracing. Returns: A new Sanic application ready to be run. @@ -106,10 +112,7 @@ def create_app( executor = ActionExecutor() executor.register_package(action_package_name) - app.register_listener( - partial(load_tracer_provider, tracer_provider=tracer_provider), - "main_process_start", - ) + app.shared_ctx.tracer_provider = None @app.get("/health") async def health(_) -> HTTPResponse: @@ -184,7 +187,7 @@ def run( ssl_keyfile: Optional[Text] = None, ssl_password: Optional[Text] = None, auto_reload: bool = False, - tracer_provider: Optional[TracerProvider] = None, + endpoints: str = DEFAULT_ENDPOINTS_PATH, keep_alive_timeout: int = DEFAULT_KEEP_ALIVE_TIMEOUT, ) -> None: """Starts the action endpoint server with given config values.""" @@ -195,12 +198,17 @@ def run( action_package_name, cors_origins=cors_origins, auto_reload=auto_reload, - tracer_provider=tracer_provider, ), ) app = loader.load() + app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout + app.register_listener( + partial(load_tracer_provider, endpoints=endpoints), + "main_process_start", + ) + # Attach additional sanic extensions: listeners, middleware and routing logger.info("Starting plugins...") plugin_manager().hook.attach_sanic_app_extensions(app=app) diff --git a/rasa_sdk/tracing/utils.py b/rasa_sdk/tracing/utils.py index cd3f66630..32b759b24 100644 --- a/rasa_sdk/tracing/utils.py +++ b/rasa_sdk/tracing/utils.py @@ -1,4 +1,3 @@ -import argparse from rasa_sdk.tracing import config from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -9,25 +8,18 @@ from typing import Optional, Tuple, Any, Text -def get_tracer_provider( - cmdline_arguments: argparse.Namespace, -) -> Optional[TracerProvider]: +def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: """Gets the tracer provider from the command line arguments.""" - tracer_provider = None - endpoints_file = "" - if "endpoints" in cmdline_arguments: - endpoints_file = cmdline_arguments.endpoints - - if endpoints_file is not None: - tracer_provider = config.get_tracer_provider(endpoints_file) - config.configure_tracing(tracer_provider) + tracer_provider = config.get_tracer_provider(endpoints_file) + config.configure_tracing(tracer_provider) + return tracer_provider def get_tracer_and_context( tracer_provider: Optional[TracerProvider], request: Request ) -> Tuple[Any, Any, Text]: - """Gets tracer and context""" + """Gets tracer and context.""" span_name = "create_app.webhook" if tracer_provider is None: tracer = trace.get_tracer(span_name) @@ -39,7 +31,7 @@ def get_tracer_and_context( def set_span_attributes(span: Any, action_call: dict) -> None: - """Sets span attributes""" + """Sets span attributes.""" tracker = action_call.get("tracker", {}) set_span_attributes = { "http.method": "POST", diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 9f4090ca1..101741a5a 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -4,49 +4,48 @@ import zlib import pytest +from sanic import Sanic import rasa_sdk.endpoint as ep from rasa_sdk.events import SlotSet from tests.conftest import get_stack -# noinspection PyTypeChecker -app = ep.create_app(None) - logger = logging.getLogger(__name__) +@pytest.fixture +def app(): + return ep.create_app("tests.test_actions") + + def test_endpoint_exit_for_unknown_actions_package(): with pytest.raises(SystemExit): ep.create_app("non-existing-actions-package") -def test_server_health_returns_200(): +def test_server_health_returns_200(app: Sanic): request, response = app.test_client.get("/health") assert response.status == 200 assert response.json == {"status": "ok"} -def test_server_list_actions_returns_200(): +def test_server_list_actions_returns_200(app: Sanic): request, response = app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 6 + assert len(response.json) == 4 - # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS expected = [ # defined in tests/test_actions.py {"name": "custom_async_action"}, {"name": "custom_action"}, {"name": "custom_action_exception"}, {"name": "custom_action_with_dialogue_stack"}, - # defined in tests/tracing/instrumentation/conftest.py - {"name": "mock_validation_action"}, - {"name": "mock_form_validation_action"}, ] assert response.json == expected -def test_server_webhook_unknown_action_returns_404(): +def test_server_webhook_unknown_action_returns_404(app: Sanic): data = { "next_action": "test_action_1", "tracker": {"sender_id": "1", "conversation_id": "default"}, @@ -55,7 +54,7 @@ def test_server_webhook_unknown_action_returns_404(): assert response.status == 404 -def test_server_webhook_handles_action_exception(): +def test_server_webhook_handles_action_exception(app: Sanic): data = { "next_action": "custom_action_exception", "tracker": {"sender_id": "1", "conversation_id": "default"}, @@ -66,7 +65,7 @@ def test_server_webhook_handles_action_exception(): assert response.json.get("request_body") == data -def test_server_webhook_custom_action_returns_200(): +def test_server_webhook_custom_action_returns_200(app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, @@ -78,7 +77,7 @@ def test_server_webhook_custom_action_returns_200(): assert response.status == 200 -def test_server_webhook_custom_async_action_returns_200(): +def test_server_webhook_custom_async_action_returns_200(app: Sanic): data = { "next_action": "custom_async_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, @@ -109,7 +108,7 @@ def test_arg_parser_actions_params_module_style(): assert cmdline_args.actions == "actions.act" -def test_server_webhook_custom_action_encoded_data_returns_200(): +def test_server_webhook_custom_action_encoded_data_returns_200(app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, @@ -135,7 +134,7 @@ def test_server_webhook_custom_action_encoded_data_returns_200(): ], ) def test_server_webhook_custom_action_with_dialogue_stack_returns_200( - stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]] + stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]], app: Sanic ): data = { "next_action": "custom_action_with_dialogue_stack", From 85e3233d5412e5cf09aecfac4caadc3dbacad674 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 21:58:40 +0100 Subject: [PATCH 05/11] update some tests --- tests/test_actions.py | 58 -------- .../__init.py__ => test_actions/__init__.py} | 0 tests/test_actions/test_actions.py | 132 ++++++++++++++++++ tests/test_endpoint.py | 51 ++++--- tests/test_executor.py | 12 +- tests/test_plugin.py | 5 +- .../action_fixtures/__init__.py | 0 tests/tracing/instrumentation/conftest.py | 70 +--------- .../test_form_validation_action.py | 2 +- tests/tracing/instrumentation/test_tracing.py | 16 ++- .../instrumentation/test_validation_action.py | 2 +- tests/tracing/test_utils.py | 7 +- 12 files changed, 189 insertions(+), 166 deletions(-) delete mode 100644 tests/test_actions.py rename tests/{tracing/instrumentation/action_fixtures/__init.py__ => test_actions/__init__.py} (100%) create mode 100644 tests/test_actions/test_actions.py create mode 100644 tests/tracing/instrumentation/action_fixtures/__init__.py diff --git a/tests/test_actions.py b/tests/test_actions.py deleted file mode 100644 index e054d361c..000000000 --- a/tests/test_actions.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List, Dict, Text, Any - -from rasa_sdk import Action, Tracker -from rasa_sdk.events import SlotSet -from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.types import DomainDict - - -class CustomAsyncAction(Action): - def name(cls) -> Text: - return "custom_async_action" - - async def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("test", "foo"), SlotSet("test2", "boo")] - - -class CustomAction(Action): - def name(cls) -> Text: - return "custom_action" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("test", "bar")] - - -class CustomActionRaisingException(Action): - def name(cls) -> Text: - return "custom_action_exception" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - raise Exception("test exception") - - -class CustomActionWithDialogueStack(Action): - def name(cls) -> Text: - return "custom_action_with_dialogue_stack" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("stack", tracker.stack)] diff --git a/tests/tracing/instrumentation/action_fixtures/__init.py__ b/tests/test_actions/__init__.py similarity index 100% rename from tests/tracing/instrumentation/action_fixtures/__init.py__ rename to tests/test_actions/__init__.py diff --git a/tests/test_actions/test_actions.py b/tests/test_actions/test_actions.py new file mode 100644 index 000000000..fcc7fa237 --- /dev/null +++ b/tests/test_actions/test_actions.py @@ -0,0 +1,132 @@ +from typing import List, Dict, Text, Any + +import pytest + +from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction +from rasa_sdk.events import SlotSet +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.types import DomainDict + + +class CustomAsyncAction(Action): + def name(cls) -> Text: + return "custom_async_action" + + async def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("test", "foo"), SlotSet("test2", "boo")] + + +class CustomAction(Action): + def name(cls) -> Text: + return "custom_action" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("test", "bar")] + + +class CustomActionRaisingException(Action): + def name(cls) -> Text: + return "custom_action_exception" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + raise Exception("test exception") + + +class CustomActionWithDialogueStack(Action): + def name(cls) -> Text: + return "custom_action_with_dialogue_stack" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("stack", tracker.stack)] + + +class MockFormValidationAction(FormValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: str) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + def name(self) -> str: + return "mock_form_validation_action" + + +class MockValidationAction(ValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + pass + + def name(self) -> Text: + return "mock_validation_action" + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + +class SubclassTestActionA(Action): + def name(self): + return "subclass_test_action_a" + + +class SubclassTestActionB(SubclassTestActionA): + def name(self): + return "subclass_test_action_b" diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 101741a5a..7912738a9 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -15,8 +15,8 @@ @pytest.fixture -def app(): - return ep.create_app("tests.test_actions") +def sanic_app(): + return ep.create_app("tests") def test_endpoint_exit_for_unknown_actions_package(): @@ -24,65 +24,72 @@ def test_endpoint_exit_for_unknown_actions_package(): ep.create_app("non-existing-actions-package") -def test_server_health_returns_200(app: Sanic): - request, response = app.test_client.get("/health") +def test_server_health_returns_200(sanic_app: Sanic): + request, response = sanic_app.test_client.get("/health") assert response.status == 200 assert response.json == {"status": "ok"} -def test_server_list_actions_returns_200(app: Sanic): - request, response = app.test_client.get("/actions") +def test_server_list_actions_returns_200(sanic_app: Sanic): + request, response = sanic_app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 4 - + assert len(response.json) == 9 + print(response.json) expected = [ - # defined in tests/test_actions.py + # defined in tests/test_actions {"name": "custom_async_action"}, {"name": "custom_action"}, {"name": "custom_action_exception"}, {"name": "custom_action_with_dialogue_stack"}, + {"name": "subclass_test_action_a"}, + {"name": "mock_validation_action"}, + {"name": "mock_form_validation_action"}, + # defined in tests/test_forms.py + {"name": "some_form"}, + # defined in tests/test_actions + {"name": "subclass_test_action_b"}, ] assert response.json == expected -def test_server_webhook_unknown_action_returns_404(app: Sanic): +def test_server_webhook_unknown_action_returns_404(sanic_app: Sanic): data = { "next_action": "test_action_1", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 404 -def test_server_webhook_handles_action_exception(app: Sanic): +def test_server_webhook_handles_action_exception(sanic_app: Sanic): data = { "next_action": "custom_action_exception", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 500 assert response.json.get("error") == "test exception" assert response.json.get("request_body") == data -def test_server_webhook_custom_action_returns_200(app: Sanic): +def test_server_webhook_custom_action_returns_200(sanic_app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") assert events == [SlotSet("test", "bar")] assert response.status == 200 -def test_server_webhook_custom_async_action_returns_200(app: Sanic): +def test_server_webhook_custom_async_action_returns_200(sanic_app: Sanic): data = { "next_action": "custom_async_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") assert events == [SlotSet("test", "foo"), SlotSet("test2", "boo")] @@ -108,14 +115,14 @@ def test_arg_parser_actions_params_module_style(): assert cmdline_args.actions == "actions.act" -def test_server_webhook_custom_action_encoded_data_returns_200(app: Sanic): +def test_server_webhook_custom_action_encoded_data_returns_200(sanic_app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, "domain": {"intents": ["greet", "goodbye"]}, } - request, response = app.test_client.post( + request, response = sanic_app.test_client.post( "/webhook", data=zlib.compress(json.dumps(data).encode()), headers={"Content-encoding": "deflate"}, @@ -134,13 +141,15 @@ def test_server_webhook_custom_action_encoded_data_returns_200(app: Sanic): ], ) def test_server_webhook_custom_action_with_dialogue_stack_returns_200( - stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]], app: Sanic + stack_state: Dict[Text, Any], + dialogue_stack: List[Dict[Text, Any]], + sanic_app: Sanic, ): data = { "next_action": "custom_action_with_dialogue_stack", "tracker": {"sender_id": "1", "conversation_id": "default", **stack_state}, } - _, response = app.test_client.post("/webhook", data=json.dumps(data)) + _, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") assert events == [SlotSet("stack", dialogue_stack)] diff --git a/tests/test_executor.py b/tests/test_executor.py index 3493a9983..bc917334b 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,11 +4,11 @@ import string import time -from rasa_sdk import Action from typing import Text, Optional, Generator import pytest from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from tests.test_actions.test_actions import SubclassTestActionA, SubclassTestActionB TEST_PACKAGE_BASE = "tests/executor_test_packages" @@ -237,16 +237,6 @@ async def test_reload_module( } -class SubclassTestActionA(Action): - def name(self): - return "subclass_test_action_a" - - -class SubclassTestActionB(SubclassTestActionA): - def name(self): - return "subclass_test_action_b" - - def test_load_subclasses(executor: ActionExecutor): executor.register_action(SubclassTestActionB) assert list(executor.actions) == ["subclass_test_action_b"] diff --git a/tests/test_plugin.py b/tests/test_plugin.py index c984f7e73..af08614ab 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -34,6 +34,7 @@ def test_plugin_attach_sanic_app_extension( monkeypatch.setattr( manager.hook, "attach_sanic_app_extensions", MagicMock(return_value=None) ) + monkeypatch.setattr("rasa_sdk.endpoint.Sanic.serve", MagicMock(return_value=None)) app_mock = MagicMock() # Create a MagicMock object to replace the create_app() method @@ -42,8 +43,8 @@ def test_plugin_attach_sanic_app_extension( # Set the create_app() method to return create_app_mock monkeypatch.setattr("rasa_sdk.endpoint.create_app", create_app_mock) - # Set the return value of app_mock.run() to None - app_mock.run.return_value = None + # Set the return value of app_mock.prepare() to None + app_mock.prepare.return_value = None with warnings.catch_warnings(): warnings.simplefilter("error") diff --git a/tests/tracing/instrumentation/action_fixtures/__init__.py b/tests/tracing/instrumentation/action_fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 5b7570f4b..e54961d56 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -5,10 +5,8 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from rasa_sdk.executor import ActionExecutor, CollectingDispatcher -from rasa_sdk.forms import ValidationAction, FormValidationAction -from rasa_sdk.types import ActionCall, DomainDict -from rasa_sdk import Tracker +from rasa_sdk.executor import ActionExecutor +from rasa_sdk.types import ActionCall @pytest.fixture(scope="session") @@ -23,7 +21,7 @@ def span_exporter(tracer_provider: TracerProvider) -> InMemorySpanExporter: return exporter -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def previous_num_captured_spans(span_exporter: InMemorySpanExporter) -> int: captured_spans = span_exporter.get_finished_spans() # type: ignore return len(captured_spans) @@ -52,65 +50,3 @@ def _create_api_response( events: List[Dict[Text, Any]], messages: List[Dict[Text, Any]] ) -> None: pass - - -class MockValidationAction(ValidationAction): - def __init__(self) -> None: - self.fail_if_undefined("run") - - def fail_if_undefined(self, method_name: Text) -> None: - if not ( - hasattr(self.__class__.__base__, method_name) - and callable(getattr(self.__class__.__base__, method_name)) - ): - pytest.fail( - f"method '{method_name}' not found in {self.__class__.__base__}. " - f"This likely means the method was renamed, which means the " - f"instrumentation needs to be adapted!" - ) - - async def run( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - pass - - def name(self) -> Text: - return "mock_validation_action" - - async def _extract_validation_events( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - return tracker.events - - -class MockFormValidationAction(FormValidationAction): - def __init__(self) -> None: - self.fail_if_undefined("run") - - def fail_if_undefined(self, method_name: Text) -> None: - if not ( - hasattr(self.__class__.__base__, method_name) - and callable(getattr(self.__class__.__base__, method_name)) - ): - pytest.fail( - f"method '{method_name}' not found in {self.__class__.__base__}. " - f"This likely means the method was renamed, which means the " - f"instrumentation needs to be adapted!" - ) - - async def _extract_validation_events( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - return tracker.events - - def name(self) -> Text: - return "mock_form_validation_action" diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py index 66374c6e3..a52154758 100644 --- a/tests/tracing/instrumentation/test_form_validation_action.py +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import MockFormValidationAction +from tests.test_actions.test_actions import MockFormValidationAction from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher from rasa_sdk.events import ActionExecuted, SlotSet diff --git a/tests/tracing/instrumentation/test_tracing.py b/tests/tracing/instrumentation/test_tracing.py index e1af9b46d..c85871b18 100644 --- a/tests/tracing/instrumentation/test_tracing.py +++ b/tests/tracing/instrumentation/test_tracing.py @@ -1,10 +1,13 @@ import json +from functools import partial + import pytest import rasa_sdk.endpoint as ep from typing import Sequence from opentelemetry.sdk.trace import TracerProvider +from pytest import MonkeyPatch from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -35,11 +38,20 @@ def test_server_webhook_custom_action_is_instrumented( previous_num_captured_spans: int, action_name: str, action_package: str, + monkeypatch: MonkeyPatch, ) -> None: """Tests that the custom action is instrumented.""" - + monkeypatch.setattr( + "rasa_sdk.endpoint.get_tracer_provider", lambda _: tracer_provider + ) data["next_action"] = action_name - app = ep.create_app(action_package, tracer_provider=tracer_provider) + app = ep.create_app(action_package) + + app.register_listener( + partial(ep.load_tracer_provider, endpoints="endpoints.yml"), + "main_process_start", + ) + _, response = app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 200 diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 860d5032d..7105bd03c 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import ( +from tests.test_actions.test_actions import ( MockValidationAction, ) from rasa_sdk import Tracker diff --git a/tests/tracing/test_utils.py b/tests/tracing/test_utils.py index f4ebe5729..00683c040 100644 --- a/tests/tracing/test_utils.py +++ b/tests/tracing/test_utils.py @@ -20,10 +20,11 @@ def test_get_tracer_provider_returns_none_if_no_endpoints_file() -> None: """Tests that get_tracer_provider returns None if no endpoints file is provided.""" parser = argparse.ArgumentParser() parser.add_argument("--random_args", type=str, default=None) + parser.add_argument("--endpoints", default="endpoints.yml") args = parser.parse_args(["--random_args", "random text"]) - tracer_provider = get_tracer_provider(args) + tracer_provider = get_tracer_provider(args.endpoints) assert tracer_provider is None @@ -36,7 +37,7 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "no_tracing.yml") args = parser.parse_args(["--endpoints", endpoints_file]) - tracer_provider = get_tracer_provider(args) + tracer_provider = get_tracer_provider(args.endpoints) assert tracer_provider is None @@ -50,7 +51,7 @@ def test_get_tracer_provider_returns_provider() -> None: endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "jaeger_endpoints.yml") args = parser.parse_args(["--endpoints", endpoints_file]) - tracer_provider = get_tracer_provider(args) + tracer_provider = get_tracer_provider(args.endpoints) assert tracer_provider is not None assert isinstance(tracer_provider, TracerProvider) From b1ee81661933b86a6524bd53f8c1c62b4988ad03 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 22:01:29 +0100 Subject: [PATCH 06/11] revert scope --- tests/tracing/instrumentation/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index e54961d56..0e2ea108d 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -21,7 +21,7 @@ def span_exporter(tracer_provider: TracerProvider) -> InMemorySpanExporter: return exporter -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def previous_num_captured_spans(span_exporter: InMemorySpanExporter) -> int: captured_spans = span_exporter.get_finished_spans() # type: ignore return len(captured_spans) From 24ad595d9bbb9eae3f324c6d7955ac3a853b20ad Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 14 May 2024 22:14:59 +0100 Subject: [PATCH 07/11] update listener --- rasa_sdk/endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 75630b707..eab013d36 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -205,7 +205,7 @@ def run( app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout app.register_listener( - partial(load_tracer_provider, endpoints=endpoints), + partial(load_tracer_provider, endpoints), "main_process_start", ) From 158634f67a00d96629dc62c84b3da895fe93655d Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 15 May 2024 09:36:30 +0100 Subject: [PATCH 08/11] change listener and context types --- rasa_sdk/endpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index eab013d36..7237440f0 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -86,7 +86,7 @@ def create_argument_parser(): async def load_tracer_provider(endpoints: str, app: Sanic): """Load the tracer provider into the Sanic app.""" tracer_provider = get_tracer_provider(endpoints) - app.shared_ctx.tracer_provider = tracer_provider + app.ctx.tracer_provider = tracer_provider def create_app( @@ -112,7 +112,7 @@ def create_app( executor = ActionExecutor() executor.register_package(action_package_name) - app.shared_ctx.tracer_provider = None + app.ctx.tracer_provider = None @app.get("/health") async def health(_) -> HTTPResponse: @@ -124,7 +124,7 @@ async def health(_) -> HTTPResponse: async def webhook(request: Request) -> HTTPResponse: """Webhook to retrieve action calls.""" tracer, context, span_name = get_tracer_and_context( - app.shared_ctx.tracer_provider, request + request.app.ctx.tracer_provider, request ) with tracer.start_as_current_span(span_name, context=context) as span: @@ -206,7 +206,7 @@ def run( app.register_listener( partial(load_tracer_provider, endpoints), - "main_process_start", + "before_server_start", ) # Attach additional sanic extensions: listeners, middleware and routing From 29d5272193a958ce65016e2105aedd6355acb08a Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 15 May 2024 13:16:04 +0100 Subject: [PATCH 09/11] use legacy arg in app.run --- rasa_sdk/endpoint.py | 9 ++++++--- tests/tracing/instrumentation/test_tracing.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index 7237440f0..ec8d4fb86 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -218,10 +218,13 @@ def run( host = os.environ.get("SANIC_HOST", "0.0.0.0") logger.info(f"Action endpoint is up and running on {protocol}://{host}:{port}") - app.prepare( - host=host, port=port, ssl=ssl_context, workers=utils.number_of_sanic_workers() + app.run( + host=host, + port=port, + ssl=ssl_context, + workers=utils.number_of_sanic_workers(), + legacy=True, ) - Sanic.serve(primary=app, app_loader=loader) if __name__ == "__main__": diff --git a/tests/tracing/instrumentation/test_tracing.py b/tests/tracing/instrumentation/test_tracing.py index c85871b18..f1c7a411e 100644 --- a/tests/tracing/instrumentation/test_tracing.py +++ b/tests/tracing/instrumentation/test_tracing.py @@ -48,8 +48,8 @@ def test_server_webhook_custom_action_is_instrumented( app = ep.create_app(action_package) app.register_listener( - partial(ep.load_tracer_provider, endpoints="endpoints.yml"), - "main_process_start", + partial(ep.load_tracer_provider, "endpoints.yml"), + "before_server_start", ) _, response = app.test_client.post("/webhook", data=json.dumps(data)) From 87e5341f51862c8f54ab346cc94c875f4b508ac6 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:41:11 +0100 Subject: [PATCH 10/11] add changelog entry --- changelog/1103.misc.md | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 changelog/1103.misc.md diff --git a/changelog/1103.misc.md b/changelog/1103.misc.md new file mode 100644 index 000000000..68115810d --- /dev/null +++ b/changelog/1103.misc.md @@ -0,0 +1,2 @@ +Upgrade Sanic to v22.12LTS. +Refactor loading of tracer provider to be triggered by Sanic `before_server_start` event listener. From 609e2b8ea66644ce142466fc58b9a628a0a07618 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 4 Jun 2024 15:29:10 +0100 Subject: [PATCH 11/11] move custom actions used in tests to conftest.py --- tests/conftest.py | 133 ++++++++++++++++++ tests/test_actions/__init__.py | 0 tests/test_actions/test_actions.py | 132 ----------------- tests/test_endpoint.py | 4 +- tests/test_executor.py | 2 +- .../test_form_validation_action.py | 2 +- .../instrumentation/test_validation_action.py | 2 +- 7 files changed, 138 insertions(+), 137 deletions(-) delete mode 100644 tests/test_actions/__init__.py delete mode 100644 tests/test_actions/test_actions.py diff --git a/tests/conftest.py b/tests/conftest.py index 4b3c5abc8..22eac5ff1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,14 @@ +from typing import List, Dict, Text, Any + from sanic import Sanic +import pytest + +from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction +from rasa_sdk.events import SlotSet +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.types import DomainDict + Sanic.test_mode = True @@ -14,3 +23,127 @@ def get_stack(): } ] return dialogue_stack + + +class CustomAsyncAction(Action): + def name(cls) -> Text: + return "custom_async_action" + + async def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("test", "foo"), SlotSet("test2", "boo")] + + +class CustomAction(Action): + def name(cls) -> Text: + return "custom_action" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("test", "bar")] + + +class CustomActionRaisingException(Action): + def name(cls) -> Text: + return "custom_action_exception" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + raise Exception("test exception") + + +class CustomActionWithDialogueStack(Action): + def name(cls) -> Text: + return "custom_action_with_dialogue_stack" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("stack", tracker.stack)] + + +class MockFormValidationAction(FormValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: str) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + def name(self) -> str: + return "mock_form_validation_action" + + +class MockValidationAction(ValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + pass + + def name(self) -> Text: + return "mock_validation_action" + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + +class SubclassTestActionA(Action): + def name(self): + return "subclass_test_action_a" + + +class SubclassTestActionB(SubclassTestActionA): + def name(self): + return "subclass_test_action_b" diff --git a/tests/test_actions/__init__.py b/tests/test_actions/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/test_actions/test_actions.py b/tests/test_actions/test_actions.py deleted file mode 100644 index fcc7fa237..000000000 --- a/tests/test_actions/test_actions.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import List, Dict, Text, Any - -import pytest - -from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction -from rasa_sdk.events import SlotSet -from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.types import DomainDict - - -class CustomAsyncAction(Action): - def name(cls) -> Text: - return "custom_async_action" - - async def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("test", "foo"), SlotSet("test2", "boo")] - - -class CustomAction(Action): - def name(cls) -> Text: - return "custom_action" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("test", "bar")] - - -class CustomActionRaisingException(Action): - def name(cls) -> Text: - return "custom_action_exception" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - raise Exception("test exception") - - -class CustomActionWithDialogueStack(Action): - def name(cls) -> Text: - return "custom_action_with_dialogue_stack" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("stack", tracker.stack)] - - -class MockFormValidationAction(FormValidationAction): - def __init__(self) -> None: - self.fail_if_undefined("run") - - def fail_if_undefined(self, method_name: str) -> None: - if not ( - hasattr(self.__class__.__base__, method_name) - and callable(getattr(self.__class__.__base__, method_name)) - ): - pytest.fail( - f"method '{method_name}' not found in {self.__class__.__base__}. " - f"This likely means the method was renamed, which means the " - f"instrumentation needs to be adapted!" - ) - - async def _extract_validation_events( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - return tracker.events - - def name(self) -> str: - return "mock_form_validation_action" - - -class MockValidationAction(ValidationAction): - def __init__(self) -> None: - self.fail_if_undefined("run") - - def fail_if_undefined(self, method_name: Text) -> None: - if not ( - hasattr(self.__class__.__base__, method_name) - and callable(getattr(self.__class__.__base__, method_name)) - ): - pytest.fail( - f"method '{method_name}' not found in {self.__class__.__base__}. " - f"This likely means the method was renamed, which means the " - f"instrumentation needs to be adapted!" - ) - - async def run( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - pass - - def name(self) -> Text: - return "mock_validation_action" - - async def _extract_validation_events( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - return tracker.events - - -class SubclassTestActionA(Action): - def name(self): - return "subclass_test_action_a" - - -class SubclassTestActionB(SubclassTestActionA): - def name(self): - return "subclass_test_action_b" diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 7912738a9..315a4f3a9 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -36,7 +36,7 @@ def test_server_list_actions_returns_200(sanic_app: Sanic): assert len(response.json) == 9 print(response.json) expected = [ - # defined in tests/test_actions + # defined in tests/conftest.py {"name": "custom_async_action"}, {"name": "custom_action"}, {"name": "custom_action_exception"}, @@ -46,7 +46,7 @@ def test_server_list_actions_returns_200(sanic_app: Sanic): {"name": "mock_form_validation_action"}, # defined in tests/test_forms.py {"name": "some_form"}, - # defined in tests/test_actions + # defined in tests/conftest.py {"name": "subclass_test_action_b"}, ] assert response.json == expected diff --git a/tests/test_executor.py b/tests/test_executor.py index bc917334b..cbad5aef2 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -8,7 +8,7 @@ import pytest from rasa_sdk.executor import ActionExecutor, CollectingDispatcher -from tests.test_actions.test_actions import SubclassTestActionA, SubclassTestActionB +from tests.conftest import SubclassTestActionA, SubclassTestActionB TEST_PACKAGE_BASE = "tests/executor_test_packages" diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py index a52154758..37c849f0d 100644 --- a/tests/tracing/instrumentation/test_form_validation_action.py +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.test_actions.test_actions import MockFormValidationAction +from tests.conftest import MockFormValidationAction from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher from rasa_sdk.events import ActionExecuted, SlotSet diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 7105bd03c..6e11c1e68 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.test_actions.test_actions import ( +from tests.conftest import ( MockValidationAction, ) from rasa_sdk import Tracker