diff --git a/changelog/1103.misc.md b/changelog/1103.misc.md new file mode 100644 index 000000000..68115810d --- /dev/null +++ b/changelog/1103.misc.md @@ -0,0 +1,2 @@ +Upgrade Sanic to v22.12LTS. +Refactor loading of tracer provider to be triggered by Sanic `before_server_start` event listener. diff --git a/poetry.lock b/poetry.lock index 46e670166..2f688c463 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1335,30 +1335,31 @@ files = [ [[package]] name = "sanic" -version = "21.12.2" +version = "22.12.0" description = "A web server and web framework that's written to go fast. Build fast. Run fast." optional = false python-versions = ">=3.7" files = [ - {file = "sanic-21.12.2-py3-none-any.whl", hash = "sha256:ef4edddfba46f2f8728400470f84deb91d9e5fc21cf2531512acf70638699620"}, - {file = "sanic-21.12.2.tar.gz", hash = "sha256:c426e15aac6984860c6d1221329be17e02e2dfed4ce0abf8532ab096026ee5e3"}, + {file = "sanic-22.12.0-py3-none-any.whl", hash = "sha256:84edf46cc17d13264ccec0ae6622e43087498f95644dc336ade74a2d5e6c88cb"}, + {file = "sanic-22.12.0.tar.gz", hash = "sha256:e5f81115f838956957046b6c52e7a08c1bd6e8ff530ee1376471eaf1579bfffa"}, ] [package.dependencies] aiofiles = ">=0.6.0" httptools = ">=0.0.10" -multidict = ">=5.0,<6.0" -sanic-routing = ">=0.7,<1.0" +multidict = ">=5.0,<7.0" +sanic-routing = ">=22.8.0" ujson = {version = ">=1.35", markers = "sys_platform != \"win32\" and implementation_name == \"cpython\""} -uvloop = {version = ">=0.5.3", markers = "sys_platform != \"win32\" and implementation_name == \"cpython\""} +uvloop = {version = ">=0.15.0", markers = "sys_platform != \"win32\" and implementation_name == \"cpython\""} websockets = ">=10.0" [package.extras] -all = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage (==5.3)", "cryptography", "docutils", "flake8", "gunicorn (==20.0.4)", "isort (>=5.0.0)", "m2r2", "mistune (<2.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==6.2.5)", "pytest-benchmark", "pytest-cov", "pytest-sanic", "pytest-sugar", "sanic-testing (>=0.7.0)", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] -dev = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage (==5.3)", "cryptography", "docutils", "flake8", "gunicorn (==20.0.4)", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==6.2.5)", "pytest-benchmark", "pytest-cov", "pytest-sanic", "pytest-sugar", "sanic-testing (>=0.7.0)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] -docs = ["docutils", "m2r2", "mistune (<2.0.0)", "pygments", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)"] +all = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage", "cryptography", "docutils", "enum-tools[sphinx]", "flake8", "isort (>=5.0.0)", "m2r2", "mistune (<2.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==7.1.*)", "pytest-benchmark", "pytest-sanic", "sanic-testing (>=22.9.0)", "slotscheck (>=0.8.0,<1)", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] +dev = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage", "cryptography", "docutils", "flake8", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==7.1.*)", "pytest-benchmark", "pytest-sanic", "sanic-testing (>=22.9.0)", "slotscheck (>=0.8.0,<1)", "towncrier", "tox", "types-ujson", "uvicorn (<0.15.0)"] +docs = ["docutils", "enum-tools[sphinx]", "m2r2", "mistune (<2.0.0)", "pygments", "sphinx (>=2.1.2)", "sphinx-rtd-theme (>=0.4.3)"] ext = ["sanic-ext"] -test = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage (==5.3)", "docutils", "flake8", "gunicorn (==20.0.4)", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==6.2.5)", "pytest-benchmark", "pytest-cov", "pytest-sanic", "pytest-sugar", "sanic-testing (>=0.7.0)", "types-ujson", "uvicorn (<0.15.0)"] +http3 = ["aioquic"] +test = ["bandit", "beautifulsoup4", "black", "chardet (==3.*)", "coverage", "docutils", "flake8", "isort (>=5.0.0)", "mypy (>=0.901,<0.910)", "pygments", "pytest (==7.1.*)", "pytest-benchmark", "pytest-sanic", "sanic-testing (>=22.9.0)", "slotscheck (>=0.8.0,<1)", "types-ujson", "uvicorn (<0.15.0)"] [[package]] name = "sanic-cors" @@ -1376,24 +1377,24 @@ sanic = ">=21.9.3" [[package]] name = "sanic-routing" -version = "0.7.2" +version = "23.12.0" description = "Core routing component for Sanic" optional = false python-versions = "*" files = [ - {file = "sanic-routing-0.7.2.tar.gz", hash = "sha256:139ce88b3f054e7aa336e2ecc8459837092b103b275d3a97609a34092c55374d"}, - {file = "sanic_routing-0.7.2-py3-none-any.whl", hash = "sha256:523034ffd07aca056040e08de438269c9a880722eee1ace3a32e4f74b394d9aa"}, + {file = "sanic-routing-23.12.0.tar.gz", hash = "sha256:1dcadc62c443e48c852392dba03603f9862b6197fc4cba5bbefeb1ace0848b04"}, + {file = "sanic_routing-23.12.0-py3-none-any.whl", hash = "sha256:1558a72afcb9046ed3134a5edae02fc1552cff08f0fff2e8d5de0877ea43ed73"}, ] [[package]] name = "sanic-testing" -version = "22.6.0" +version = "22.12.0" description = "Core testing clients for Sanic" optional = false python-versions = "*" files = [ - {file = "sanic-testing-22.6.0.tar.gz", hash = "sha256:8f006d2332106539cd6f3da8a5c0d1f31472261f3293e43e2c9bbad605e72c5b"}, - {file = "sanic_testing-22.6.0-py3-none-any.whl", hash = "sha256:d84303e83066de7f18e8c3a0cd04512ba1517dbc31123f14e8aec318b22c008c"}, + {file = "sanic-testing-22.12.0.tar.gz", hash = "sha256:c9582c9bb9aabd82d3bf9fba2514a0274d0d741d84ce600e3ba2bef7b6c87aed"}, + {file = "sanic_testing-22.12.0-py3-none-any.whl", hash = "sha256:2cc3338207c6aab4cdc6b89264744a3d51ea66685fe1f30f81f9c376f3ee93a3"}, ] [package.dependencies] @@ -1850,4 +1851,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "77b9438ee3a1fcd07e5f6891a815463ee35d488b4d8ac1ddb43eaf6afee2c4c6" +content-hash = "0c854c620eb789aa7b5697a6430b0af5b2c5d9f29a2f045adfe71be21b35ed48" diff --git a/pyproject.toml b/pyproject.toml index 6bfb676b2..0ad2b41e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ select = [ "D", "E", "F", "W", "RUF",] [tool.poetry.dependencies] python = ">=3.8,<3.11" coloredlogs = ">=10,<16" -sanic = "^21.12.0" +sanic = "^22.12" typing-extensions = ">=4.1.1,<5.0.0" Sanic-Cors = "^2.0.0" prompt-toolkit = "^3.0,<3.0.29" @@ -99,7 +99,7 @@ toml = "^0.10.0" pep440-version-utils = "^0.3.0" semantic_version = "^2.8.5" mypy = "^1.5" -sanic-testing = "^22.3.0, <22.9.0" +sanic-testing = "^22.12" [tool.ruff.pydocstyle] convention = "google" diff --git a/rasa_sdk/__main__.py b/rasa_sdk/__main__.py index 05b1805cb..67b9fb497 100644 --- a/rasa_sdk/__main__.py +++ b/rasa_sdk/__main__.py @@ -3,7 +3,6 @@ from rasa_sdk import utils from rasa_sdk.endpoint import create_argument_parser, run from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME -from rasa_sdk.tracing.utils import get_tracer_provider def main_from_args(args): @@ -18,7 +17,6 @@ def main_from_args(args): args.logging_config_file, ) utils.update_sanic_log_level() - tracer_provider = get_tracer_provider(args) run( args.actions, @@ -28,7 +26,7 @@ def main_from_args(args): args.ssl_keyfile, args.ssl_password, args.auto_reload, - tracer_provider, + args.endpoints, ) diff --git a/rasa_sdk/endpoint.py b/rasa_sdk/endpoint.py index eb922c51a..ec8d4fb86 100644 --- a/rasa_sdk/endpoint.py +++ b/rasa_sdk/endpoint.py @@ -5,10 +5,12 @@ import warnings import zlib import json +from functools import partial from typing import List, Text, Union, Optional from ssl import SSLContext from sanic import Sanic, response from sanic.response import HTTPResponse +from sanic.worker.loader import AppLoader # catching: # - all `pkg_resources` deprecation warning from multiple dependencies @@ -24,16 +26,23 @@ category=DeprecationWarning, message="distutils Version classes are deprecated", ) - from opentelemetry.sdk.trace import TracerProvider from sanic_cors import CORS from sanic.request import Request from rasa_sdk import utils from rasa_sdk.cli.arguments import add_endpoint_arguments - from rasa_sdk.constants import DEFAULT_KEEP_ALIVE_TIMEOUT, DEFAULT_SERVER_PORT + from rasa_sdk.constants import ( + DEFAULT_ENDPOINTS_PATH, + DEFAULT_KEEP_ALIVE_TIMEOUT, + DEFAULT_SERVER_PORT, + ) from rasa_sdk.executor import ActionExecutor from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException from rasa_sdk.plugin import plugin_manager - from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes + from rasa_sdk.tracing.utils import ( + get_tracer_and_context, + get_tracer_provider, + set_span_attributes, + ) logger = logging.getLogger(__name__) @@ -42,7 +51,6 @@ def configure_cors( app: Sanic, cors_origins: Union[Text, List[Text], None] = "" ) -> None: """Configure CORS origins for the given app.""" - CORS( app, resources={r"/*": {"origins": cors_origins or ""}}, automatic_options=True ) @@ -54,7 +62,6 @@ def create_ssl_context( ssl_password: Optional[Text] = None, ) -> Optional[SSLContext]: """Create a SSL context if a certificate is passed.""" - if ssl_certificate: import ssl @@ -69,7 +76,6 @@ def create_ssl_context( def create_argument_parser(): """Parse all the command line arguments for the run script.""" - parser = argparse.ArgumentParser(description="starts the action endpoint") add_endpoint_arguments(parser) utils.add_logging_level_option_arguments(parser) @@ -77,11 +83,16 @@ def create_argument_parser(): return parser +async def load_tracer_provider(endpoints: str, app: Sanic): + """Load the tracer provider into the Sanic app.""" + tracer_provider = get_tracer_provider(endpoints) + app.ctx.tracer_provider = tracer_provider + + def create_app( action_package_name: Union[Text, types.ModuleType], cors_origins: Union[Text, List[Text], None] = "*", auto_reload: bool = False, - tracer_provider: Optional[TracerProvider] = None, ) -> Sanic: """Create a Sanic application and return it. @@ -90,7 +101,6 @@ def create_app( from. cors_origins: CORS origins to allow. auto_reload: When `True`, auto-reloading of actions is enabled. - tracer_provider: Tracer provider to use for tracing. Returns: A new Sanic application ready to be run. @@ -102,6 +112,8 @@ def create_app( executor = ActionExecutor() executor.register_package(action_package_name) + app.ctx.tracer_provider = None + @app.get("/health") async def health(_) -> HTTPResponse: """Ping endpoint to check if the server is running and well.""" @@ -111,7 +123,9 @@ async def health(_) -> HTTPResponse: @app.post("/webhook") async def webhook(request: Request) -> HTTPResponse: """Webhook to retrieve action calls.""" - tracer, context, span_name = get_tracer_and_context(tracer_provider, request) + tracer, context, span_name = get_tracer_and_context( + request.app.ctx.tracer_provider, request + ) with tracer.start_as_current_span(span_name, context=context) as span: if request.headers.get("Content-Encoding") == "deflate": @@ -173,27 +187,44 @@ def run( ssl_keyfile: Optional[Text] = None, ssl_password: Optional[Text] = None, auto_reload: bool = False, - tracer_provider: Optional[TracerProvider] = None, + endpoints: str = DEFAULT_ENDPOINTS_PATH, keep_alive_timeout: int = DEFAULT_KEEP_ALIVE_TIMEOUT, ) -> None: """Starts the action endpoint server with given config values.""" logger.info("Starting action endpoint server...") - app = create_app( - action_package_name, - cors_origins=cors_origins, - auto_reload=auto_reload, - tracer_provider=tracer_provider, + loader = AppLoader( + factory=partial( + create_app, + action_package_name, + cors_origins=cors_origins, + auto_reload=auto_reload, + ), ) + app = loader.load() + app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout - ## Attach additional sanic extensions: listeners, middleware and routing + + app.register_listener( + partial(load_tracer_provider, endpoints), + "before_server_start", + ) + + # Attach additional sanic extensions: listeners, middleware and routing logger.info("Starting plugins...") plugin_manager().hook.attach_sanic_app_extensions(app=app) + ssl_context = create_ssl_context(ssl_certificate, ssl_keyfile, ssl_password) protocol = "https" if ssl_context else "http" host = os.environ.get("SANIC_HOST", "0.0.0.0") logger.info(f"Action endpoint is up and running on {protocol}://{host}:{port}") - app.run(host, port, ssl=ssl_context, workers=utils.number_of_sanic_workers()) + app.run( + host=host, + port=port, + ssl=ssl_context, + workers=utils.number_of_sanic_workers(), + legacy=True, + ) if __name__ == "__main__": diff --git a/rasa_sdk/tracing/utils.py b/rasa_sdk/tracing/utils.py index cd3f66630..32b759b24 100644 --- a/rasa_sdk/tracing/utils.py +++ b/rasa_sdk/tracing/utils.py @@ -1,4 +1,3 @@ -import argparse from rasa_sdk.tracing import config from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -9,25 +8,18 @@ from typing import Optional, Tuple, Any, Text -def get_tracer_provider( - cmdline_arguments: argparse.Namespace, -) -> Optional[TracerProvider]: +def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: """Gets the tracer provider from the command line arguments.""" - tracer_provider = None - endpoints_file = "" - if "endpoints" in cmdline_arguments: - endpoints_file = cmdline_arguments.endpoints - - if endpoints_file is not None: - tracer_provider = config.get_tracer_provider(endpoints_file) - config.configure_tracing(tracer_provider) + tracer_provider = config.get_tracer_provider(endpoints_file) + config.configure_tracing(tracer_provider) + return tracer_provider def get_tracer_and_context( tracer_provider: Optional[TracerProvider], request: Request ) -> Tuple[Any, Any, Text]: - """Gets tracer and context""" + """Gets tracer and context.""" span_name = "create_app.webhook" if tracer_provider is None: tracer = trace.get_tracer(span_name) @@ -39,7 +31,7 @@ def get_tracer_and_context( def set_span_attributes(span: Any, action_call: dict) -> None: - """Sets span attributes""" + """Sets span attributes.""" tracker = action_call.get("tracker", {}) set_span_attributes = { "http.method": "POST", diff --git a/tests/conftest.py b/tests/conftest.py index 4b3c5abc8..22eac5ff1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,14 @@ +from typing import List, Dict, Text, Any + from sanic import Sanic +import pytest + +from rasa_sdk import Action, FormValidationAction, Tracker, ValidationAction +from rasa_sdk.events import SlotSet +from rasa_sdk.executor import CollectingDispatcher +from rasa_sdk.types import DomainDict + Sanic.test_mode = True @@ -14,3 +23,127 @@ def get_stack(): } ] return dialogue_stack + + +class CustomAsyncAction(Action): + def name(cls) -> Text: + return "custom_async_action" + + async def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("test", "foo"), SlotSet("test2", "boo")] + + +class CustomAction(Action): + def name(cls) -> Text: + return "custom_action" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("test", "bar")] + + +class CustomActionRaisingException(Action): + def name(cls) -> Text: + return "custom_action_exception" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + raise Exception("test exception") + + +class CustomActionWithDialogueStack(Action): + def name(cls) -> Text: + return "custom_action_with_dialogue_stack" + + def run( + self, + dispatcher: CollectingDispatcher, + tracker: Tracker, + domain: DomainDict, + ) -> List[Dict[Text, Any]]: + return [SlotSet("stack", tracker.stack)] + + +class MockFormValidationAction(FormValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: str) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + def name(self) -> str: + return "mock_form_validation_action" + + +class MockValidationAction(ValidationAction): + def __init__(self) -> None: + self.fail_if_undefined("run") + + def fail_if_undefined(self, method_name: Text) -> None: + if not ( + hasattr(self.__class__.__base__, method_name) + and callable(getattr(self.__class__.__base__, method_name)) + ): + pytest.fail( + f"method '{method_name}' not found in {self.__class__.__base__}. " + f"This likely means the method was renamed, which means the " + f"instrumentation needs to be adapted!" + ) + + async def run( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + pass + + def name(self) -> Text: + return "mock_validation_action" + + async def _extract_validation_events( + self, + dispatcher: "CollectingDispatcher", + tracker: "Tracker", + domain: "DomainDict", + ) -> None: + return tracker.events + + +class SubclassTestActionA(Action): + def name(self): + return "subclass_test_action_a" + + +class SubclassTestActionB(SubclassTestActionA): + def name(self): + return "subclass_test_action_b" diff --git a/tests/test_actions.py b/tests/test_actions.py deleted file mode 100644 index e054d361c..000000000 --- a/tests/test_actions.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import List, Dict, Text, Any - -from rasa_sdk import Action, Tracker -from rasa_sdk.events import SlotSet -from rasa_sdk.executor import CollectingDispatcher -from rasa_sdk.types import DomainDict - - -class CustomAsyncAction(Action): - def name(cls) -> Text: - return "custom_async_action" - - async def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("test", "foo"), SlotSet("test2", "boo")] - - -class CustomAction(Action): - def name(cls) -> Text: - return "custom_action" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("test", "bar")] - - -class CustomActionRaisingException(Action): - def name(cls) -> Text: - return "custom_action_exception" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - raise Exception("test exception") - - -class CustomActionWithDialogueStack(Action): - def name(cls) -> Text: - return "custom_action_with_dialogue_stack" - - def run( - self, - dispatcher: CollectingDispatcher, - tracker: Tracker, - domain: DomainDict, - ) -> List[Dict[Text, Any]]: - return [SlotSet("stack", tracker.stack)] diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 9f4090ca1..315a4f3a9 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -4,86 +4,92 @@ import zlib import pytest +from sanic import Sanic import rasa_sdk.endpoint as ep from rasa_sdk.events import SlotSet from tests.conftest import get_stack -# noinspection PyTypeChecker -app = ep.create_app(None) - logger = logging.getLogger(__name__) +@pytest.fixture +def sanic_app(): + return ep.create_app("tests") + + def test_endpoint_exit_for_unknown_actions_package(): with pytest.raises(SystemExit): ep.create_app("non-existing-actions-package") -def test_server_health_returns_200(): - request, response = app.test_client.get("/health") +def test_server_health_returns_200(sanic_app: Sanic): + request, response = sanic_app.test_client.get("/health") assert response.status == 200 assert response.json == {"status": "ok"} -def test_server_list_actions_returns_200(): - request, response = app.test_client.get("/actions") +def test_server_list_actions_returns_200(sanic_app: Sanic): + request, response = sanic_app.test_client.get("/actions") assert response.status == 200 - assert len(response.json) == 6 - - # ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS + assert len(response.json) == 9 + print(response.json) expected = [ - # defined in tests/test_actions.py + # defined in tests/conftest.py {"name": "custom_async_action"}, {"name": "custom_action"}, {"name": "custom_action_exception"}, {"name": "custom_action_with_dialogue_stack"}, - # defined in tests/tracing/instrumentation/conftest.py + {"name": "subclass_test_action_a"}, {"name": "mock_validation_action"}, {"name": "mock_form_validation_action"}, + # defined in tests/test_forms.py + {"name": "some_form"}, + # defined in tests/conftest.py + {"name": "subclass_test_action_b"}, ] assert response.json == expected -def test_server_webhook_unknown_action_returns_404(): +def test_server_webhook_unknown_action_returns_404(sanic_app: Sanic): data = { "next_action": "test_action_1", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 404 -def test_server_webhook_handles_action_exception(): +def test_server_webhook_handles_action_exception(sanic_app: Sanic): data = { "next_action": "custom_action_exception", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 500 assert response.json.get("error") == "test exception" assert response.json.get("request_body") == data -def test_server_webhook_custom_action_returns_200(): +def test_server_webhook_custom_action_returns_200(sanic_app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") assert events == [SlotSet("test", "bar")] assert response.status == 200 -def test_server_webhook_custom_async_action_returns_200(): +def test_server_webhook_custom_async_action_returns_200(sanic_app: Sanic): data = { "next_action": "custom_async_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, } - request, response = app.test_client.post("/webhook", data=json.dumps(data)) + request, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") assert events == [SlotSet("test", "foo"), SlotSet("test2", "boo")] @@ -109,14 +115,14 @@ def test_arg_parser_actions_params_module_style(): assert cmdline_args.actions == "actions.act" -def test_server_webhook_custom_action_encoded_data_returns_200(): +def test_server_webhook_custom_action_encoded_data_returns_200(sanic_app: Sanic): data = { "next_action": "custom_action", "tracker": {"sender_id": "1", "conversation_id": "default"}, "domain": {"intents": ["greet", "goodbye"]}, } - request, response = app.test_client.post( + request, response = sanic_app.test_client.post( "/webhook", data=zlib.compress(json.dumps(data).encode()), headers={"Content-encoding": "deflate"}, @@ -135,13 +141,15 @@ def test_server_webhook_custom_action_encoded_data_returns_200(): ], ) def test_server_webhook_custom_action_with_dialogue_stack_returns_200( - stack_state: Dict[Text, Any], dialogue_stack: List[Dict[Text, Any]] + stack_state: Dict[Text, Any], + dialogue_stack: List[Dict[Text, Any]], + sanic_app: Sanic, ): data = { "next_action": "custom_action_with_dialogue_stack", "tracker": {"sender_id": "1", "conversation_id": "default", **stack_state}, } - _, response = app.test_client.post("/webhook", data=json.dumps(data)) + _, response = sanic_app.test_client.post("/webhook", data=json.dumps(data)) events = response.json.get("events") assert events == [SlotSet("stack", dialogue_stack)] diff --git a/tests/test_executor.py b/tests/test_executor.py index 3493a9983..cbad5aef2 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,11 +4,11 @@ import string import time -from rasa_sdk import Action from typing import Text, Optional, Generator import pytest from rasa_sdk.executor import ActionExecutor, CollectingDispatcher +from tests.conftest import SubclassTestActionA, SubclassTestActionB TEST_PACKAGE_BASE = "tests/executor_test_packages" @@ -237,16 +237,6 @@ async def test_reload_module( } -class SubclassTestActionA(Action): - def name(self): - return "subclass_test_action_a" - - -class SubclassTestActionB(SubclassTestActionA): - def name(self): - return "subclass_test_action_b" - - def test_load_subclasses(executor: ActionExecutor): executor.register_action(SubclassTestActionB) assert list(executor.actions) == ["subclass_test_action_b"] diff --git a/tests/test_plugin.py b/tests/test_plugin.py index c984f7e73..af08614ab 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -34,6 +34,7 @@ def test_plugin_attach_sanic_app_extension( monkeypatch.setattr( manager.hook, "attach_sanic_app_extensions", MagicMock(return_value=None) ) + monkeypatch.setattr("rasa_sdk.endpoint.Sanic.serve", MagicMock(return_value=None)) app_mock = MagicMock() # Create a MagicMock object to replace the create_app() method @@ -42,8 +43,8 @@ def test_plugin_attach_sanic_app_extension( # Set the create_app() method to return create_app_mock monkeypatch.setattr("rasa_sdk.endpoint.create_app", create_app_mock) - # Set the return value of app_mock.run() to None - app_mock.run.return_value = None + # Set the return value of app_mock.prepare() to None + app_mock.prepare.return_value = None with warnings.catch_warnings(): warnings.simplefilter("error") diff --git a/tests/tracing/instrumentation/action_fixtures/__init.py__ b/tests/tracing/instrumentation/action_fixtures/__init__.py similarity index 100% rename from tests/tracing/instrumentation/action_fixtures/__init.py__ rename to tests/tracing/instrumentation/action_fixtures/__init__.py diff --git a/tests/tracing/instrumentation/conftest.py b/tests/tracing/instrumentation/conftest.py index 5b7570f4b..0e2ea108d 100644 --- a/tests/tracing/instrumentation/conftest.py +++ b/tests/tracing/instrumentation/conftest.py @@ -5,10 +5,8 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from rasa_sdk.executor import ActionExecutor, CollectingDispatcher -from rasa_sdk.forms import ValidationAction, FormValidationAction -from rasa_sdk.types import ActionCall, DomainDict -from rasa_sdk import Tracker +from rasa_sdk.executor import ActionExecutor +from rasa_sdk.types import ActionCall @pytest.fixture(scope="session") @@ -52,65 +50,3 @@ def _create_api_response( events: List[Dict[Text, Any]], messages: List[Dict[Text, Any]] ) -> None: pass - - -class MockValidationAction(ValidationAction): - def __init__(self) -> None: - self.fail_if_undefined("run") - - def fail_if_undefined(self, method_name: Text) -> None: - if not ( - hasattr(self.__class__.__base__, method_name) - and callable(getattr(self.__class__.__base__, method_name)) - ): - pytest.fail( - f"method '{method_name}' not found in {self.__class__.__base__}. " - f"This likely means the method was renamed, which means the " - f"instrumentation needs to be adapted!" - ) - - async def run( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - pass - - def name(self) -> Text: - return "mock_validation_action" - - async def _extract_validation_events( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - return tracker.events - - -class MockFormValidationAction(FormValidationAction): - def __init__(self) -> None: - self.fail_if_undefined("run") - - def fail_if_undefined(self, method_name: Text) -> None: - if not ( - hasattr(self.__class__.__base__, method_name) - and callable(getattr(self.__class__.__base__, method_name)) - ): - pytest.fail( - f"method '{method_name}' not found in {self.__class__.__base__}. " - f"This likely means the method was renamed, which means the " - f"instrumentation needs to be adapted!" - ) - - async def _extract_validation_events( - self, - dispatcher: "CollectingDispatcher", - tracker: "Tracker", - domain: "DomainDict", - ) -> None: - return tracker.events - - def name(self) -> Text: - return "mock_form_validation_action" diff --git a/tests/tracing/instrumentation/test_form_validation_action.py b/tests/tracing/instrumentation/test_form_validation_action.py index 66374c6e3..37c849f0d 100644 --- a/tests/tracing/instrumentation/test_form_validation_action.py +++ b/tests/tracing/instrumentation/test_form_validation_action.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import MockFormValidationAction +from tests.conftest import MockFormValidationAction from rasa_sdk import Tracker from rasa_sdk.executor import CollectingDispatcher from rasa_sdk.events import ActionExecuted, SlotSet diff --git a/tests/tracing/instrumentation/test_tracing.py b/tests/tracing/instrumentation/test_tracing.py index e1af9b46d..f1c7a411e 100644 --- a/tests/tracing/instrumentation/test_tracing.py +++ b/tests/tracing/instrumentation/test_tracing.py @@ -1,10 +1,13 @@ import json +from functools import partial + import pytest import rasa_sdk.endpoint as ep from typing import Sequence from opentelemetry.sdk.trace import TracerProvider +from pytest import MonkeyPatch from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter @@ -35,11 +38,20 @@ def test_server_webhook_custom_action_is_instrumented( previous_num_captured_spans: int, action_name: str, action_package: str, + monkeypatch: MonkeyPatch, ) -> None: """Tests that the custom action is instrumented.""" - + monkeypatch.setattr( + "rasa_sdk.endpoint.get_tracer_provider", lambda _: tracer_provider + ) data["next_action"] = action_name - app = ep.create_app(action_package, tracer_provider=tracer_provider) + app = ep.create_app(action_package) + + app.register_listener( + partial(ep.load_tracer_provider, "endpoints.yml"), + "before_server_start", + ) + _, response = app.test_client.post("/webhook", data=json.dumps(data)) assert response.status == 200 diff --git a/tests/tracing/instrumentation/test_validation_action.py b/tests/tracing/instrumentation/test_validation_action.py index 860d5032d..6e11c1e68 100644 --- a/tests/tracing/instrumentation/test_validation_action.py +++ b/tests/tracing/instrumentation/test_validation_action.py @@ -5,7 +5,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from rasa_sdk.tracing.instrumentation import instrumentation -from tests.tracing.instrumentation.conftest import ( +from tests.conftest import ( MockValidationAction, ) from rasa_sdk import Tracker diff --git a/tests/tracing/test_utils.py b/tests/tracing/test_utils.py index f4ebe5729..00683c040 100644 --- a/tests/tracing/test_utils.py +++ b/tests/tracing/test_utils.py @@ -20,10 +20,11 @@ def test_get_tracer_provider_returns_none_if_no_endpoints_file() -> None: """Tests that get_tracer_provider returns None if no endpoints file is provided.""" parser = argparse.ArgumentParser() parser.add_argument("--random_args", type=str, default=None) + parser.add_argument("--endpoints", default="endpoints.yml") args = parser.parse_args(["--random_args", "random text"]) - tracer_provider = get_tracer_provider(args) + tracer_provider = get_tracer_provider(args.endpoints) assert tracer_provider is None @@ -36,7 +37,7 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "no_tracing.yml") args = parser.parse_args(["--endpoints", endpoints_file]) - tracer_provider = get_tracer_provider(args) + tracer_provider = get_tracer_provider(args.endpoints) assert tracer_provider is None @@ -50,7 +51,7 @@ def test_get_tracer_provider_returns_provider() -> None: endpoints_file = str(TRACING_TESTS_FIXTURES_DIRECTORY / "jaeger_endpoints.yml") args = parser.parse_args(["--endpoints", endpoints_file]) - tracer_provider = get_tracer_provider(args) + tracer_provider = get_tracer_provider(args.endpoints) assert tracer_provider is not None assert isinstance(tracer_provider, TracerProvider)