Skip to content

Commit

Permalink
Merge pull request #119 from graphbookai/dev
Browse files Browse the repository at this point in the history
Various improvements
  • Loading branch information
rsamf authored Nov 15, 2024
2 parents 7bebeb0 + 5c223c7 commit e2ea62e
Show file tree
Hide file tree
Showing 20 changed files with 365 additions and 192 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ ENV PATH=$PATH:/root/.local/bin
# Setup app
WORKDIR /app
COPY pyproject.toml poetry.lock ./
RUN poetry install --no-root --no-directory --with peer
RUN poetry install --no-directory --with peer
COPY . .
RUN make web

Expand Down
2 changes: 2 additions & 0 deletions graphbook/custom_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Split,
SplitNotesByItems,
SplitItemField,
Copy,
)
from graphbook.resources import Resource, NumberResource, FunctionResource, ListResource, DictResource

Expand All @@ -33,6 +34,7 @@
Split,
SplitNotesByItems,
SplitItemField,
Copy,
]
BUILT_IN_RESOURCES = [Resource, NumberResource, FunctionResource, ListResource, DictResource]

Expand Down
48 changes: 25 additions & 23 deletions graphbook/dataloading.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import List, Dict, Tuple, Any
import queue
import torch
from torch import Tensor
from torch import set_num_threads
import torch.multiprocessing as mp
import traceback
from .utils import MP_WORKER_TIMEOUT

torch.set_num_threads(1)
set_num_threads(1)
MAX_RESULT_QUEUE_SIZE = 32


Expand All @@ -15,12 +14,12 @@ def do_load(
) -> Tuple[bool, Any]:

try:
item, index, note_id = work_queue.get(False)
item, index, note_id, params = work_queue.get(False)
except queue.Empty:
return True, None

