diff --git a/rasa_sdk/__main__.py b/rasa_sdk/__main__.py index 7433b9e5..dc337130 100644 --- a/rasa_sdk/__main__.py +++ b/rasa_sdk/__main__.py @@ -29,7 +29,7 @@ def main_from_args(args): args.port, args.ssl_certificate, args.ssl_keyfile, - args.ssl_password, + args.ssl_ca_cert_file, args.endpoints, ) ) @@ -47,7 +47,7 @@ def main_from_args(args): def main(): - # Running as standalone python application + """Runs the action server as standalone application.""" arg_parser = create_argument_parser() cmdline_args = arg_parser.parse_args() diff --git a/rasa_sdk/cli/arguments.py b/rasa_sdk/cli/arguments.py index 5a805ef8..f595ac22 100644 --- a/rasa_sdk/cli/arguments.py +++ b/rasa_sdk/cli/arguments.py @@ -3,17 +3,31 @@ from rasa_sdk.constants import DEFAULT_SERVER_PORT, DEFAULT_ENDPOINTS_PATH -def action_arg(action): - if "/" in action: +def action_arg(actions_module_path: str) -> str: + """Validate the action module path. + + Valid action module path is python module, so it should not contain a slash. + + Args: + actions_module_path: Path to the actions python module. + + Returns: + actions_module_path: If provided module path is valid. + + Raises: + argparse.ArgumentTypeError: If the module path is invalid. + """ + if "/" in actions_module_path: raise argparse.ArgumentTypeError( "Invalid actions format. Actions file should be a python module " "and passed with module notation (e.g. directory.actions)." ) else: - return action + return actions_module_path -def add_endpoint_arguments(parser): +def add_endpoint_arguments(parser: argparse.ArgumentParser) -> None: + """Add all the arguments to the argument parser.""" parser.add_argument( "-p", "--port", @@ -47,7 +61,15 @@ def add_endpoint_arguments(parser): "--ssl-password", default=None, help="If your ssl-keyfile is protected by a password, you can specify it " - "using this paramer.", + "using this parameter. " + "Not supported in grpc mode.", + ) + parser.add_argument( + "--ssl-ca-cert-file", + default=None, + help="If you want to authenticate the client using a certificate, you can " + "specify the CA certificate of the client using this parameter. " + "Supported only in grpc mode.", ) parser.add_argument( "--auto-reload", diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index 1b4f98d1..8f7115c5 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -3,7 +3,6 @@ import signal import asyncio -import ssl import grpc import logging @@ -147,7 +146,7 @@ async def run_grpc( port: int = DEFAULT_SERVER_PORT, ssl_certificate: Optional[str] = None, ssl_keyfile: Optional[str] = None, - ssl_password: Optional[str] = None, + ssl_ca_cert_file: Optional[str] = None, endpoints: str = DEFAULT_ENDPOINTS_PATH, ): """Start a gRPC server to handle incoming action requests. @@ -157,7 +156,7 @@ async def run_grpc( port: Port to start the server on. ssl_certificate: File path to the SSL certificate. ssl_keyfile: File path to the SSL key file. - ssl_password: Password for the SSL key file. + ssl_ca_cert_file: File path to the SSL CA certificate file. endpoints: Path to the endpoints file. """ workers = number_of_sanic_workers() @@ -170,14 +169,11 @@ async def run_grpc( action_webhook_pb2_grpc.add_ActionServerWebhookServicer_to_server( GRPCActionServerWebhook(executor, tracer_provider), server ) + + ca_cert = open(ssl_ca_cert_file, 'rb').read() if ssl_ca_cert_file else None + if ssl_certificate and ssl_keyfile: # Use SSL/TLS if certificate and key are provided - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_context.load_cert_chain( - ssl_certificate, - keyfile=ssl_keyfile, - password=ssl_password if ssl_password else None, - ) grpc.ssl_channel_credentials() private_key = open(ssl_keyfile, "rb").read() certificate_chain = open(ssl_certificate, "rb").read() @@ -187,6 +183,8 @@ async def run_grpc( server_credentials=grpc.ssl_server_credentials( [(private_key, certificate_chain)] ), + root_certificates=ca_cert, + require_client_auth=True if ca_cert else False, ) else: logger.info(f"Starting gRPC server without SSL on port {port}")