Skip to content

Commit

Permalink
Remove SSL password provider function
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 17, 2024
1 parent fc6bc9b commit 5f94cab
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 52 deletions.
1 change: 1 addition & 0 deletions rasa_sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
"https://docs.python.org/3/library/logging.config.html#dictionary-schema-details"
)
DEFAULT_ENDPOINTS_PATH = "endpoints.yml"
NO_GRACE_PERIOD = 0
5 changes: 4 additions & 1 deletion rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ def _create_api_response(
return {"events": events, "responses": messages}

@staticmethod
def validate_events(events: List[Dict[Text, Any]], action_name: Text) -> List[Dict[Text, Any]]:
def validate_events(
events: List[Dict[Text, Any]],
action_name: Text,
) -> List[Dict[Text, Any]]:
"""Validate the events returned by the action.
Args:
Expand Down
21 changes: 20 additions & 1 deletion rasa_sdk/grpc_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,33 @@
from enum import Enum


class ActionExecutionFailed(BaseModel):
"""Error which indicates that an action execution failed.
Attributes:
action_name: Name of the action that failed.
message: Message which describes the error.
"""

action_name: str = Field(alias="action_name")
message: str = Field(alias="message")


class ResourceNotFoundType(str, Enum):
"""Type of resource that was not found."""

ACTION = "ACTION"
DOMAIN = "DOMAIN"


class ResourceNotFound(BaseModel):
"""Error which indicates that a resource was not found."""
"""Error which indicates that a resource was not found.
Attributes:
action_name: Name of the action that was not found.
message: Message which describes the error.
"""

action_name: str = Field(alias="action_name")
message: str = Field(alias="message")
resource_type: ResourceNotFoundType = Field(alias="resource_type")
93 changes: 43 additions & 50 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from __future__ import annotations

import time

import sys

import signal

import asyncio
Expand All @@ -13,21 +9,32 @@
import logging
import ssl
import types
from typing import Union, Optional, Callable
from typing import Union, Optional
from concurrent import futures
from grpc import aio
from google.protobuf.json_format import MessageToDict, ParseDict

from rasa_sdk.constants import DEFAULT_SERVER_PORT, DEFAULT_ENDPOINTS_PATH
from rasa_sdk.constants import (
DEFAULT_SERVER_PORT,
DEFAULT_ENDPOINTS_PATH,
NO_GRACE_PERIOD,
)
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.grpc_errors import ResourceNotFound, ResourceNotFoundType
from rasa_sdk.grpc_errors import (
ResourceNotFound,
ResourceNotFoundType,
ActionExecutionFailed,
)
from rasa_sdk.grpc_py import action_webhook_pb2, action_webhook_pb2_grpc
from rasa_sdk.grpc_py.action_webhook_pb2 import WebhookRequest
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException, ActionMissingDomainException
from rasa_sdk.interfaces import (
ActionExecutionRejection,
ActionNotFoundException,
ActionMissingDomainException,
)
from rasa_sdk.tracing.utils import (
get_tracer_and_context,
TracerProvider,
get_tracer_provider,
)

logger = logging.getLogger(__name__)
Expand All @@ -48,9 +55,11 @@ def __init__(
self.tracer_provider = tracer_provider
self.executor = executor

async def webhook(self,
request: WebhookRequest, context,
) -> action_webhook_pb2.WebhookResponse:
async def webhook(
self,
request: WebhookRequest,
context,
) -> action_webhook_pb2.WebhookResponse:
"""Handle RPC request for the webhook.
Args:
Expand All @@ -61,36 +70,37 @@ async def webhook(self,
tracer, tracer_context, span_name = get_tracer_and_context(
self.tracer_provider, request
)
with tracer.start_as_current_span(span_name, context=tracer_context) as span:
with tracer.start_as_current_span(span_name, context=tracer_context):
utils.check_version_compatibility(request.version)
try:
action_call = MessageToDict(request, preserving_proto_field_name=True)
result = await self.executor.run(action_call)
except ActionExecutionRejection as e:
logger.debug(e)
body = {"error": e.message, "action_name": e.action_name}

body = ActionExecutionFailed(
action_name=e.action_name, message=e.message
).model_dump()
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(body))
context.set_details(body)
return action_webhook_pb2.WebhookResponse()
except ActionNotFoundException as e:
logger.error(e)
resource_not_found = ResourceNotFound(
body = ResourceNotFound(
action_name=e.action_name,
message=e.message,
resource_type=ResourceNotFoundType.ACTION,
)
body = resource_not_found.json()
).model_dump()
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(body)
return action_webhook_pb2.WebhookResponse()
except ActionMissingDomainException as e:
logger.error(e)
resource_not_found = ResourceNotFound(
body = ResourceNotFound(
action_name=e.action_name,
message=e.message,
resource_type=ResourceNotFoundType.DOMAIN,
)
body = resource_not_found.json()
).model_dump()
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(body)
return action_webhook_pb2.WebhookResponse()
Expand All @@ -102,22 +112,6 @@ async def webhook(self,
return ParseDict(result, response)


def get_ssl_password_callback(ssl_password: str) -> Callable[[], bytes]:
"""Return a password callback function for the SSL key file.
Args:
ssl_password: Password for the SSL key file.
Returns:
A password callback function.
"""
def password_callback() -> bytes:
"""Return the SSL password as bytes."""
return ssl_password.encode() if ssl_password else None

return password_callback


def get_signal_name(signal_number: int) -> str:
"""Return the signal name for the given signal number."""
return signal.Signals(signal_number).name
Expand All @@ -128,16 +122,19 @@ def initialise_interrupts(server: grpc.aio.Server) -> None:

async def handle_sigint(signal_received: int):
logger.info(
f"Received {get_signal_name(signal_received)} signal. Stopping gRPC server..."
f"Received {get_signal_name(signal_received)} signal."
"Stopping gRPC server..."
)
await server.stop(0)
await server.stop(NO_GRACE_PERIOD)
logger.info("gRPC server stopped.")

loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT,
lambda: asyncio.create_task(handle_sigint(signal.SIGINT)))
loop.add_signal_handler(signal.SIGTERM,
lambda: asyncio.create_task(handle_sigint(signal.SIGTERM)))
loop.add_signal_handler(
signal.SIGINT, lambda: asyncio.create_task(handle_sigint(signal.SIGINT))
)
loop.add_signal_handler(
signal.SIGTERM, lambda: asyncio.create_task(handle_sigint(signal.SIGTERM))
)


async def run_grpc(
Expand Down Expand Up @@ -174,18 +171,14 @@ async def run_grpc(
ssl_context.load_cert_chain(
ssl_certificate,
keyfile=ssl_keyfile,
password=get_ssl_password_callback(ssl_password),
)
logger.info(
f"Starting gRPC server with SSL support on port {port}"
password=ssl_password if ssl_password else None,
)
logger.info(f"Starting gRPC server with SSL support on port {port}")
server.add_secure_port(
f"[::]:{port}", server_credentials=grpc.ssl_server_credentials(ssl_context)
)
else:
logger.info(
f"Starting gRPC server without SSL on port {port}"
)
logger.info(f"Starting gRPC server without SSL on port {port}")
# Use insecure connection if no SSL/TLS information is provided
server.add_insecure_port(f"[::]:{port}")

Expand Down

0 comments on commit 5f94cab

Please sign in to comment.