From 737a5259923370698365687b3771c68e16821fd1 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 11:19:01 +0200 Subject: [PATCH 1/3] add syncronous message handler --- packages/syft/src/syft/orchestra.py | 5 ++ packages/syft/src/syft/server/server.py | 14 +++-- packages/syft/src/syft/service/queue/queue.py | 16 ++++-- .../syft/src/syft/service/queue/zmq_client.py | 6 ++- .../src/syft/service/queue/zmq_consumer.py | 51 +++++++++++++++++++ 5 files changed, 84 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 5921b53b434..6dcba7bf154 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -28,6 +28,7 @@ from .server.enclave import Enclave from .server.gateway import Gateway from .server.uvicorn import serve_server +from .service.queue.queue import Handler from .service.response import SyftInfo from .types.errors import SyftException from .util.util import get_random_available_port @@ -182,6 +183,7 @@ def deploy_to_python( log_level: str | int | None = None, debug: bool = False, migrate: bool = False, + handler_type: Handler | None = None, ) -> ServerHandle: worker_classes = { ServerType.DATASITE: Datasite, @@ -213,6 +215,7 @@ def deploy_to_python( "debug": debug, "migrate": migrate, "deployment_type": deployment_type_enum, + "handler_type": handler_type, } if port: @@ -325,6 +328,7 @@ def launch( debug: bool = False, migrate: bool = False, from_state_folder: str | Path | None = None, + handler_type: Handler | None = None, ) -> ServerHandle: if from_state_folder is not None: with open(f"{from_state_folder}/config.json") as f: @@ -373,6 +377,7 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, + handler_type=handler_type, ) display( SyftInfo( diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 102b34ad99a..c73016c4419 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -61,6 +61,7 @@ from ..service.queue.base_queue import QueueConsumer from ..service.queue.base_queue import QueueProducer from ..service.queue.queue import APICallMessageHandler +from ..service.queue.queue import Handler from ..service.queue.queue import QueueManager from ..service.queue.queue_stash import APIEndpointQueueItem from ..service.queue.queue_stash import ActionQueueItem @@ -338,6 +339,7 @@ def __init__( smtp_host: str | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, + handler_type: Handler | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -381,10 +383,13 @@ def __init__( self.association_request_auto_approval = association_request_auto_approval + handler_type = ( + handler_type or Handler.Thread if thread_workers else Handler.Process + ) self.queue_config = self.create_queue_config( n_consumers=n_consumers, create_producer=create_producer, - thread_workers=thread_workers, + handler_type=handler_type, queue_port=queue_port, queue_config=queue_config, ) @@ -578,7 +583,7 @@ def create_queue_config( self, n_consumers: int, create_producer: bool, - thread_workers: bool, + handler_type: Handler, queue_port: int | None, queue_config: QueueConfig | None, ) -> QueueConfig: @@ -587,13 +592,14 @@ def create_queue_config( elif queue_port is not None or n_consumers > 0 or create_producer: if not create_producer and queue_port is None: logger.warn("No queue port defined to bind consumers.") + queue_config_ = ZMQQueueConfig( client_config=ZMQClientConfig( create_producer=create_producer, queue_port=queue_port, n_consumers=n_consumers, ), - thread_workers=thread_workers, + handler_type=handler_type, ) else: queue_config_ = ZMQQueueConfig() @@ -727,6 +733,7 @@ def named( in_memory_workers: bool = True, association_request_auto_approval: bool = False, background_tasks: bool = False, + handler_type: Handler | None = None, ) -> Server: uid = get_named_server_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() @@ -757,6 +764,7 @@ def named( reset=reset, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, + handler_type=handler_type, ) def is_root(self, credentials: SyftVerifyKey) -> bool: diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index da1ded8bd70..59cfbda374a 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -1,4 +1,5 @@ # stdlib +from enum import Enum import logging from multiprocessing import Process import threading @@ -35,6 +36,13 @@ logger = logging.getLogger(__name__) +@serializable(canonical_name="WorkerType", version=1) +class Handler(str, Enum): + Thread = "thread" + Process = "process" + Synchronous = "synchronous" + + class MonitorThread(threading.Thread): def __init__( self, @@ -300,17 +308,17 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: logger.info( f"Handling queue item: id={queue_item.id}, method={queue_item.method} " f"args={queue_item.args}, kwargs={queue_item.kwargs} " - f"service={queue_item.service}, as_thread={queue_config.thread_workers}" + f"service={queue_item.service}, as={queue_config.handler_type}" ) - if queue_config.thread_workers: + if queue_config.handler_type == Handler.Thread: thread = Thread( target=handle_message_multiprocessing, args=(worker_settings, queue_item, credentials), ) thread.start() thread.join() - else: + elif queue_config.handler_type == Handler.Process: # if psutil.pid_exists(job_item.job_pid): # psutil.Process(job_item.job_pid).terminate() process = Process( @@ -321,3 +329,5 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: job_item.job_pid = process.pid worker.job_stash.set_result(credentials, job_item).unwrap() process.join() + elif queue_config.handler_type == Handler.Synchronous: + handle_message_multiprocessing(worker_settings, queue_item, credentials) diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py index deeeb97a32b..ab74704795e 100644 --- a/packages/syft/src/syft/service/queue/zmq_client.py +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -16,6 +16,7 @@ from .base_queue import QueueClient from .base_queue import QueueClientConfig from .base_queue import QueueConfig +from .queue import Handler from .queue_stash import QueueStash from .zmq_consumer import ZMQConsumer from .zmq_producer import ZMQProducer @@ -76,6 +77,7 @@ def add_producer( else: port = self.config.queue_port + print(f"Adding producer for queue: {queue_name} on: {get_queue_address(port)}") producer = ZMQProducer( queue_name=queue_name, queue_stash=queue_stash, @@ -183,8 +185,8 @@ def __init__( self, client_type: type[ZMQClient] | None = None, client_config: ZMQClientConfig | None = None, - thread_workers: bool = False, + handler_type: Handler = Handler.Process, ): self.client_type = client_type or ZMQClient self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() - self.thread_workers = thread_workers + self.handler_type = handler_type diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index 4de8da60494..97b8d30c4bc 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -1,5 +1,6 @@ # stdlib import logging +import subprocess import threading from threading import Event @@ -28,6 +29,26 @@ logger = logging.getLogger(__name__) +def last_created_port() -> int: + command = ( + "lsof -i -P -n | grep '*:[0-9]* (LISTEN)' | grep python | awk '{print $9, $1, $2}' | " + "sort -k2,2 -k3,3n | tail -n 1 | awk '{print $1}' | cut -d':' -f2" + ) + # 1. Lists open files (including network connections) with lsof -i -P -n + # 2. Filters for listening ports with grep '*:[0-9]* (LISTEN)' + # 3. Further filters for Python processes with grep python + # 4. Sorts based on the 9th field (which is likely the port number) with sort -k9 + # 5. Takes the last 10 entries with tail -n 10 + # 6. Prints only the 9th field (port and address) with awk '{print $9}' + # 7. Extracts only the port number with cut -d':' -f2 + + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) + out, err = process.communicate() + return int(out.decode("utf-8").strip()) + + @serializable(attrs=["_subscriber"], canonical_name="ZMQConsumer", version=1) class ZMQConsumer(QueueConsumer): def __init__( @@ -54,6 +75,36 @@ def __init__( self.worker_stash = worker_stash self.post_init() + @classmethod + def default(cls, address: str | None = None, **kwargs: dict) -> "ZMQConsumer": + # relative + from ...types.uid import UID + from ..worker.utils import DEFAULT_WORKER_POOL_NAME + from .queue import APICallMessageHandler + + if address is None: + try: + address = f"tcp://localhost:{last_created_port()}" + except Exception: + raise Exception( + "Could not auto-assign ZMQConsumer address. Please provide one." + ) + print(f"Auto-assigning ZMQConsumer address: {address}. Please verify.") + default_kwargs = { + "message_handler": APICallMessageHandler, + "queue_name": APICallMessageHandler.queue_name, + "service_name": DEFAULT_WORKER_POOL_NAME, + "syft_worker_id": UID(), + "verbose": True, + "address": address, + } + + for key, value in kwargs.items(): + if key in default_kwargs: + default_kwargs[key] = value + + return cls(**default_kwargs) + def reconnect_to_producer(self) -> None: """Connect or reconnect to producer""" if self.socket: From 97c5ba11e306bc5a6cd4fe409453a8f9e362aece Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 16:14:13 +0200 Subject: [PATCH 2/3] ignore sec warning --- packages/syft/src/syft/service/queue/zmq_consumer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index 97b8d30c4bc..f6993d6b032 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -1,6 +1,6 @@ # stdlib import logging -import subprocess +import subprocess # nosec import threading from threading import Event @@ -42,7 +42,7 @@ def last_created_port() -> int: # 6. Prints only the 9th field (port and address) with awk '{print $9}' # 7. Extracts only the port number with cut -d':' -f2 - process = subprocess.Popen( + process = subprocess.Popen( # nosec command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True ) out, err = process.communicate() From 3c2114a10e2237627052505f72482e51a7ba28bd Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 16:36:42 +0200 Subject: [PATCH 3/3] rename ConsumerHandler --- packages/syft/src/syft/orchestra.py | 10 +++++----- packages/syft/src/syft/server/server.py | 20 ++++++++++--------- packages/syft/src/syft/service/queue/queue.py | 10 +++++----- .../syft/src/syft/service/queue/zmq_client.py | 6 +++--- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 6dcba7bf154..0d295b81982 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -28,7 +28,7 @@ from .server.enclave import Enclave from .server.gateway import Gateway from .server.uvicorn import serve_server -from .service.queue.queue import Handler +from .service.queue.queue import ConsumerType from .service.response import SyftInfo from .types.errors import SyftException from .util.util import get_random_available_port @@ -183,7 +183,7 @@ def deploy_to_python( log_level: str | int | None = None, debug: bool = False, migrate: bool = False, - handler_type: Handler | None = None, + consumer_type: ConsumerType | None = None, ) -> ServerHandle: worker_classes = { ServerType.DATASITE: Datasite, @@ -215,7 +215,7 @@ def deploy_to_python( "debug": debug, "migrate": migrate, "deployment_type": deployment_type_enum, - "handler_type": handler_type, + "consumer_type": consumer_type, } if port: @@ -328,7 +328,7 @@ def launch( debug: bool = False, migrate: bool = False, from_state_folder: str | Path | None = None, - handler_type: Handler | None = None, + consumer_type: ConsumerType | None = None, ) -> ServerHandle: if from_state_folder is not None: with open(f"{from_state_folder}/config.json") as f: @@ -377,7 +377,7 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, - handler_type=handler_type, + consumer_type=consumer_type, ) display( SyftInfo( diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index c73016c4419..f9a05ca1279 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -61,7 +61,7 @@ from ..service.queue.base_queue import QueueConsumer from ..service.queue.base_queue import QueueProducer from ..service.queue.queue import APICallMessageHandler -from ..service.queue.queue import Handler +from ..service.queue.queue import ConsumerType from ..service.queue.queue import QueueManager from ..service.queue.queue_stash import APIEndpointQueueItem from ..service.queue.queue_stash import ActionQueueItem @@ -339,7 +339,7 @@ def __init__( smtp_host: str | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, - handler_type: Handler | None = None, + consumer_type: ConsumerType | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -383,13 +383,15 @@ def __init__( self.association_request_auto_approval = association_request_auto_approval - handler_type = ( - handler_type or Handler.Thread if thread_workers else Handler.Process + consumer_type = ( + consumer_type or ConsumerType.Thread + if thread_workers + else ConsumerType.Process ) self.queue_config = self.create_queue_config( n_consumers=n_consumers, create_producer=create_producer, - handler_type=handler_type, + consumer_type=consumer_type, queue_port=queue_port, queue_config=queue_config, ) @@ -583,7 +585,7 @@ def create_queue_config( self, n_consumers: int, create_producer: bool, - handler_type: Handler, + consumer_type: ConsumerType, queue_port: int | None, queue_config: QueueConfig | None, ) -> QueueConfig: @@ -599,7 +601,7 @@ def create_queue_config( queue_port=queue_port, n_consumers=n_consumers, ), - handler_type=handler_type, + consumer_type=consumer_type, ) else: queue_config_ = ZMQQueueConfig() @@ -733,7 +735,7 @@ def named( in_memory_workers: bool = True, association_request_auto_approval: bool = False, background_tasks: bool = False, - handler_type: Handler | None = None, + consumer_type: ConsumerType | None = None, ) -> Server: uid = get_named_server_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() @@ -764,7 +766,7 @@ def named( reset=reset, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, - handler_type=handler_type, + consumer_type=consumer_type, ) def is_root(self, credentials: SyftVerifyKey) -> bool: diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 59cfbda374a..aa2e99c6ba4 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -37,7 +37,7 @@ @serializable(canonical_name="WorkerType", version=1) -class Handler(str, Enum): +class ConsumerType(str, Enum): Thread = "thread" Process = "process" Synchronous = "synchronous" @@ -308,17 +308,17 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: logger.info( f"Handling queue item: id={queue_item.id}, method={queue_item.method} " f"args={queue_item.args}, kwargs={queue_item.kwargs} " - f"service={queue_item.service}, as={queue_config.handler_type}" + f"service={queue_item.service}, as={queue_config.consumer_type}" ) - if queue_config.handler_type == Handler.Thread: + if queue_config.consumer_type == ConsumerType.Thread: thread = Thread( target=handle_message_multiprocessing, args=(worker_settings, queue_item, credentials), ) thread.start() thread.join() - elif queue_config.handler_type == Handler.Process: + elif queue_config.consumer_type == ConsumerType.Process: # if psutil.pid_exists(job_item.job_pid): # psutil.Process(job_item.job_pid).terminate() process = Process( @@ -329,5 +329,5 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: job_item.job_pid = process.pid worker.job_stash.set_result(credentials, job_item).unwrap() process.join() - elif queue_config.handler_type == Handler.Synchronous: + elif queue_config.consumer_type == ConsumerType.Synchronous: handle_message_multiprocessing(worker_settings, queue_item, credentials) diff --git a/packages/syft/src/syft/service/queue/zmq_client.py b/packages/syft/src/syft/service/queue/zmq_client.py index ab74704795e..9265d9edd3d 100644 --- a/packages/syft/src/syft/service/queue/zmq_client.py +++ b/packages/syft/src/syft/service/queue/zmq_client.py @@ -16,7 +16,7 @@ from .base_queue import QueueClient from .base_queue import QueueClientConfig from .base_queue import QueueConfig -from .queue import Handler +from .queue import ConsumerType from .queue_stash import QueueStash from .zmq_consumer import ZMQConsumer from .zmq_producer import ZMQProducer @@ -185,8 +185,8 @@ def __init__( self, client_type: type[ZMQClient] | None = None, client_config: ZMQClientConfig | None = None, - handler_type: Handler = Handler.Process, + consumer_type: ConsumerType = ConsumerType.Process, ): self.client_type = client_type or ZMQClient self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() - self.handler_type = handler_type + self.consumer_type = consumer_type