-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlogit_example.py
70 lines (56 loc) · 2.3 KB
/
logit_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Any, Tuple
import pandas
import delta.dataset
from delta import DeltaNode
from delta.statsmodel import LogitTask, MNLogitTask
class SpectorLogitTask(LogitTask):
def __init__(
self,
) -> None:
super().__init__(
name="spector_logit", # The task name which is used for displaying purpose.
min_clients=2, # Minimum nodes required in each round, must be greater than 2.
max_clients=3, # Maximum nodes allowed in each round, must be greater equal than min_clients.
wait_timeout=5, # Timeout for calculation.
connection_timeout=5, # Wait timeout for each step.
verify_timeout=360, # Timeout for the final zero knownledge verification step
enable_verify=False # whether to enable final zero knownledge verification step
)
def dataset(self):
return {
"data": delta.dataset.DataFrame("spector.csv"),
}
def preprocess(self, data: pandas.DataFrame) -> Tuple[Any, Any]:
names = data.columns
y_name = names[3]
y = data[y_name].copy() # type: ignore
x = data.drop([y_name], axis=1)
return x, y
class IrisLogitTask(MNLogitTask):
def __init__(
self,
) -> None:
super().__init__(
name="spector_logit", # The task name which is used for displaying purpose.
min_clients=2, # Minimum nodes required in each round, must be greater than 2.
max_clients=3, # Maximum nodes allowed in each round, must be greater equal than min_clients.
wait_timeout=5, # Timeout for calculation.
connection_timeout=5, # Wait timeout for each step.
)
def dataset(self):
return {"data": delta.dataset.DataFrame("iris.csv")}
def preprocess(self, data: pandas.DataFrame) -> Tuple[Any, Any]:
y = data["target"].copy()
x = data.drop(["target"], axis=1)
return x, y
if __name__ == "__main__":
task = SpectorLogitTask().build()
# task = IrisLogitTask().build()
DELTA_NODE_API = "http://127.0.0.1:6700"
delta_node = DeltaNode(DELTA_NODE_API)
task_id = delta_node.create_task(task)
if delta_node.trace(task_id):
res = delta_node.get_result(task_id)
print(res)
else:
print("Task error")