From ed053935b33a44f9bf6f650a6fb4e1592444ee7b Mon Sep 17 00:00:00 2001 From: Radovan Zivkovic <r.zivkovic@rasa.com> Date: Mon, 17 Jun 2024 14:21:51 +0200 Subject: [PATCH] Use partial import form utils --- rasa_sdk/grpc_server.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index 5fce7dc8..27796fe7 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -5,7 +5,6 @@ import asyncio import grpc -import utils import logging import ssl import types @@ -36,11 +35,14 @@ get_tracer_and_context, TracerProvider, ) +from rasa_sdk.utils import check_version_compatibility, number_of_sanic_workers logger = logging.getLogger(__name__) -class ActionServerWebhook(action_webhook_pb2_grpc.ActionServerWebhookServicer): +class GRPCActionServerWebhook(action_webhook_pb2_grpc.ActionServerWebhookServicer): + """Runs webhook RPC which is served through gRPC server.""" + def __init__( self, executor: ActionExecutor, @@ -65,13 +67,15 @@ async def webhook( Args: request: The webhook request. context: The context of the request. + + Returns: + gRPC response. """ - await asyncio.sleep(50) tracer, tracer_context, span_name = get_tracer_and_context( self.tracer_provider, request ) with tracer.start_as_current_span(span_name, context=tracer_context): - utils.check_version_compatibility(request.version) + check_version_compatibility(request.version) try: action_call = MessageToDict(request, preserving_proto_field_name=True) result = await self.executor.run(action_call) @@ -121,6 +125,7 @@ def initialise_interrupts(server: grpc.aio.Server) -> None: """Initialise handlers for kernel signal interrupts.""" async def handle_sigint(signal_received: int): + """Handle the received signal.""" logger.info( f"Received {get_signal_name(signal_received)} signal." "Stopping gRPC server..." @@ -155,7 +160,7 @@ async def run_grpc( ssl_password: Password for the SSL key file. endpoints: Path to the endpoints file. """ - workers = utils.number_of_sanic_workers() + workers = number_of_sanic_workers() server = aio.server(futures.ThreadPoolExecutor(max_workers=workers)) initialise_interrupts(server) executor = ActionExecutor() @@ -163,7 +168,7 @@ async def run_grpc( # tracer_provider = get_tracer_provider(endpoints) tracer_provider = None action_webhook_pb2_grpc.add_ActionServerWebhookServicer_to_server( - ActionServerWebhook(executor, tracer_provider), server + GRPCActionServerWebhook(executor, tracer_provider), server ) if ssl_certificate and ssl_keyfile: # Use SSL/TLS if certificate and key are provided