forked from DAGWorks-Inc/burr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapplication.py
110 lines (86 loc) · 3.14 KB
/
application.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import burr.core.application
from burr.core import Action, Condition, State, default
class ProcessDataAction(Action):
@property
def reads(self) -> list[str]:
return ["data_path"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["training_data", "evaluation_data"]
def update(self, result: dict, state: State) -> State:
return state.update(training_data=result["training_data"])
class TrainModel(Action):
@property
def reads(self) -> list[str]:
return ["training_data", "epochs"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["models", "training_metrics", "epochs"]
def update(self, result: dict, state: State) -> State:
return state.update(
epochs=result["epochs"], # overwrite each epoch
).append(
models=result["model"], # append -- note this can get big if your model is big
# so you'll want to overwrite but store the conditions, or log somewhere
metrics=result["metrics"], # append the metrics
)
class ValidateModel(Action):
@property
def reads(self) -> list[str]:
return ["models", "evaluation_data"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["validation_metrics"]
def update(self, result: dict, state: State) -> State:
return state.append(validation_metrics=result["validation_metrics"])
class BestModel(Action):
@property
def reads(self) -> list[str]:
return ["validation_metrics", "models"]
def run(self, state: State) -> dict:
pass
@property
def writes(self) -> list[str]:
return ["best_model"]
def update(self, result: dict, state: State) -> State:
return state.update(best_model=result["best_model"])
def application(epochs: int) -> burr.core.application.Application:
return (
burr.core.ApplicationBuilder()
.with_state(
data_path="data.csv",
epochs=10,
training_data=None,
evaluation_data=None,
models=[],
training_metrics=[],
validation_metrics=[],
best_model=None,
)
.with_actions(
process_data=ProcessDataAction(),
train_model=TrainModel(),
validate_model=ValidateModel(),
best_model=BestModel(),
)
.with_transitions(
("process_data", "train_model", default),
("train_model", "validate_model", default),
("validate_model", "best_model", Condition.expr(f"epochs>{epochs}")),
("validate_model", "train_model", default),
)
.with_entrypoint("process_data")
.build()
)
if __name__ == "__main__":
app = application(100) # doing good data science is up to you...
# action, state, result = app.run(halt_after=["result"])
app.visualize(output_file_path="ml_training.png", include_conditions=True, view=True)
# assert state["counter"] == 10
# print(state["counter"])