Skip to content

Commit

Permalink
rename ConsumerHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
abyesilyurt committed Sep 9, 2024
1 parent 97c5ba1 commit 3c2114a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 22 deletions.
10 changes: 5 additions & 5 deletions packages/syft/src/syft/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .server.enclave import Enclave
from .server.gateway import Gateway
from .server.uvicorn import serve_server
from .service.queue.queue import Handler
from .service.queue.queue import ConsumerType
from .service.response import SyftInfo
from .types.errors import SyftException
from .util.util import get_random_available_port
Expand Down Expand Up @@ -183,7 +183,7 @@ def deploy_to_python(
log_level: str | int | None = None,
debug: bool = False,
migrate: bool = False,
handler_type: Handler | None = None,
consumer_type: ConsumerType | None = None,
) -> ServerHandle:
worker_classes = {
ServerType.DATASITE: Datasite,
Expand Down Expand Up @@ -215,7 +215,7 @@ def deploy_to_python(
"debug": debug,
"migrate": migrate,
"deployment_type": deployment_type_enum,
"handler_type": handler_type,
"consumer_type": consumer_type,
}

if port:
Expand Down Expand Up @@ -328,7 +328,7 @@ def launch(
debug: bool = False,
migrate: bool = False,
from_state_folder: str | Path | None = None,
handler_type: Handler | None = None,
consumer_type: ConsumerType | None = None,
) -> ServerHandle:
if from_state_folder is not None:
with open(f"{from_state_folder}/config.json") as f:
Expand Down Expand Up @@ -377,7 +377,7 @@ def launch(
background_tasks=background_tasks,
debug=debug,
migrate=migrate,
handler_type=handler_type,
consumer_type=consumer_type,
)
display(
SyftInfo(
Expand Down
20 changes: 11 additions & 9 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from ..service.queue.base_queue import QueueConsumer
from ..service.queue.base_queue import QueueProducer
from ..service.queue.queue import APICallMessageHandler
from ..service.queue.queue import Handler
from ..service.queue.queue import ConsumerType
from ..service.queue.queue import QueueManager
from ..service.queue.queue_stash import APIEndpointQueueItem
from ..service.queue.queue_stash import ActionQueueItem
Expand Down Expand Up @@ -339,7 +339,7 @@ def __init__(
smtp_host: str | None = None,
association_request_auto_approval: bool = False,
background_tasks: bool = False,
handler_type: Handler | None = None,
consumer_type: ConsumerType | None = None,
):
# 🟡 TODO 22: change our ENV variable format and default init args to make this
# less horrible or add some convenience functions
Expand Down Expand Up @@ -383,13 +383,15 @@ def __init__(

self.association_request_auto_approval = association_request_auto_approval

handler_type = (
handler_type or Handler.Thread if thread_workers else Handler.Process
consumer_type = (
consumer_type or ConsumerType.Thread
if thread_workers
else ConsumerType.Process
)
self.queue_config = self.create_queue_config(
n_consumers=n_consumers,
create_producer=create_producer,
handler_type=handler_type,
consumer_type=consumer_type,
queue_port=queue_port,
queue_config=queue_config,
)
Expand Down Expand Up @@ -583,7 +585,7 @@ def create_queue_config(
self,
n_consumers: int,
create_producer: bool,
handler_type: Handler,
consumer_type: ConsumerType,
queue_port: int | None,
queue_config: QueueConfig | None,
) -> QueueConfig:
Expand All @@ -599,7 +601,7 @@ def create_queue_config(
queue_port=queue_port,
n_consumers=n_consumers,
),
handler_type=handler_type,
consumer_type=consumer_type,
)
else:
queue_config_ = ZMQQueueConfig()
Expand Down Expand Up @@ -733,7 +735,7 @@ def named(
in_memory_workers: bool = True,
association_request_auto_approval: bool = False,
background_tasks: bool = False,
handler_type: Handler | None = None,
consumer_type: ConsumerType | None = None,
) -> Server:
uid = get_named_server_uid(name)
name_hash = hashlib.sha256(name.encode("utf8")).digest()
Expand Down Expand Up @@ -764,7 +766,7 @@ def named(
reset=reset,
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
handler_type=handler_type,
consumer_type=consumer_type,
)

def is_root(self, credentials: SyftVerifyKey) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


@serializable(canonical_name="WorkerType", version=1)
class Handler(str, Enum):
class ConsumerType(str, Enum):
Thread = "thread"
Process = "process"
Synchronous = "synchronous"
Expand Down Expand Up @@ -308,17 +308,17 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None:
logger.info(
f"Handling queue item: id={queue_item.id}, method={queue_item.method} "
f"args={queue_item.args}, kwargs={queue_item.kwargs} "
f"service={queue_item.service}, as={queue_config.handler_type}"
f"service={queue_item.service}, as={queue_config.consumer_type}"
)

if queue_config.handler_type == Handler.Thread:
if queue_config.consumer_type == ConsumerType.Thread:
thread = Thread(
target=handle_message_multiprocessing,
args=(worker_settings, queue_item, credentials),
)
thread.start()
thread.join()
elif queue_config.handler_type == Handler.Process:
elif queue_config.consumer_type == ConsumerType.Process:
# if psutil.pid_exists(job_item.job_pid):
# psutil.Process(job_item.job_pid).terminate()
process = Process(
Expand All @@ -329,5 +329,5 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None:
job_item.job_pid = process.pid
worker.job_stash.set_result(credentials, job_item).unwrap()
process.join()
elif queue_config.handler_type == Handler.Synchronous:
elif queue_config.consumer_type == ConsumerType.Synchronous:
handle_message_multiprocessing(worker_settings, queue_item, credentials)
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/queue/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .base_queue import QueueClient
from .base_queue import QueueClientConfig
from .base_queue import QueueConfig
from .queue import Handler
from .queue import ConsumerType
from .queue_stash import QueueStash
from .zmq_consumer import ZMQConsumer
from .zmq_producer import ZMQProducer
Expand Down Expand Up @@ -185,8 +185,8 @@ def __init__(
self,
client_type: type[ZMQClient] | None = None,
client_config: ZMQClientConfig | None = None,
handler_type: Handler = Handler.Process,
consumer_type: ConsumerType = ConsumerType.Process,
):
self.client_type = client_type or ZMQClient
self.client_config: ZMQClientConfig = client_config or ZMQClientConfig()
self.handler_type = handler_type
self.consumer_type = consumer_type

0 comments on commit 3c2114a

Please sign in to comment.