Skip to content

Commit

Permalink
Add support for client cert auth in gRPC
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 18, 2024
1 parent 5410a90 commit a5ad38d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
4 changes: 2 additions & 2 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main_from_args(args):
args.port,
args.ssl_certificate,
args.ssl_keyfile,
args.ssl_password,
args.ssl_ca_cert_file,
args.endpoints,
)
)
Expand All @@ -47,7 +47,7 @@ def main_from_args(args):


def main():
# Running as standalone python application
"""Runs the action server as standalone application."""
arg_parser = create_argument_parser()
cmdline_args = arg_parser.parse_args()

Expand Down
32 changes: 27 additions & 5 deletions rasa_sdk/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,31 @@
from rasa_sdk.constants import DEFAULT_SERVER_PORT, DEFAULT_ENDPOINTS_PATH


def action_arg(action):
if "/" in action:
def action_arg(actions_module_path: str) -> str:
"""Validate the action module path.
Valid action module path is python module, so it should not contain a slash.
Args:
actions_module_path: Path to the actions python module.
Returns:
actions_module_path: If provided module path is valid.
Raises:
argparse.ArgumentTypeError: If the module path is invalid.
"""
if "/" in actions_module_path:
raise argparse.ArgumentTypeError(
"Invalid actions format. Actions file should be a python module "
"and passed with module notation (e.g. directory.actions)."
)
else:
return action
return actions_module_path


def add_endpoint_arguments(parser):
def add_endpoint_arguments(parser: argparse.ArgumentParser) -> None:
"""Add all the arguments to the argument parser."""
parser.add_argument(
"-p",
"--port",
Expand Down Expand Up @@ -47,7 +61,15 @@ def add_endpoint_arguments(parser):
"--ssl-password",
default=None,
help="If your ssl-keyfile is protected by a password, you can specify it "
"using this paramer.",
"using this parameter. "
"Not supported in grpc mode.",
)
parser.add_argument(
"--ssl-ca-cert-file",
default=None,
help="If you want to authenticate the client using a certificate, you can "
"specify the CA certificate of the client using this parameter. "
"Supported only in grpc mode.",
)
parser.add_argument(
"--auto-reload",
Expand Down
16 changes: 7 additions & 9 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import signal

import asyncio
import ssl

import grpc
import logging
Expand Down Expand Up @@ -147,7 +146,7 @@ async def run_grpc(
port: int = DEFAULT_SERVER_PORT,
ssl_certificate: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_password: Optional[str] = None,
ssl_ca_cert_file: Optional[str] = None,
endpoints: str = DEFAULT_ENDPOINTS_PATH,
):
"""Start a gRPC server to handle incoming action requests.
Expand All @@ -157,7 +156,7 @@ async def run_grpc(
port: Port to start the server on.
ssl_certificate: File path to the SSL certificate.
ssl_keyfile: File path to the SSL key file.
ssl_password: Password for the SSL key file.
ssl_ca_cert_file: File path to the SSL CA certificate file.
endpoints: Path to the endpoints file.
"""
workers = number_of_sanic_workers()
Expand All @@ -170,14 +169,11 @@ async def run_grpc(
action_webhook_pb2_grpc.add_ActionServerWebhookServicer_to_server(
GRPCActionServerWebhook(executor, tracer_provider), server
)

ca_cert = open(ssl_ca_cert_file, 'rb').read() if ssl_ca_cert_file else None

if ssl_certificate and ssl_keyfile:
# Use SSL/TLS if certificate and key are provided
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(
ssl_certificate,
keyfile=ssl_keyfile,
password=ssl_password if ssl_password else None,
)
grpc.ssl_channel_credentials()
private_key = open(ssl_keyfile, "rb").read()
certificate_chain = open(ssl_certificate, "rb").read()
Expand All @@ -187,6 +183,8 @@ async def run_grpc(
server_credentials=grpc.ssl_server_credentials(
[(private_key, certificate_chain)]
),
root_certificates=ca_cert,
require_client_auth=True if ca_cert else False,
)
else:
logger.info(f"Starting gRPC server without SSL on port {port}")
Expand Down

0 comments on commit a5ad38d

Please sign in to comment.