-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnode.py
134 lines (102 loc) · 4.92 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import sys
from mqtt_client import MqttClient
import yaml
import pathlib
import os
import shlex, subprocess
import datetime
import time
class Node():
def callback_msg(self, topic: str, payload: str):
pay = json.loads(payload)
if topic == "master_status/" + str(self.user) and pay["start"]:
print('Your job is starting...\n')
self.waiting_status_msg = True
def __init__(self):
path = str(pathlib.Path(__file__).parent.resolve())
with open(os.path.join(path, 'config', 'config.yaml')) as file:
self.file_config = yaml.load(file, Loader=yaml.FullLoader)
self.timeout = self.file_config['timeout']
self.client = MqttClient(callback=self.callback_msg)
self.client.subscribe("master_status/#")
self.waiting_status_msg = False
self.user = sys.argv[1]
self.job_path = sys.argv[2]
self.gpu = sys.argv[3]
self.parallel = int(sys.argv[4])
data = {"user": str(self.user), "path": self.job_path, "gpu": self.gpu, "parallel": self.parallel}
payload = json.dumps(data)
self.client.publish("jobs/" + str(self.user), payload)
#check path
if not os.path.exists(self.job_path):
print('No such file')
print(f'Path provided: {self.job_path}')
return
self.loop()
def loop(self):
finished = False
while True:
while self.waiting_status_msg:
path = str(pathlib.Path(__file__).parent.resolve())
res = [-1]*self.parallel
p = []
index_lines = 0
last_iter = False
try:
with open(os.path.join(self.job_path)) as file:
lines = file.readlines()
try:
for i in range(len(lines)):
block_lines = []
if index_lines+self.parallel <= len(lines):
block_lines = lines[index_lines:index_lines+self.parallel]
else:
block_lines = lines[index_lines:len(lines)]
last_iter = True
res = [-1]*len(block_lines)
p.clear()
for line in block_lines:
args = shlex.split(str(line))
if not os.path.exists(args[1]):
print('No such python file')
print(f'Path provided: {args[1]}')
data = {"start": False, "end": True}
payload = json.dumps(data)
self.client.publish("node_status/" + str(self.user), payload)
self.client.stop()
finished = True
print('\nEnd all your jobs!')
time.sleep(10)
return
p.append(subprocess.Popen(args))
start_time = datetime.datetime.now().second
end_time = start_time + self.timeout
while start_time < end_time and sum(res) != 0:
for i in range(len(res)):
res[i] = p[i].wait(self.timeout - start_time)
start_time = datetime.datetime.now().second
for j in range(len(block_lines)):
if res[j] != 0:
p[j].kill()
if index_lines + self.parallel <= len(lines):
index_lines += self.parallel
if last_iter:
break
except Exception as e:
print(e)
pass
except Exception as e:
print(e)
pass
data = {"start": False, "end": True}
payload = json.dumps(data)
self.client.publish("node_status/" + str(self.user), payload)
self.client.stop()
finished = True
print('\nEnd all your jobs!')
break
if finished:
break
if __name__ == '__main__':
Node()