Skip to content

Commit

Permalink
Supporting multiple connections (#125)
Browse files Browse the repository at this point in the history
* Simplifying and fixing multiple client connections
  • Loading branch information
rsamf authored Dec 23, 2024
1 parent bfd884a commit e75b098
Show file tree
Hide file tree
Showing 13 changed files with 653 additions and 705 deletions.
12 changes: 11 additions & 1 deletion graphbook/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .note import Note
from .decorators import step, param, source, output, batch, resource, event, prompt

__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "prompt", "Note"]
__all__ = [
"step",
"param",
"source",
"output",
"batch",
"resource",
"event",
"prompt",
"Note",
]
105 changes: 64 additions & 41 deletions graphbook/clients.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Dict
from typing import List, Dict
import uuid
from aiohttp.web import WebSocketResponse
from .processing.web_processor import WebInstanceProcessor
from .utils import ProcessorStateRequest
from .nodes import NodeHub
from .viewer import ViewManager
import tempfile
Expand All @@ -25,7 +24,6 @@ def __init__(
node_hub: NodeHub,
view_manager: ViewManager,
setup_paths: dict,
options: dict = DEFAULT_CLIENT_OPTIONS,
):
self.sid = sid
self.ws = ws
Expand All @@ -35,8 +33,7 @@ def __init__(
self.root_path = Path(setup_paths["workflow_dir"])
self.docs_path = Path(setup_paths["docs_path"])
self.custom_nodes_path = Path(setup_paths["custom_nodes_path"])
self.options = options
self.curr_task = None
self.close_event = asyncio.Event()

def get_root_path(self) -> Path:
return self.root_path
Expand All @@ -59,31 +56,17 @@ def resource_doc(self, name):
def exec(self, req: dict):
self.processor.exec(req)

def poll(self, cmd: ProcessorStateRequest, data: dict = None):
res = self.processor.poll_client(
cmd,
data,
)
return res
def get_processor(self) -> WebInstanceProcessor:
return self.processor

async def _loop(self):
while True:
await asyncio.sleep(self.options["SEND_EVERY"])
current_view_data = self.view_manager.get_current_view_data()
current_states = self.view_manager.get_current_states()
all_data = [*current_view_data, *current_states]
await asyncio.gather(*[self.ws.send_json(data) for data in all_data])
def get_view_manager(self) -> ViewManager:
return self.view_manager

def start(self):
loop = asyncio.get_event_loop()
self.curr_task = loop.create_task(self._loop())
def get_node_hub(self) -> NodeHub:
return self.node_hub

async def close(self):
if self.curr_task is not None:
self.curr_task.cancel()
await self.ws.close()
self.processor.close()
self.node_hub.stop()


class ClientPool:
Expand All @@ -95,6 +78,7 @@ def __init__(
isolate_users: bool,
no_sample: bool,
close_event: mp.Event,
options: dict = DEFAULT_CLIENT_OPTIONS,
):
self.clients: Dict[str, Client] = {}
self.tmpdirs: Dict[str, str] = {}
Expand All @@ -104,10 +88,12 @@ def __init__(
self.shared_execution = not isolate_users
self.no_sample = no_sample
self.close_event = close_event
self.options = options
if self.shared_execution:
self.shared_resources = self._create_resources(
web_processor_args, setup_paths
)
self.curr_task = None

def _create_resources(self, web_processor_args: dict, setup_paths: dict):
view_queue = mp.Queue()
Expand All @@ -116,9 +102,11 @@ def _create_resources(self, web_processor_args: dict, setup_paths: dict):
"custom_nodes_path": setup_paths["custom_nodes_path"],
"view_manager_queue": view_queue,
}
if not self.shared_execution:
processor_args["cwd"] = setup_paths["workflow_dir"].parent
self._create_dirs(**setup_paths, no_sample=self.no_sample)
processor = WebInstanceProcessor(**processor_args)
view_manager = ViewManager(view_queue, self.close_event, processor)
view_manager = ViewManager(view_queue, processor)
node_hub = NodeHub(setup_paths["custom_nodes_path"], self.plugins, view_manager)
processor.start()
view_manager.start()
Expand Down Expand Up @@ -156,10 +144,12 @@ def create_sample_workflow():
if should_create_sample:
create_sample_workflow()

def add_client(self, ws: WebSocketResponse) -> Client:
async def add_client(self, ws: WebSocketResponse) -> Client:
sid = uuid.uuid4().hex
setup_paths = {**self.setup_paths}
if not self.shared_execution:
if self.shared_execution:
resources = self.shared_resources
else:
root_path = Path(tempfile.mkdtemp())
self.tmpdirs[sid] = root_path
setup_paths = {
Expand All @@ -170,32 +160,65 @@ def add_client(self, ws: WebSocketResponse) -> Client:
"custom_nodes_path": setup_paths["custom_nodes_path"],
}
resources = self._create_resources(web_processor_args, setup_paths)
else:
resources = self.shared_resources

client = Client(sid, ws, **resources, setup_paths=setup_paths)
client.start()
self.clients[sid] = client
asyncio.create_task(ws.send_json({"type": "sid", "data": sid}))
await ws.send_json({"type": "sid", "data": sid})
print(f"{sid}: {client.get_root_path()}")
return client

def get(self, sid: str) -> Client | None:
return self.clients.get(sid, None)

async def remove_client(self, client: Client):
sid = client.sid
if sid in self.clients:
await client.close()
del self.clients[sid]
if not self.shared_execution:
client.get_processor().stop()
client.get_node_hub().stop()
client.get_view_manager().stop()
if sid in self.tmpdirs:
shutil.rmtree(self.tmpdirs[sid])
del self.tmpdirs[sid]

async def remove_all(self):
for sid in self.clients:
await self.clients[sid].close()
for sid in self.tmpdirs:
os.rmdir(self.tmpdirs[sid])
self.clients = {}
self.tmpdirs = {}
async def stop(self):
for client in list(self.clients.values()):
await self.remove_client(client)
if self.curr_task:
self.curr_task.cancel()
if self.shared_execution:
self.shared_resources["processor"].close()
self.shared_resources["node_hub"].stop()

def get(self, sid: str) -> Client | None:
return self.clients.get(sid, None)
async def _loop(self):
def get_view_data(view_manager: ViewManager) -> List[dict]:
current_view_data = view_manager.get_current_view_data()
current_states = view_manager.get_current_states()
return [*current_view_data, *current_states]

while not self.close_event.is_set():
await asyncio.sleep(self.options["SEND_EVERY"])

if self.shared_execution:
all_data = get_view_data(self.shared_resources["view_manager"])
for client in self.clients.values():
try:
await asyncio.gather(
*[client.ws.send_json(data) for data in all_data]
)
except Exception as e:
print(f"Error sending to client: {e}")
else:
for client in self.clients.values():
all_data = get_view_data(client.get_view_manager())
try:
await asyncio.gather(
*[client.ws.send_json(data) for data in all_data]
)
except Exception as e:
print(f"Error sending to client: {e}")

async def start(self):
self.curr_task = asyncio.create_task(self._loop())
61 changes: 0 additions & 61 deletions graphbook/logger.py

This file was deleted.

Loading

0 comments on commit e75b098

Please sign in to comment.