try:
output = load_fn(item)
output = load_fn(item, **params)
result = (output, index)
to_return = (result, note_id)
except Exception as e:
Expand Down Expand Up @@ -156,9 +155,15 @@ def dump_loop(


class Dataloader:
def __init__(self, num_workers: int = 1):
def __init__(self, num_workers: int = 1, spawn_method: bool = False):
self.num_workers = num_workers
self.manager = mp.Manager()
self.context = mp
if spawn_method:
print(
"Using spawn method is not recommended because it is more error prone. Try to avoid it as much as possible."
)
self.context = mp.get_context("spawn")
self.manager = self.context.Manager()
self._load_queues: Dict[int, mp.Queue] = self.manager.dict()
self._dump_queues: Dict[int, mp.Queue] = self.manager.dict()
self._load_result_queues: Dict[int, mp.Queue] = self.manager.dict()
Expand All @@ -171,20 +176,21 @@ def __init__(self, num_workers: int = 1):
self._pending_dump_results: List[PendingResult] = self.manager.list(
[None for _ in range(num_workers)]
)

self._workers: List[mp.Process] = []
self._loaders: List[mp.Process] = []
self._dumpers: List[mp.Process] = []
self._worker_queue_cycle = 0
self._close_event: mp.Event = mp.Event()
self._fail_event: mp.Event = mp.Event()
self._close_event: mp.Event = self.context.Event()
self._fail_event: mp.Event = self.context.Event()

def _start_workers(self):
if len(self._workers) > 0:
return
self._fail_event.clear()
self._close_event.clear()
for i in range(self.num_workers):
load_process = mp.Process(
load_process = self.context.Process(
target=load_loop,
args=(
i,
Expand All @@ -199,7 +205,7 @@ def _start_workers(self):
)
load_process.daemon = True
load_process.start()
dump_process = mp.Process(
dump_process = self.context.Process(
target=dump_loop,
args=(
i,
Expand Down Expand Up @@ -292,9 +298,6 @@ def get_all_sizes(self):
return sz

def clear(self, consumer_id: int | None = None):
# There's a weird issue where queue.empty() evaluates to True even though there are still items in the queue.
# So we instead close the queue because workers should be killed by now and will need to be restarted
# with new queues from the graph.
def clear_queue(q: mp.Queue):
while not q.empty():
try:
Expand Down Expand Up @@ -336,9 +339,13 @@ def clear_queue(q: mp.Queue):
if consumer_id in self._consumer_dump_fn:
del self._consumer_dump_fn[consumer_id]

def put_load(self, items: list, note_id: int, consumer_id: int):
def put_load(
self, items: list, load_fn_params: dict, note_id: int, consumer_id: int
):
for i, item in enumerate(items):
self._load_queues[consumer_id].put((item, i, note_id), block=False)
self._load_queues[consumer_id].put(
(item, i, note_id, load_fn_params), block=False
)

def get_load(self, consumer_id):
if consumer_id not in self._load_result_queues:
Expand All @@ -351,11 +358,6 @@ def get_load(self, consumer_id):
if result is None:
return None, note_id
out, index = result
# https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors
if isinstance(out, Tensor):
out_clone = out.clone()
del out
out = out_clone
return (out, index), note_id
except queue.Empty:
return None
Expand Down Expand Up @@ -395,9 +397,9 @@ def setup_global_dl(dataloader: Dataloader):
workers = dataloader


def put_load(items: list, note_id: int, consumer_id: int):
def put_load(items: list, load_fn_params: dict, note_id: int, consumer_id: int):
global workers
workers.put_load(items, note_id, consumer_id)
workers.put_load(items, load_fn_params, note_id, consumer_id)


def get_load(consumer_id):
Expand Down
1 change: 1 addition & 0 deletions graphbook/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"Split": steps.Split,
"SplitNotesByItems": steps.SplitNotesByItems,
"SplitItemField": steps.SplitItemField,
"Copy": steps.Copy,
"DumpJSONL": steps.DumpJSONL,
"LoadJSONL": steps.LoadJSONL,
}
Expand Down
5 changes: 5 additions & 0 deletions graphbook/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def get_args():
action="store_true",
help="Do not create a sample workflow if the workflow directory does not exist",
)
parser.add_argument(
"--spawn",
action="store_true",
help="Use the spawn start method for multiprocessing",
)

return parser.parse_args()

Expand Down
3 changes: 2 additions & 1 deletion graphbook/processing/web_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
custom_nodes_path: str,
close_event: mp.Event,
pause_event: mp.Event,
spawn_method: bool,
num_workers: int = 1,
):
self.cmd_queue = cmd_queue
Expand All @@ -53,7 +54,7 @@ def __init__(
self.custom_nodes_path = custom_nodes_path
self.num_workers = num_workers
self.steps = {}
self.dataloader = Dataloader(self.num_workers)
self.dataloader = Dataloader(self.num_workers, spawn_method)
setup_global_dl(self.dataloader)
self.state_client = ProcessorStateClient(
server_request_conn,
Expand Down
62 changes: 35 additions & 27 deletions graphbook/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _update_custom_nodes(self) -> dict:
if issubclass(obj, Resource):
self.nodes["resources"][name] = obj
updated_nodes["resources"][name] = True

for name, cls in get_steps().items():
self.nodes["steps"][name] = cls
updated_nodes["steps"][name] = True
Expand Down Expand Up @@ -158,6 +158,7 @@ def __init__(self, custom_nodes_path: str, view_manager_queue: mp.Queue):
self._node_catalog = NodeCatalog(custom_nodes_path)
self._updated_nodes: Dict[str, Dict[str, bool]] = {}
self._step_states: Dict[str, Set[StepState]] = {}
self._step_graph = {"child": {}, "parent": {}}

def update_state(self, graph: dict, graph_resources: dict):
nodes, is_updated = self._node_catalog.get_nodes()
Expand Down Expand Up @@ -214,9 +215,10 @@ def set_resource_value(resource_id, resource_data):
set_resource_value(resource_id, resource_data)

# Next, create all steps
steps = {}
queues = {}
step_states = {}
steps: Dict[str, Step] = {}
queues: Dict[str, MultiConsumerStateDictionaryQueue] = {}
step_states: Dict[str, Set[StepState]] = {}
step_graph = {"child": {}, "parent": {}}
logger_param_pool = {}
for step_id, step_data in graph.items():
step_name = step_data["name"]
Expand All @@ -240,6 +242,8 @@ def set_resource_value(resource_id, resource_data):
queues[step_id] = self._queues[step_id]
step_states[step_id] = self._step_states[step_id]
step_states[step_id].discard(StepState.EXECUTED_THIS_RUN)
step_graph["parent"][step_id] = self._step_graph["parent"][step_id]
step_graph["child"][step_id] = self._step_graph["child"][step_id]
logger_param_pool[id(self._steps[step_id])] = (step_id, step_name)
else:
try:
Expand All @@ -252,11 +256,15 @@ def set_resource_value(resource_id, resource_data):
step_states[step_id] = set()
logger_param_pool[id(step)] = (step_id, step_name)

# Remove old consumers from parents
previous_obj = self._steps.get(step_id)
if previous_obj is not None:
for parent in previous_obj.parents:
if parent.id in self._queues:
self._queues[parent.id].remove_consumer(id(previous_obj))
parent_ids = self._step_graph["parent"][previous_obj.id]
for parent_id in parent_ids:
if parent_id in self._queues:
self._queues[parent_id].remove_consumer(id(previous_obj))
step_graph["parent"][step_id] = set()
step_graph["child"][step_id] = set()

# Next, connect the steps
for step_id, step_data in graph.items():
Expand All @@ -265,27 +273,27 @@ def set_resource_value(resource_id, resource_data):
node = input["node"]
slot = input["slot"]
parent_node = steps[node]
if parent_node not in child_node.parents:
parent_node.set_child(child_node, slot)
step_graph["parent"][child_node.id].add(parent_node.id)
step_graph["child"][parent_node.id].add(child_node.id)
# Note: Two objects with non-overlapping lifetimes may have the same id() value.
# But in this case, the below child_node object is not overlapping because at
# this point, any previous nodes in the graph are still in self._steps
queues[parent_node.id].add_consumer(id(child_node), slot)

# Remove consumers from parents that are not children
for step_id in steps:
parent_node = steps[step_id]
children_ids = [
id(child)
for label_steps in parent_node.children.values()
for child in label_steps
id(steps[child_id]) for child_id in step_graph["child"][step_id]
]
queues[step_id].remove_except(children_ids)
queues[step_id].remove_all_except(children_ids)

def get_parent_iterator(step_id):
step = steps[step_id]
p_index = 0
parents = list(self._step_graph["parent"][step_id])
while True:
yield step.parents[p_index]
p_index = (p_index + 1) % len(step.parents)
yield parents[p_index]
p_index = (p_index + 1) % len(parents)

self._parent_iterators = {
step_id: get_parent_iterator(step_id) for step_id in steps
Expand All @@ -300,6 +308,7 @@ def get_parent_iterator(step_id):
self._queues = queues
self._resource_values = resource_values
self._step_states = step_states
self._step_graph = step_graph

def create_parent_subgraph(self, step_id: str):
new_steps = {}
Expand All @@ -311,9 +320,8 @@ def create_parent_subgraph(self, step_id: str):
continue

new_steps[step_id] = self._steps[step_id]
step = self._steps[step_id]
for input in step.parents:
q.append(input.id)
for parent_id in self._step_graph["parent"][step_id]:
q.append(parent_id)
return new_steps

def get_processing_steps(self, step_id: str = None):
Expand All @@ -330,9 +338,9 @@ def dfs(step_id):
return
visited.add(step_id)
step = steps[step_id]
for child in step.children.values():
for c in child:
dfs(c.id)
children = self._step_graph["child"][step_id]
for child_id in children:
dfs(child_id)
ordered_steps.append(step)

for step_id in steps:
Expand Down Expand Up @@ -374,12 +382,12 @@ def clear_outputs(self, node_id: str | None = None):
del self._resource_values[node_id], self._dict_resources[node_id]

def get_input(self, step: Step) -> Note:
num_parents = len(step.parents)
num_parents = len(self._step_graph["parent"][step.id])
i = 0
while i < num_parents:
next_parent = next(self._parent_iterators[step.id])
next_parent_id = next(self._parent_iterators[step.id])
try:
next_input = self._queues[next_parent.id].dequeue(id(step))
next_input = self._queues[next_parent_id].dequeue(id(step))
return next_input
except StopIteration:
i += 1
Expand All @@ -402,7 +410,7 @@ def get_output_note(self, step_id: str, pin_id: str, index: int) -> dict:
note = internal_list[index]
entry.update(data=note.items)
return entry

def handle_prompt_response(self, step_id: str, response: dict) -> bool:
step = self._steps.get(step_id)
if not isinstance(step, PromptStep):
Expand Down Expand Up @@ -444,7 +452,7 @@ def remove_consumer(self, consumer_id: int):
del self._consumer_idx[consumer_id]
del self._consumer_subs[consumer_id]

def remove_except(self, consumer_ids: List[int]):
def remove_all_except(self, consumer_ids: List[int]):
self._consumer_idx = {
k: v for k, v in self._consumer_idx.items() if k in consumer_ids
}
Expand Down
2 changes: 2 additions & 0 deletions graphbook/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Split,
SplitNotesByItems,
SplitItemField,
Copy
)
from .io import LoadJSONL, DumpJSONL

Expand All @@ -23,6 +24,7 @@
"Split",
"SplitNotesByItems",
"SplitItemField",
"Copy",
"LoadJSONL",
"DumpJSONL",
]
Loading

0 comments on commit e2ea62e

Please sign in to comment.