diff --git a/Makefile b/Makefile index 2f2c72679..173d03dda 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,6 @@ install: poetry run python -m pip install -U pip poetry install - clean: find . -name '*.pyc' -exec rm -f {} + find . -name '*.pyo' -exec rm -f {} + @@ -62,3 +61,13 @@ cleanup-generated-changelog: release: poetry run python scripts/release.py +download-protoc-compiler: + curl -0L https://github.com/protocolbuffers/protobuf/releases/download/v25.0/protoc-25.0-osx-aarch_64.zip --output protoc-25.0-osx-aarch_64.zip + +generate-grpc: + python -m grpc_tools.protoc \ + -Irasa_sdk/grpc_py=./proto \ + --python_out=. \ + --grpc_python_out=. \ + --pyi_out=. \ + proto/action_webhook.proto \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 557bc709d..e603df7b7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1564,6 +1564,17 @@ tomli = {version = "*", markers = "python_version < \"3.11\""} [package.extras] dev = ["furo", "packaging", "sphinx (>=5)", "twisted"] +[[package]] +name = "types-protobuf" +version = "4.25.0.20240417" +description = "Typing stubs for protobuf" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-protobuf-4.25.0.20240417.tar.gz", hash = "sha256:c34eff17b9b3a0adb6830622f0f302484e4c089f533a46e3f147568313544352"}, + {file = "types_protobuf-4.25.0.20240417-py3-none-any.whl", hash = "sha256:e9b613227c2127e3d4881d75d93c93b4d6fd97b5f6a099a0b654a05351c8685d"}, +] + [[package]] name = "typing-extensions" version = "4.9.0" @@ -1900,4 +1911,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "154b5b360e5a822ebeacf1cc569bfdb45c9c2e318a4ff69031210117abb7c47a" +content-hash = "6fab19810735e6ca4cd28bc038f641e64545379904ba604b9f405088d0e2bfbe" diff --git a/pyproject.toml b/pyproject.toml index 7d4105334..6dfc4372b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,3 +114,4 @@ asyncio_mode = "auto" [tool.poetry.group.dev.dependencies] ruff = ">=0.0.256,<0.0.286" pytest-asyncio = "^0.21.0" +types-protobuf = "4.25.0.20240417" diff --git a/rasa_sdk/__main__.py b/rasa_sdk/__main__.py index 67b9fb497..dcd6e70a0 100644 --- a/rasa_sdk/__main__.py +++ b/rasa_sdk/__main__.py @@ -1,8 +1,24 @@ import logging +import asyncio +import signal from rasa_sdk import utils from rasa_sdk.endpoint import create_argument_parser, run from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME +from rasa_sdk.grpc_server import run_grpc + +logger = logging.getLogger(__name__) + + +def initialise_interrupts() -> None: + """Initialise handlers for kernel signal interrupts.""" + + def handle_sigint(signum, frame): + logger.info("Received SIGINT, exiting") + asyncio.get_event_loop().stop() + + signal.signal(signal.SIGINT, handle_sigint) + signal.signal(signal.SIGTERM, handle_sigint) def main_from_args(args): @@ -18,16 +34,30 @@ def main_from_args(args): ) utils.update_sanic_log_level() - run( - args.actions, - args.port, - args.cors, - args.ssl_certificate, - args.ssl_keyfile, - args.ssl_password, - args.auto_reload, - args.endpoints, - ) + initialise_interrupts() + + if args.grpc: + asyncio.run( + run_grpc( + args.actions, + args.port, + args.ssl_certificate, + args.ssl_keyfile, + args.ssl_password, + args.endpoints, + ) + ) + else: + run( + args.actions, + args.port, + args.cors, + args.ssl_certificate, + args.ssl_keyfile, + args.ssl_password, + args.auto_reload, + args.endpoints, + ) def main(): diff --git a/rasa_sdk/cli/arguments.py b/rasa_sdk/cli/arguments.py index c520d72af..5a805ef8d 100644 --- a/rasa_sdk/cli/arguments.py +++ b/rasa_sdk/cli/arguments.py @@ -59,3 +59,8 @@ def add_endpoint_arguments(parser): default=DEFAULT_ENDPOINTS_PATH, help="Configuration file for the assistant as a yml file.", ) + parser.add_argument( + "--grpc", + help="Starts grpc server instead of http", + action="store_true" + ) diff --git a/rasa_sdk/grpc_py/__init__.py b/rasa_sdk/grpc_py/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rasa_sdk/grpc_py/action_webhook_pb2.py b/rasa_sdk/grpc_py/action_webhook_pb2.py new file mode 100644 index 000000000..9bb03ad36 --- /dev/null +++ b/rasa_sdk/grpc_py/action_webhook_pb2.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: rasa_sdk/grpc_py/action_webhook.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n%rasa_sdk/grpc_py/action_webhook.proto\x12\x15\x61\x63tion_server_webhook\x1a\x19google/protobuf/any.proto\"\xf3\x04\n\x07Tracker\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x38\n\x05slots\x18\x02 \x03(\x0b\x32).action_server_webhook.Tracker.SlotsEntry\x12I\n\x0elatest_message\x18\x03 \x03(\x0b\x32\x31.action_server_webhook.Tracker.LatestMessageEntry\x12\x19\n\x11latest_event_time\x18\x04 \x01(\x01\x12\x17\n\x0f\x66ollowup_action\x18\x05 \x01(\t\x12\x0e\n\x06paused\x18\x06 \x01(\x08\x12\x0e\n\x06\x65vents\x18\x07 \x03(\t\x12\x1c\n\x14latest_input_channel\x18\x08 \x01(\t\x12\x43\n\x0b\x61\x63tive_loop\x18\t \x03(\x0b\x32..action_server_webhook.Tracker.ActiveLoopEntry\x12G\n\rlatest_action\x18\n \x03(\x0b\x32\x30.action_server_webhook.Tracker.LatestActionEntry\x1a,\n\nSlotsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x34\n\x12LatestMessageEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x31\n\x0f\x41\x63tiveLoopEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x33\n\x11LatestActionEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xfc\x04\n\x06\x44omain\x12\x39\n\x06\x63onfig\x18\x01 \x03(\x0b\x32).action_server_webhook.Domain.ConfigEntry\x12H\n\x0esession_config\x18\x02 \x03(\x0b\x32\x30.action_server_webhook.Domain.SessionConfigEntry\x12\x0f\n\x07intents\x18\x03 \x03(\t\x12\x10\n\x08\x65ntities\x18\x04 \x03(\t\x12\x37\n\x05slots\x18\x05 \x03(\x0b\x32(.action_server_webhook.Domain.SlotsEntry\x12?\n\tresponses\x18\x06 \x03(\x0b\x32,.action_server_webhook.Domain.ResponsesEntry\x12\x0f\n\x07\x61\x63tions\x18\x07 \x03(\t\x12\x37\n\x05\x66orms\x18\x08 \x03(\x0b\x32(.action_server_webhook.Domain.FormsEntry\x12\x13\n\x0b\x65\x32\x65_actions\x18\t \x03(\t\x1a-\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x34\n\x12SessionConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a,\n\nSlotsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x30\n\x0eResponsesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a,\n\nFormsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xa9\x01\n\x0eWebhookRequest\x12\x13\n\x0bnext_action\x18\x01 \x01(\t\x12\x11\n\tsender_id\x18\x02 \x01(\t\x12/\n\x07tracker\x18\x03 \x01(\x0b\x32\x1e.action_server_webhook.Tracker\x12-\n\x06\x64omain\x18\x04 \x01(\x0b\x32\x1d.action_server_webhook.Domain\x12\x0f\n\x07version\x18\x05 \x01(\t\"`\n\x0fWebhookResponse\x12$\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x14.google.protobuf.Any\x12\'\n\tresponses\x18\x02 \x03(\x0b\x32\x14.google.protobuf.Any2o\n\x13\x41\x63tionServerWebhook\x12X\n\x07webhook\x12%.action_server_webhook.WebhookRequest\x1a&.action_server_webhook.WebhookResponseb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'rasa_sdk.grpc_py.action_webhook_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _TRACKER_SLOTSENTRY._options = None + _TRACKER_SLOTSENTRY._serialized_options = b'8\001' + _TRACKER_LATESTMESSAGEENTRY._options = None + _TRACKER_LATESTMESSAGEENTRY._serialized_options = b'8\001' + _TRACKER_ACTIVELOOPENTRY._options = None + _TRACKER_ACTIVELOOPENTRY._serialized_options = b'8\001' + _TRACKER_LATESTACTIONENTRY._options = None + _TRACKER_LATESTACTIONENTRY._serialized_options = b'8\001' + _DOMAIN_CONFIGENTRY._options = None + _DOMAIN_CONFIGENTRY._serialized_options = b'8\001' + _DOMAIN_SESSIONCONFIGENTRY._options = None + _DOMAIN_SESSIONCONFIGENTRY._serialized_options = b'8\001' + _DOMAIN_SLOTSENTRY._options = None + _DOMAIN_SLOTSENTRY._serialized_options = b'8\001' + _DOMAIN_RESPONSESENTRY._options = None + _DOMAIN_RESPONSESENTRY._serialized_options = b'8\001' + _DOMAIN_FORMSENTRY._options = None + _DOMAIN_FORMSENTRY._serialized_options = b'8\001' + _globals['_TRACKER']._serialized_start=92 + _globals['_TRACKER']._serialized_end=719 + _globals['_TRACKER_SLOTSENTRY']._serialized_start=517 + _globals['_TRACKER_SLOTSENTRY']._serialized_end=561 + _globals['_TRACKER_LATESTMESSAGEENTRY']._serialized_start=563 + _globals['_TRACKER_LATESTMESSAGEENTRY']._serialized_end=615 + _globals['_TRACKER_ACTIVELOOPENTRY']._serialized_start=617 + _globals['_TRACKER_ACTIVELOOPENTRY']._serialized_end=666 + _globals['_TRACKER_LATESTACTIONENTRY']._serialized_start=668 + _globals['_TRACKER_LATESTACTIONENTRY']._serialized_end=719 + _globals['_DOMAIN']._serialized_start=722 + _globals['_DOMAIN']._serialized_end=1358 + _globals['_DOMAIN_CONFIGENTRY']._serialized_start=1117 + _globals['_DOMAIN_CONFIGENTRY']._serialized_end=1162 + _globals['_DOMAIN_SESSIONCONFIGENTRY']._serialized_start=1164 + _globals['_DOMAIN_SESSIONCONFIGENTRY']._serialized_end=1216 + _globals['_DOMAIN_SLOTSENTRY']._serialized_start=517 + _globals['_DOMAIN_SLOTSENTRY']._serialized_end=561 + _globals['_DOMAIN_RESPONSESENTRY']._serialized_start=1264 + _globals['_DOMAIN_RESPONSESENTRY']._serialized_end=1312 + _globals['_DOMAIN_FORMSENTRY']._serialized_start=1314 + _globals['_DOMAIN_FORMSENTRY']._serialized_end=1358 + _globals['_WEBHOOKREQUEST']._serialized_start=1361 + _globals['_WEBHOOKREQUEST']._serialized_end=1530 + _globals['_WEBHOOKRESPONSE']._serialized_start=1532 + _globals['_WEBHOOKRESPONSE']._serialized_end=1628 + _globals['_ACTIONSERVERWEBHOOK']._serialized_start=1630 + _globals['_ACTIONSERVERWEBHOOK']._serialized_end=1741 +# @@protoc_insertion_point(module_scope) diff --git a/rasa_sdk/grpc_py/action_webhook_pb2.pyi b/rasa_sdk/grpc_py/action_webhook_pb2.pyi new file mode 100644 index 000000000..23caab556 --- /dev/null +++ b/rasa_sdk/grpc_py/action_webhook_pb2.pyi @@ -0,0 +1,138 @@ +from google.protobuf import any_pb2 as _any_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class Tracker(_message.Message): + __slots__ = ["conversation_id", "slots", "latest_message", "latest_event_time", "followup_action", "paused", "events", "latest_input_channel", "active_loop", "latest_action"] + class SlotsEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class LatestMessageEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class ActiveLoopEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class LatestActionEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + CONVERSATION_ID_FIELD_NUMBER: _ClassVar[int] + SLOTS_FIELD_NUMBER: _ClassVar[int] + LATEST_MESSAGE_FIELD_NUMBER: _ClassVar[int] + LATEST_EVENT_TIME_FIELD_NUMBER: _ClassVar[int] + FOLLOWUP_ACTION_FIELD_NUMBER: _ClassVar[int] + PAUSED_FIELD_NUMBER: _ClassVar[int] + EVENTS_FIELD_NUMBER: _ClassVar[int] + LATEST_INPUT_CHANNEL_FIELD_NUMBER: _ClassVar[int] + ACTIVE_LOOP_FIELD_NUMBER: _ClassVar[int] + LATEST_ACTION_FIELD_NUMBER: _ClassVar[int] + conversation_id: str + slots: _containers.ScalarMap[str, str] + latest_message: _containers.ScalarMap[str, str] + latest_event_time: float + followup_action: str + paused: bool + events: _containers.RepeatedScalarFieldContainer[str] + latest_input_channel: str + active_loop: _containers.ScalarMap[str, str] + latest_action: _containers.ScalarMap[str, str] + def __init__(self, conversation_id: _Optional[str] = ..., slots: _Optional[_Mapping[str, str]] = ..., latest_message: _Optional[_Mapping[str, str]] = ..., latest_event_time: _Optional[float] = ..., followup_action: _Optional[str] = ..., paused: bool = ..., events: _Optional[_Iterable[str]] = ..., latest_input_channel: _Optional[str] = ..., active_loop: _Optional[_Mapping[str, str]] = ..., latest_action: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class Domain(_message.Message): + __slots__ = ["config", "session_config", "intents", "entities", "slots", "responses", "actions", "forms", "e2e_actions"] + class ConfigEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class SessionConfigEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class SlotsEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class ResponsesEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + class FormsEntry(_message.Message): + __slots__ = ["key", "value"] + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + CONFIG_FIELD_NUMBER: _ClassVar[int] + SESSION_CONFIG_FIELD_NUMBER: _ClassVar[int] + INTENTS_FIELD_NUMBER: _ClassVar[int] + ENTITIES_FIELD_NUMBER: _ClassVar[int] + SLOTS_FIELD_NUMBER: _ClassVar[int] + RESPONSES_FIELD_NUMBER: _ClassVar[int] + ACTIONS_FIELD_NUMBER: _ClassVar[int] + FORMS_FIELD_NUMBER: _ClassVar[int] + E2E_ACTIONS_FIELD_NUMBER: _ClassVar[int] + config: _containers.ScalarMap[str, str] + session_config: _containers.ScalarMap[str, str] + intents: _containers.RepeatedScalarFieldContainer[str] + entities: _containers.RepeatedScalarFieldContainer[str] + slots: _containers.ScalarMap[str, str] + responses: _containers.ScalarMap[str, str] + actions: _containers.RepeatedScalarFieldContainer[str] + forms: _containers.ScalarMap[str, str] + e2e_actions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, config: _Optional[_Mapping[str, str]] = ..., session_config: _Optional[_Mapping[str, str]] = ..., intents: _Optional[_Iterable[str]] = ..., entities: _Optional[_Iterable[str]] = ..., slots: _Optional[_Mapping[str, str]] = ..., responses: _Optional[_Mapping[str, str]] = ..., actions: _Optional[_Iterable[str]] = ..., forms: _Optional[_Mapping[str, str]] = ..., e2e_actions: _Optional[_Iterable[str]] = ...) -> None: ... + +class WebhookRequest(_message.Message): + __slots__ = ["next_action", "sender_id", "tracker", "domain", "version"] + NEXT_ACTION_FIELD_NUMBER: _ClassVar[int] + SENDER_ID_FIELD_NUMBER: _ClassVar[int] + TRACKER_FIELD_NUMBER: _ClassVar[int] + DOMAIN_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] + next_action: str + sender_id: str + tracker: Tracker + domain: Domain + version: str + def __init__(self, next_action: _Optional[str] = ..., sender_id: _Optional[str] = ..., tracker: _Optional[_Union[Tracker, _Mapping]] = ..., domain: _Optional[_Union[Domain, _Mapping]] = ..., version: _Optional[str] = ...) -> None: ... + +class WebhookResponse(_message.Message): + __slots__ = ["events", "responses"] + EVENTS_FIELD_NUMBER: _ClassVar[int] + RESPONSES_FIELD_NUMBER: _ClassVar[int] + events: _containers.RepeatedCompositeFieldContainer[_any_pb2.Any] + responses: _containers.RepeatedCompositeFieldContainer[_any_pb2.Any] + def __init__(self, events: _Optional[_Iterable[_Union[_any_pb2.Any, _Mapping]]] = ..., responses: _Optional[_Iterable[_Union[_any_pb2.Any, _Mapping]]] = ...) -> None: ... diff --git a/rasa_sdk/grpc_py/action_webhook_pb2_grpc.py b/rasa_sdk/grpc_py/action_webhook_pb2_grpc.py new file mode 100644 index 000000000..215a4062a --- /dev/null +++ b/rasa_sdk/grpc_py/action_webhook_pb2_grpc.py @@ -0,0 +1,66 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from rasa_sdk.grpc_py import action_webhook_pb2 as rasa__sdk_dot_grpc__py_dot_action__webhook__pb2 + + +class ActionServerWebhookStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.webhook = channel.unary_unary( + '/action_server_webhook.ActionServerWebhook/webhook', + request_serializer=rasa__sdk_dot_grpc__py_dot_action__webhook__pb2.WebhookRequest.SerializeToString, + response_deserializer=rasa__sdk_dot_grpc__py_dot_action__webhook__pb2.WebhookResponse.FromString, + ) + + +class ActionServerWebhookServicer(object): + """Missing associated documentation comment in .proto file.""" + + def webhook(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_ActionServerWebhookServicer_to_server(servicer, server): + rpc_method_handlers = { + 'webhook': grpc.unary_unary_rpc_method_handler( + servicer.webhook, + request_deserializer=rasa__sdk_dot_grpc__py_dot_action__webhook__pb2.WebhookRequest.FromString, + response_serializer=rasa__sdk_dot_grpc__py_dot_action__webhook__pb2.WebhookResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'action_server_webhook.ActionServerWebhook', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class ActionServerWebhook(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def webhook(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/action_server_webhook.ActionServerWebhook/webhook', + rasa__sdk_dot_grpc__py_dot_action__webhook__pb2.WebhookRequest.SerializeToString, + rasa__sdk_dot_grpc__py_dot_action__webhook__pb2.WebhookResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index e69de29bb..78b015de4 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -0,0 +1,110 @@ +import grpc +import utils +import logging +import ssl +import types +from typing import Text, Union, Optional +from concurrent import futures +from grpc import aio +from google.protobuf.json_format import MessageToDict, ParseDict + +from rasa_sdk.constants import DEFAULT_SERVER_PORT, DEFAULT_ENDPOINTS_PATH +from rasa_sdk.executor import ActionExecutor +from rasa_sdk.grpc_py import action_webhook_pb2, action_webhook_pb2_grpc +from rasa_sdk.grpc_py.action_webhook_pb2 import WebhookRequest +from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException +from rasa_sdk.tracing.utils import ( + get_tracer_and_context, + TracerProvider, + get_tracer_provider, +) + +logger = logging.getLogger(__name__) + + +class ActionServerWebhook(action_webhook_pb2_grpc.ActionServerWebhookServicer): + def __init__( + self, + executor: ActionExecutor, + tracer_provider: Optional[TracerProvider] = None, + ) -> None: + """Initializes the ActionServerWebhook. + + Args: + tracer_provider: The tracer provider. + executor: The action executor. + """ + self.tracer_provider = tracer_provider + self.executor = executor + + async def webhook(self, request: WebhookRequest, context): + tracer, context, span_name = get_tracer_and_context( + self.tracer_provider, request + ) + with tracer.start_as_current_span(span_name, context=context) as span: + utils.check_version_compatibility(request.version) + try: + action_call = MessageToDict(request) + result = await self.executor.run(action_call) + except ActionExecutionRejection as e: + logger.debug(e) + body = {"error": e.message, "action_name": e.action_name} + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(body)) + return action_webhook_pb2.WebhookResponse() + except ActionNotFoundException as e: + logger.error(e) + body = {"error": e.message, "action_name": e.action_name} + context.set_code(grpc.StatusCode.NOTFOUND) + context.set_details(str(body)) + return action_webhook_pb2.WebhookResponse() + if not result: + return action_webhook_pb2.WebhookResponse() + # set_span_attributes(span, request) + response = action_webhook_pb2.WebhookResponse() + + return ParseDict(result, response) + + +def get_ssl_password_callback(ssl_password): + def password_callback(*args, **kwargs): + return ssl_password.encode() if ssl_password else None + + return password_callback + + +async def run_grpc( + action_package_name: Union[Text, types.ModuleType], + port: int = DEFAULT_SERVER_PORT, + ssl_certificate: Optional[Text] = None, + ssl_keyfile: Optional[Text] = None, + ssl_password: Optional[Text] = None, + endpoints: str = DEFAULT_ENDPOINTS_PATH, +): + workers = utils.number_of_sanic_workers() + server = aio.server(futures.ThreadPoolExecutor(max_workers=workers)) + executor = ActionExecutor() + executor.register_package(action_package_name) + # tracer_provider = get_tracer_provider(endpoints) + tracer_provider = None + action_webhook_pb2_grpc.add_ActionServerWebhookServicer_to_server( + ActionServerWebhook(executor, tracer_provider), server + ) + if ssl_certificate and ssl_keyfile: + # Use SSL/TLS if certificate and key are provided + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ssl_context.load_cert_chain( + ssl_certificate, + keyfile=ssl_keyfile, + password=get_ssl_password_callback(ssl_password), + ) + server.add_secure_port( + f"[::]:{port}", server_credentials=grpc.ssl_server_credentials(ssl_context) + ) + else: + # Use insecure connection if no SSL/TLS information is provided + server.add_insecure_port(f"[::]:{port}") + + await server.start() + print(f"gRPC Server started on port {port}") + await server.wait_for_termination() diff --git a/rasa_sdk/tracing/utils.py b/rasa_sdk/tracing/utils.py index 32b759b24..bbf56374c 100644 --- a/rasa_sdk/tracing/utils.py +++ b/rasa_sdk/tracing/utils.py @@ -1,3 +1,4 @@ +from rasa_sdk.grpc_py.action_webhook_pb2 import WebhookRequest from rasa_sdk.tracing import config from opentelemetry import trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -5,7 +6,7 @@ from opentelemetry.sdk.trace import TracerProvider from sanic.request import Request -from typing import Optional, Tuple, Any, Text +from typing import Optional, Tuple, Any, Text, Union def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: @@ -17,11 +18,12 @@ def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]: def get_tracer_and_context( - tracer_provider: Optional[TracerProvider], request: Request + tracer_provider: Optional[TracerProvider], request: Union[Request, WebhookRequest] ) -> Tuple[Any, Any, Text]: """Gets tracer and context.""" span_name = "create_app.webhook" - if tracer_provider is None: + + if tracer_provider is None or isinstance(request, WebhookRequest): tracer = trace.get_tracer(span_name) context = None else: