Skip to content

Commit a5864c3

Browse files
authored
Merge pull request #8435 from madhavajay/madhava/fix_hagrid
Fixed issue where new args in hagrid break syft code
2 parents 5130907 + f3dc151 commit a5864c3

File tree

2 files changed

+55
-61
lines changed

2 files changed

+55
-61
lines changed

packages/hagrid/hagrid/orchestra.py

+29-50
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Python Level API to launch Docker Containers using Hagrid"""
2+
23
# future
34
from __future__ import annotations
45

@@ -255,38 +256,34 @@ def deploy_to_python(
255256
print("Staging Protocol Changes...")
256257
stage_protocol_changes()
257258

259+
kwargs = {
260+
"name": name,
261+
"host": host,
262+
"port": port,
263+
"reset": reset,
264+
"processes": processes,
265+
"dev_mode": dev_mode,
266+
"tail": tail,
267+
"node_type": node_type_enum,
268+
"node_side_type": node_side_type,
269+
"enable_warnings": enable_warnings,
270+
# new kwargs
271+
"queue_port": queue_port,
272+
"n_consumers": n_consumers,
273+
"create_producer": create_producer,
274+
}
275+
258276
if port:
277+
kwargs["in_memory_workers"] = True
259278
if port == "auto":
260279
# dont use default port to prevent port clashes in CI
261280
port = find_available_port(host="localhost", port=None, search=True)
281+
kwargs["port"] = port
282+
262283
sig = inspect.signature(sy.serve_node)
263-
if "node_type" in sig.parameters.keys():
264-
start, stop = sy.serve_node(
265-
name=name,
266-
host=host,
267-
port=port,
268-
reset=reset,
269-
processes=processes,
270-
queue_port=queue_port,
271-
n_consumers=n_consumers,
272-
create_producer=create_producer,
273-
dev_mode=dev_mode,
274-
tail=tail,
275-
node_type=node_type_enum,
276-
node_side_type=node_side_type,
277-
enable_warnings=enable_warnings,
278-
in_memory_workers=True, # Only in-memory workers supported for python mode
279-
)
280-
else:
281-
# syft <= 0.8.1
282-
start, stop = sy.serve_node(
283-
name=name,
284-
host=host,
285-
port=port,
286-
reset=reset,
287-
dev_mode=dev_mode,
288-
tail=tail,
289-
)
284+
supported_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
285+
286+
start, stop = sy.serve_node(**supported_kwargs)
290287
start()
291288
return NodeHandle(
292289
node_type=node_type_enum,
@@ -298,33 +295,15 @@ def deploy_to_python(
298295
node_side_type=node_side_type,
299296
)
300297
else:
298+
kwargs["local_db"] = local_db
299+
kwargs["thread_workers"] = thread_workers
301300
if node_type_enum in worker_classes:
302301
worker_class = worker_classes[node_type_enum]
303302
sig = inspect.signature(worker_class.named)
303+
supported_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
304304
if "node_type" in sig.parameters.keys():
305-
worker = worker_class.named(
306-
dev_mode=dev_mode,
307-
name=name,
308-
processes=processes,
309-
reset=reset,
310-
local_db=local_db,
311-
node_type=node_type_enum,
312-
node_side_type=node_side_type,
313-
enable_warnings=enable_warnings,
314-
n_consumers=n_consumers,
315-
thread_workers=thread_workers,
316-
create_producer=create_producer,
317-
queue_port=queue_port,
318-
migrate=True,
319-
)
320-
else:
321-
# syft <= 0.8.1
322-
worker = worker_class.named(
323-
name=name,
324-
processes=processes,
325-
reset=reset,
326-
local_db=local_db,
327-
)
305+
supported_kwargs["migrate"] = True
306+
worker = worker_class.named(**supported_kwargs)
328307
else:
329308
raise NotImplementedError(f"node_type: {node_type_enum} is not supported")
330309

packages/syft/src/syft/node/run.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
# stdlib
22
import argparse
3+
from typing import Optional
34

45
# relative
56
from ..client.deploy import Orchestra
67

78

9+
def str_to_bool(bool_str: Optional[str]) -> bool:
10+
result = False
11+
bool_str = str(bool_str).lower()
12+
if bool_str == "true" or bool_str == "1":
13+
result = True
14+
return result
15+
16+
817
def run():
918
parser = argparse.ArgumentParser()
1019
parser.add_argument("command", help="command: launch", type=str, default="none")
@@ -28,43 +37,43 @@ def run():
2837
parser.add_argument(
2938
"--dev-mode",
3039
help="developer mode",
31-
type=bool,
32-
default=True,
40+
type=str,
41+
default="True",
3342
dest="dev_mode",
3443
)
3544
parser.add_argument(
3645
"--reset",
3746
help="reset",
38-
type=bool,
39-
default=True,
47+
type=str,
48+
default="True",
4049
dest="reset",
4150
)
4251
parser.add_argument(
4352
"--local-db",
4453
help="reset",
45-
type=bool,
46-
default=False,
54+
type=str,
55+
default="False",
4756
dest="local_db",
4857
)
4958
parser.add_argument(
5059
"--processes",
5160
help="processing mode",
5261
type=int,
53-
default=False,
62+
default=0,
5463
dest="processes",
5564
)
5665
parser.add_argument(
5766
"--tail",
5867
help="tail mode",
59-
type=bool,
60-
default=True,
68+
type=str,
69+
default="True",
6170
dest="tail",
6271
)
6372
parser.add_argument(
6473
"--cmd",
6574
help="cmd mode",
66-
type=bool,
67-
default=False,
75+
type=str,
76+
default="False",
6877
dest="cmd",
6978
)
7079

@@ -73,6 +82,12 @@ def run():
7382
if args.command != "launch":
7483
print("syft launch is the only command currently supported")
7584

85+
args.dev_mode = str_to_bool(args.dev_mode)
86+
args.reset = str_to_bool(args.reset)
87+
args.local_db = str_to_bool(args.local_db)
88+
args.tail = str_to_bool(args.tail)
89+
args.cmd = str_to_bool(args.cmd)
90+
7691
node = Orchestra.launch(
7792
name=args.name,
7893
node_type=args.node_type,

0 commit comments

Comments
 (0)