Skip to content

Commit

Permalink
System Util (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsamf authored Jul 22, 2024
1 parent 4090f51 commit 4147f76
Show file tree
Hide file tree
Showing 22 changed files with 1,829 additions and 366 deletions.
30 changes: 18 additions & 12 deletions graphbook/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def worker_loop(
dump_ctr = rank
while not close_event.is_set():
do_load(load_queue, load_result_queue)
did_receive_work = do_dump(dump_queue, dump_result_queue, dump_dir, dump_ctr)
did_receive_work = do_dump(
dump_queue, dump_result_queue, dump_dir, dump_ctr
)
if did_receive_work:
dump_ctr += num_processes
except KeyboardInterrupt:
Expand Down Expand Up @@ -126,7 +128,7 @@ def setup(self, consumer_ids: List[int]):
self.consumer_dump_queues[c] = queue.Queue()
unused_ids = set(self.consumer_load_queues.keys()) - set(consumer_ids)
for c in unused_ids:
self.total_consumer_size -= self.consumer_dump_queues[c].qsize()
self.total_consumer_size -= self.consumer_load_queues[c].qsize()
del self.consumer_load_queues[c]
unused_ids = set(self.consumer_dump_queues.keys()) - set(consumer_ids)
for c in unused_ids:
Expand Down Expand Up @@ -163,20 +165,24 @@ def _handle_queues(self):
result, consumer_id = q.get(False)
consumers[consumer_id].put(result, block=False)
self.total_consumer_size += 1

def get_all_sizes(self):
return {
"load": [q.qsize() for q in self._load_queues],
"dump": [q.qsize() for q in self._dump_queues],
"load_result": [q.qsize() for q in self._load_result_queues],
"dump_result": [q.qsize() for q in self._dump_result_queues],
"total_consumer_size": self.total_consumer_size,
}

def put_load(
self, items: list, record_id: int, load_fn: callable, consumer_id: int
) -> bool:
try:
for i, item in enumerate(items):
self._load_queues[self._worker_queue_cycle].put(
(item, i, record_id, load_fn, consumer_id), block=False
)
except queue.Full:
return False
finally:
):
for i, item in enumerate(items):
self._load_queues[self._worker_queue_cycle].put(
(item, i, record_id, load_fn, consumer_id), block=False
)
self._worker_queue_cycle = (self._worker_queue_cycle + 1) % self.num_workers
return True

def get_load(self, consumer_id):
self._handle_queues()
Expand Down
5 changes: 3 additions & 2 deletions graphbook/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

default_exported_steps = {
"Split": steps.Split,
"SplitRecordsByItems": steps.SplitNotesByItems,
"SplitNotesByItems": steps.SplitNotesByItems,
"SplitItemField": steps.SplitItemField,
"DumpJSONL": steps.DumpJSONL,
"LoadJSONL": steps.LoadJSONL,
}

default_exported_resources = {
"Text": rbase.Resource,
"Function": rbase.FunctionResource
"Number": rbase.NumberResource,
"Function": rbase.FunctionResource,
}


Expand Down
2 changes: 1 addition & 1 deletion graphbook/note.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

class Note:
"""
The unit that passes through workflow steps. A Note contains a dictionary of items related to the record.
The unit that passes through workflow steps. A Note contains a dictionary of items related to the note.
Args:
items (Dict[str, any]): An optional dictionary of items to store in the Note
Expand Down
98 changes: 83 additions & 15 deletions graphbook/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import queue
import multiprocessing as mp
import multiprocessing.connection as mpc
from graphbook.utils import MP_WORKER_TIMEOUT
from graphbook.utils import MP_WORKER_TIMEOUT, ProcessorStateRequest
from graphbook.state import GraphState, StepState
from graphbook.viewer import ViewManagerInterface
import traceback
import asyncio
import time


class WebInstanceProcessor:
Expand All @@ -28,17 +29,21 @@ def __init__(
self.close_event = close_event
self.pause_event = pause_event
self.view_manager = ViewManagerInterface(view_manager_queue)
self.graph_state = GraphState(custom_nodes_path, view_manager_queue, server_request_conn, close_event)
self.graph_state = GraphState(custom_nodes_path, view_manager_queue)
self.output_dir = output_dir
self.custom_nodes_path = custom_nodes_path
self.num_workers = num_workers
self.steps = {}
self.dataloader = Dataloader(self.output_dir, self.num_workers)
self.state_client = ProcessorStateClient(
server_request_conn, close_event, self.graph_state, self.dataloader
)
self.is_running = False

def exec_step(self, step: Step, input: Note = None, flush: bool = False):
outputs = {}
step_fn = step if not flush else step.all
start_time = time.time()
try:
if input is None:
outputs = step_fn()
Expand All @@ -55,6 +60,8 @@ def exec_step(self, step: Step, input: Note = None, flush: bool = False):
if outputs:
self.graph_state.handle_outputs(step.id, outputs)
self.view_manager.handle_outputs(step.id, outputs)
self.view_manager.handle_time(step.id, time.time() - start_time)

return outputs

def handle_steps(self, steps: List[Step]) -> bool:
Expand Down Expand Up @@ -93,7 +100,9 @@ def step_until_received_output(self, steps: List[Step], step_id: str):
step_executed = False
while is_active and not step_executed and not self.pause_event.is_set():
is_active = self.handle_steps(steps)
step_executed = self.graph_state.get_state(step_id, StepState.EXECUTED_THIS_RUN)
step_executed = self.graph_state.get_state(
step_id, StepState.EXECUTED_THIS_RUN
)

def run(self, step_id: str = None):
self.is_running = True
Expand Down Expand Up @@ -140,36 +149,95 @@ def setup_dataloader(self, steps: List[Step]):

def __str__(self):
return self.root.__str__()
def try_update_state(self, queue_entry: dict):

def try_update_state(self, queue_entry: dict) -> bool:
try:
self.graph_state.update_state(queue_entry["graph"], queue_entry["resources"])
self.graph_state.update_state(
queue_entry["graph"], queue_entry["resources"]
)
return True
except Exception as e:
traceback.print_exc()
return False

async def start_loop(self):
loop = asyncio.get_running_loop()
loop.run_in_executor(None, self.graph_state.start_client_loop)
loop.run_in_executor(None, self.state_client.start)
while not self.close_event.is_set():
if self.is_running:
self.is_running = False
self.view_manager.handle_run_state(False)
try:
work = self.cmd_queue.get(timeout=MP_WORKER_TIMEOUT)
if work["cmd"] == "run_all":
self.try_update_state(work)
self.run()
if self.try_update_state(work):
self.run()
elif work["cmd"] == "run":
self.try_update_state(work)
self.run(work["step_id"])
if self.try_update_state(work):
self.run(work["step_id"])
elif work["cmd"] == "step":
self.try_update_state(work)
self.step(work["step_id"])
if self.try_update_state(work):
self.step(work["step_id"])
elif work["cmd"] == "clear":
self.try_update_state(work)
self.graph_state.clear_outputs(work.get("step_id"))
if self.try_update_state(work):
self.graph_state.clear_outputs(work.get("step_id"))
except KeyboardInterrupt:
self.cleanup()
break
except queue.Empty:
pass


class ProcessorStateClient:
def __init__(
self,
server_request_conn: mpc.Connection,
close_event: mp.Event,
graph_state: GraphState,
dataloader: Dataloader,
):
self.server_request_conn = server_request_conn
self.close_event = close_event
self.curr_task = None
self.graph_state = graph_state
self.dataloader = dataloader

def _loop(self):
while not self.close_event.is_set():
if self.server_request_conn.poll(timeout=MP_WORKER_TIMEOUT):
req = self.server_request_conn.recv()
if req["cmd"] == ProcessorStateRequest.GET_OUTPUT_NOTE:
step_id = req.get("step_id")
pin_id = req.get("pin_id")
index = req.get("index")
if step_id is None or pin_id is None or index is None:
output = {}
else:
output = self.graph_state.get_output_note(step_id, pin_id, index)
elif req["cmd"] == ProcessorStateRequest.GET_WORKER_QUEUE_SIZES:
output = self.dataloader.get_all_sizes()
else:
output = {}
entry = {"res": req["cmd"], "data": output}
self.server_request_conn.send(entry)

def start(self):
self._loop()

def close(self):
if self.curr_task is not None:
self.curr_task.cancel()


def poll_conn_for(
conn: mpc.Connection, req: ProcessorStateRequest, body: dict = None
) -> dict:
req_data = {"cmd": req}
if body:
req_data.update(body)
conn.send(req_data)
if conn.poll(timeout=MP_WORKER_TIMEOUT):
res = conn.recv()
if res.get("res") == req:
return res.get("data")
return {}
2 changes: 1 addition & 1 deletion graphbook/resources/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .base import Resource, FunctionResource
from .base import Resource, NumberResource, FunctionResource
30 changes: 18 additions & 12 deletions graphbook/resources/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
class Resource:
Parameters = {
"val": {
"type": "string"
}
}
Parameters = {"val": {"type": "string"}}
Category = "Util"

def __init__(self, val):
self.val = val

def value(self):
return self.val

def __str__(self):
return str(self.val)



class NumberResource(Resource):
Parameters = {"val": {"type": "number"}}
Category = "Util"

def __init__(self, val):
super().__init__(val)

def value(self):
return float(self.val)


class FunctionResource(Resource):
Parameters = {
"val": {
"type": "function"
}
}
Parameters = {"val": {"type": "function"}}
Category = "Util"

def __init__(self, val):
super().__init__(val)
52 changes: 31 additions & 21 deletions graphbook/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import hashlib
from graphbook.state import UIState
from graphbook.media import MediaServer
from utils import MP_WORKER_TIMEOUT
from graphbook.utils import poll_conn_for, ProcessorStateRequest


@web.middleware
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
self.ui_state = None
routes = web.RouteTableDef()
self.routes = routes
self.view_manager = ViewManager(view_manager_queue, close_event)
self.view_manager = ViewManager(view_manager_queue, close_event, state_conn)
abs_root_path = osp.abspath(root_path)
middlewares = [cors_middleware]
max_upload_size = 100 # MB
Expand Down Expand Up @@ -168,28 +168,29 @@ async def get_output_note(request: web.Request) -> web.Response:
step_id = request.match_info.get("step_id")
pin_id = request.match_info.get("pin_id")
index = int(request.match_info.get("index"))
state_conn.send({
"cmd": "get_output_note",
"step_id": step_id,
"pin_id": pin_id,
"index": int(index)
})
if state_conn.poll(timeout=MP_WORKER_TIMEOUT):
res = state_conn.recv()
if res.get("step_id") == step_id and res.get("pin_id") == pin_id and res.get("index") == index:
return web.json_response(res)
else:
res = {"error": "Mismatched response"}
else:
res = {"error": "Timeout"}
return web.json_response(res)
res = poll_conn_for(
state_conn,
ProcessorStateRequest.GET_OUTPUT_NOTE,
{"step_id": step_id, "pin_id": pin_id, "index": int(index)},
)
if (
res
and res.get("step_id") == step_id
and res.get("pin_id") == pin_id
and res.get("index") == index
):
return web.json_response(res)

return web.json_response({"error": "Could not get output note."})

@routes.get("/fs")
@routes.get(r"/fs/{path:.+}")
def get(request: web.Request):
path = request.match_info.get("path", "")
fullpath = osp.join(root_path, path)
assert fullpath.startswith(root_path), f"{fullpath} must be within {root_path}"
assert fullpath.startswith(
root_path
), f"{fullpath} must be within {root_path}"

def handle_fs_tree(p: str, fn: callable) -> dict:
if osp.isdir(p):
Expand Down Expand Up @@ -340,7 +341,9 @@ def __init__(self, address, port, web_dir):

def start(self):
if not osp.isdir(self.cwd):
print(f"Couldn't find web files inside {self.cwd}. Will not start web server.")
print(
f"Couldn't find web files inside {self.cwd}. Will not start web server."
)
return
os.chdir(self.cwd)
with socketserver.TCPServer((self.address, self.port), self.server) as httpd:
Expand Down Expand Up @@ -368,7 +371,14 @@ def get_args():


def create_graph_server(
args, cmd_queue, state_conn, processor_pause_event, view_manager_queue, close_event, root_path, custom_nodes_path
args,
cmd_queue,
state_conn,
processor_pause_event,
view_manager_queue,
close_event,
root_path,
custom_nodes_path,
):
server = GraphServer(
cmd_queue,
Expand Down Expand Up @@ -443,7 +453,7 @@ def signal_handler(*_):

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

async def start():
processor = WebInstanceProcessor(
cmd_queue,
Expand Down
Loading

0 comments on commit 4147f76

Please sign in to comment.