diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b3e1521..2470540 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,16 +2,26 @@ name: Docker Build on: push: - branches: ["main"] - paths-ignore: - - 'docs/**' - - 'tests/**' + tags: + - "v*" + pull_request: jobs: - docker: + base: runs-on: group: larger-runners steps: + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + images: | + rsamf/graphbook + tags: | + type=ref,event=tag + type=semver,pattern={{version}} + type=sha + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -24,7 +34,46 @@ jobs: - name: Build and push uses: docker/build-push-action@v6 with: - push: true - tags: rsamf/graphbook:latest + push: ${{ github.event_name == 'push' && contains(github.ref, 'refs/tags') }} + file: ./docker/Dockerfile + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} cache-from: type=registry,ref=rsamf/graphbook:latest cache-to: type=inline + + hfspace: + needs: base + runs-on: + group: larger-runners + steps: + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + with: + flavor: | + suffix=-space + images: | + rsamf/graphbook + tags: | + type=ref,event=tag + type=semver,pattern={{version}} + type=sha + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push + uses: docker/build-push-action@v6 + with: + push: ${{ github.event_name == 'push' && contains(github.ref, 'refs/tags') }} + file: ./docker/Dockerfile.hfspace + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=registry,ref=rsamf/graphbook:latest-space + cache-to: type=inline diff --git a/README.md b/README.md index 3a6dc60..af59050 100644 --- a/README.md +++ b/README.md @@ -133,3 +133,4 @@ You can use any other virtual environment solution, but it is highly adviced to 1. `cd web` 1. `deno install` 1. `deno run dev` +1. In your browser, navigate to localhost:5173, and in the settings, change your **Graph Server Host** to `localhost:8005`. diff --git a/Dockerfile b/docker/Dockerfile similarity index 100% rename from Dockerfile rename to docker/Dockerfile diff --git a/docker/Dockerfile.hfspace b/docker/Dockerfile.hfspace new file mode 100644 index 0000000..20b13e4 --- /dev/null +++ b/docker/Dockerfile.hfspace @@ -0,0 +1,5 @@ +FROM rsamf/graphbook:latest + +RUN chown -R 1000:1000 . + +CMD ["python", "-m", "graphbook.main", "--isolate-users"] diff --git a/graphbook/clients.py b/graphbook/clients.py new file mode 100644 index 0000000..63ed3e7 --- /dev/null +++ b/graphbook/clients.py @@ -0,0 +1,201 @@ +from typing import Dict +import uuid +from aiohttp.web import WebSocketResponse +from .processing.web_processor import WebInstanceProcessor +from .utils import ProcessorStateRequest +from .nodes import NodeHub +from .viewer import ViewManager +import tempfile +import os.path as osp +from pathlib import Path +import multiprocessing as mp +import os +import asyncio +import shutil + +DEFAULT_CLIENT_OPTIONS = {"SEND_EVERY": 0.5} + + +class Client: + def __init__( + self, + sid: str, + ws: WebSocketResponse, + processor: WebInstanceProcessor, + node_hub: NodeHub, + view_manager: ViewManager, + setup_paths: dict, + options: dict = DEFAULT_CLIENT_OPTIONS, + ): + self.sid = sid + self.ws = ws + self.processor = processor + self.node_hub = node_hub + self.view_manager = view_manager + self.root_path = Path(setup_paths["workflow_dir"]) + self.docs_path = Path(setup_paths["docs_path"]) + self.custom_nodes_path = Path(setup_paths["custom_nodes_path"]) + self.options = options + self.curr_task = None + + def get_root_path(self) -> Path: + return self.root_path + + def get_docs_path(self) -> Path: + return self.docs_path + + def get_custom_nodes_path(self) -> Path: + return self.custom_nodes_path + + def nodes(self): + return self.node_hub.get_exported_nodes() + + def step_doc(self, name): + return self.node_hub.get_step_docstring(name) + + def resource_doc(self, name): + return self.node_hub.get_resource_docstring(name) + + def exec(self, req: dict): + self.processor.exec(req) + + def poll(self, cmd: ProcessorStateRequest, data: dict = None): + res = self.processor.poll_client( + cmd, + data, + ) + return res + + async def _loop(self): + while True: + await asyncio.sleep(self.options["SEND_EVERY"]) + current_view_data = self.view_manager.get_current_view_data() + current_states = self.view_manager.get_current_states() + all_data = [*current_view_data, *current_states] + await asyncio.gather(*[self.ws.send_json(data) for data in all_data]) + + def start(self): + loop = asyncio.get_event_loop() + self.curr_task = loop.create_task(self._loop()) + + async def close(self): + if self.curr_task is not None: + self.curr_task.cancel() + await self.ws.close() + self.processor.close() + self.node_hub.stop() + + +class ClientPool: + def __init__( + self, + web_processor_args: dict, + setup_paths: dict, + plugins: tuple, + isolate_users: bool, + no_sample: bool, + close_event: mp.Event, + ): + self.clients: Dict[str, Client] = {} + self.tmpdirs: Dict[str, str] = {} + self.web_processor_args = web_processor_args + self.setup_paths = setup_paths + self.plugins = plugins + self.shared_execution = not isolate_users + self.no_sample = no_sample + self.close_event = close_event + if self.shared_execution: + self.shared_resources = self._create_resources( + web_processor_args, setup_paths + ) + + def _create_resources(self, web_processor_args: dict, setup_paths: dict): + view_queue = mp.Queue() + processor_args = { + **web_processor_args, + "custom_nodes_path": setup_paths["custom_nodes_path"], + "view_manager_queue": view_queue, + } + self._create_dirs(**setup_paths, no_sample=self.no_sample) + processor = WebInstanceProcessor(**processor_args) + view_manager = ViewManager(view_queue, self.close_event, processor) + node_hub = NodeHub(setup_paths["custom_nodes_path"], self.plugins, view_manager) + processor.start() + view_manager.start() + node_hub.start() + return { + "processor": processor, + "node_hub": node_hub, + "view_manager": view_manager, + } + + def _create_dirs( + self, workflow_dir: str, custom_nodes_path: str, docs_path: str, no_sample: bool + ): + def create_sample_workflow(): + import shutil + + project_path = Path(__file__).parent + assets_dir = project_path.joinpath("sample_assets") + n = "SampleWorkflow.json" + shutil.copyfile(assets_dir.joinpath(n), Path(workflow_dir).joinpath(n)) + n = "SampleWorkflow.md" + shutil.copyfile(assets_dir.joinpath(n), Path(docs_path).joinpath(n)) + n = "sample_nodes.py" + shutil.copyfile(assets_dir.joinpath(n), Path(custom_nodes_path).joinpath(n)) + + should_create_sample = False + if not osp.exists(workflow_dir): + should_create_sample = not no_sample + os.mkdir(workflow_dir) + if not osp.exists(custom_nodes_path): + os.mkdir(custom_nodes_path) + if not osp.exists(docs_path): + os.mkdir(docs_path) + + if should_create_sample: + create_sample_workflow() + + def add_client(self, ws: WebSocketResponse) -> Client: + sid = uuid.uuid4().hex + setup_paths = {**self.setup_paths} + if not self.shared_execution: + root_path = Path(tempfile.mkdtemp()) + self.tmpdirs[sid] = root_path + setup_paths = { + key: root_path.joinpath(path) for key, path in setup_paths.items() + } + web_processor_args = { + **self.web_processor_args, + "custom_nodes_path": setup_paths["custom_nodes_path"], + } + resources = self._create_resources(web_processor_args, setup_paths) + else: + resources = self.shared_resources + + client = Client(sid, ws, **resources, setup_paths=setup_paths) + client.start() + self.clients[sid] = client + asyncio.create_task(ws.send_json({"type": "sid", "data": sid})) + print(f"{sid}: {client.get_root_path()}") + return client + + async def remove_client(self, client: Client): + sid = client.sid + if sid in self.clients: + await client.close() + del self.clients[sid] + if sid in self.tmpdirs: + shutil.rmtree(self.tmpdirs[sid]) + del self.tmpdirs[sid] + + async def remove_all(self): + for sid in self.clients: + await self.clients[sid].close() + for sid in self.tmpdirs: + os.rmdir(self.tmpdirs[sid]) + self.clients = {} + self.tmpdirs = {} + + def get(self, sid: str) -> Client | None: + return self.clients.get(sid, None) diff --git a/graphbook/custom_nodes.py b/graphbook/custom_nodes.py deleted file mode 100644 index 8f880c0..0000000 --- a/graphbook/custom_nodes.py +++ /dev/null @@ -1,152 +0,0 @@ -import asyncio -from watchdog.events import FileSystemEvent, FileSystemEventHandler -from watchdog.observers import Observer -import importlib -import importlib.util -import hashlib -import sys -import os -import os.path as osp -import inspect -import traceback -from .decorators import get_steps, get_resources -from .steps import ( - Step, - BatchStep, - PromptStep, - SourceStep, - GeneratorSourceStep, - AsyncStep, - Split, - SplitNotesByItems, - SplitItemField, - Copy, -) -from .resources import Resource, NumberResource, FunctionResource, ListResource, DictResource - -BUILT_IN_STEPS = [ - Step, - BatchStep, - PromptStep, - SourceStep, - GeneratorSourceStep, - AsyncStep, - Split, - SplitNotesByItems, - SplitItemField, - Copy, -] -BUILT_IN_RESOURCES = [Resource, NumberResource, FunctionResource, ListResource, DictResource] - - -class CustomModuleEventHandler(FileSystemEventHandler): - def __init__(self, root_path, handler): - super().__init__() - self.root_path = osp.abspath(root_path) - self.handler = handler - self.ha = {} - - def on_created(self, event): - if event.is_directory: - return - self.handle_new_file_sync(event.src_path) - - def on_modified(self, event): - if event.is_directory: - return - self.handle_new_file_sync(event.src_path) - - def on_deleted(self, event): - if event.is_directory: - return - - def on_moved(self, event: FileSystemEvent) -> None: - if event.is_directory: - return - self.handle_new_file_sync(event.dest_path) - - async def handle_new_file(self, filename: str): - filename = osp.abspath(filename) - assert filename.startswith( - self.root_path - ), f"Received extraneous file {filename} during tracking of {self.root_path}" - if not filename.endswith(".py"): - return - - with open(filename, "r") as f: - contents = f.read() - - hash_code = hashlib.md5(contents.encode()).hexdigest() - og_hash_code = self.ha.get(filename, None) - if hash_code == og_hash_code: - return - - self.ha[filename] = hash_code - filename = filename[len(self.root_path) + 1 :] - components = filename[: filename.index(".py")].split("/") - module_name = ".".join(components) - - try: - if og_hash_code is None: - importlib.import_module(module_name) - else: - module = importlib.import_module(module_name) - importlib.reload(module) - except Exception as e: - print(f"Error loading {module_name}:") - traceback.print_exc() - return - - module = sys.modules[module_name] - await self.handler(filename, module) - - def handle_new_file_sync(self, filename: str): - asyncio.run(self.handle_new_file(filename)) - - async def init_custom_nodes(self): - for root, dirs, files in os.walk(self.root_path): - for file in files: - await self.handle_new_file(osp.join(root, file)) - - def init_custom_nodes_sync(self): - asyncio.run(self.init_custom_nodes()) - - -class CustomNodeImporter: - def __init__(self, path, step_handler, resource_handler): - self.websocket = None - self.path = path - self.step_handler = step_handler - self.resource_handler = resource_handler - sys.path.append(path) - self.observer = Observer() - self.event_handler = CustomModuleEventHandler(path, self.on_module) - self.event_handler.init_custom_nodes_sync() - - def set_websocket(self, websocket): - self.websocket = websocket - - async def on_module(self, filename, mod): - for name, obj in inspect.getmembers(mod): - if inspect.isclass(obj): - if issubclass(obj, Step) and not obj in BUILT_IN_STEPS: - self.step_handler(filename, name, obj) - if issubclass(obj, Resource) and not obj in BUILT_IN_RESOURCES: - self.resource_handler(filename, name, obj) - - for name, cls in get_steps().items(): - self.step_handler(filename, name, cls) - for name, cls in get_resources().items(): - self.resource_handler(filename, name, cls) - - if self.websocket is not None and not self.websocket.closed: - print("Sending node updated") - await self.websocket.send_json({"type": "node_updated"}) - - def start_observer(self): - self.observer.schedule(self.event_handler, self.path, recursive=True) - self.observer.start() - - def stop_observer(self): - self.observer.stop() - self.observer.join() diff --git a/graphbook/exports.py b/graphbook/exports.py deleted file mode 100644 index 80156a8..0000000 --- a/graphbook/exports.py +++ /dev/null @@ -1,124 +0,0 @@ -from . import steps, resources, custom_nodes -from .doc2md import convert_to_md -from .plugins import setup_plugins -from aiohttp import web - -default_exported_steps = { - "Split": steps.Split, - "SplitNotesByItems": steps.SplitNotesByItems, - "SplitItemField": steps.SplitItemField, - "Copy": steps.Copy, - "DumpJSONL": steps.DumpJSONL, - "LoadJSONL": steps.LoadJSONL, -} - -default_exported_resources = { - "Text": resources.Resource, - "Number": resources.NumberResource, - "Function": resources.FunctionResource, - "List": resources.ListResource, - "Dict": resources.DictResource, -} - - -class NodeHub: - def __init__(self, path): - self.exported_steps = default_exported_steps - self.exported_resources = default_exported_resources - self.custom_node_importer = custom_nodes.CustomNodeImporter( - path, self.handle_step, self.handle_resource - ) - self.plugins = setup_plugins() - steps, resources, web = self.plugins - for plugin in steps: - self.exported_steps.update(steps[plugin]) - for plugin in resources: - self.exported_resources.update(resources[plugin]) - self.web_plugins = web - - def start(self): - self.custom_node_importer.start_observer() - - def stop(self): - self.custom_node_importer.stop_observer() - - def handle_step(self, filename, name, step): - print(f"{filename}: {name} (step)") - self.exported_steps[name] = step - - def handle_resource(self, filename, name, resource): - print(f"{filename}: {name} (resource)") - self.exported_resources[name] = resource - - def get_steps(self): - return self.exported_steps - - def get_resources(self): - return self.exported_resources - - def get_all(self): - return {"steps": self.get_steps(), "resources": self.get_resources()} - - def get_web_plugins(self): - return self.web_plugins - - def get_step_docstring(self, name): - if name in self.exported_steps: - docstring = self.exported_steps[name].__doc__ - if docstring is not None: - docstring = convert_to_md(docstring) - return docstring - return None - - def get_resource_docstring(self, name): - if name in self.exported_resources: - docstring = self.exported_resources[name].__doc__ - if docstring is not None: - docstring = convert_to_md(docstring) - return docstring - return None - - def get_exported_nodes(self): - # Create directory structure for nodes based on their category - def create_dir_structure(nodes): - node_tree = {} - for node_name in nodes: - node = nodes[node_name] - if node["category"] == "": - node_tree[node_name] = node - else: - category_tree = node["category"].split("/") - curr_category = node_tree - for category in category_tree: - if curr_category.get(category) is None: - curr_category[category] = {"children": {}} - curr_category = curr_category[category]["children"] - curr_category[node_name] = node - return node_tree - - steps = { - k: { - "name": k, - "parameters": v.Parameters, - "inputs": ["in"] if v.RequiresInput else [], - "outputs": v.Outputs, - "category": v.Category, - } - for k, v in self.get_steps().items() - } - resources = { - k: { - "name": k, - "parameters": v.Parameters, - "category": v.Category, - } - for k, v in self.get_resources().items() - } - - return { - "steps": create_dir_structure(steps), - "resources": create_dir_structure(resources), - } - - def set_websocket(self, websocket: web.WebSocketResponse): - self.custom_node_importer.set_websocket(websocket) diff --git a/graphbook/main.py b/graphbook/main.py index 160fa67..8de5f8f 100644 --- a/graphbook/main.py +++ b/graphbook/main.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import argparse -import os.path as osp +from pathlib import Path from graphbook.web import start_web from graphbook import config @@ -8,7 +8,7 @@ Graphbook | ML Workflow Framework """ -workflow_dir = "./workflow" +workflow_dir = "workflow" nodes_dir = "custom_nodes" docs_dir = "docs" @@ -48,13 +48,13 @@ def get_args(): parser.add_argument( "--nodes_dir", type=str, - default=osp.join(workflow_dir, nodes_dir), + default=str(Path(workflow_dir).joinpath(nodes_dir)), help="Path to the custom nodes directory", ) parser.add_argument( "--docs_dir", type=str, - default=osp.join(workflow_dir, docs_dir), + default=str(Path(workflow_dir).joinpath(docs_dir)), help="Path to the docs directory", ) parser.add_argument( @@ -73,6 +73,11 @@ def get_args(): action="store_true", help="Use the spawn start method for multiprocessing", ) + parser.add_argument( + "--isolate-users", + action="store_true", + help="Isolate each user in their own execution environment. Does NOT prevent users from accessing each other's files.", + ) return parser.parse_args() @@ -80,8 +85,9 @@ def get_args(): def resolve_paths(args): if args.root_dir: args.workflow_dir = args.root_dir - args.nodes_dir = osp.join(args.root_dir, nodes_dir) - args.docs_dir = osp.join(args.root_dir, docs_dir) + args.nodes_dir = str(Path(args.root_dir).joinpath(nodes_dir)) + args.docs_dir = str(Path(args.root_dir).joinpath(docs_dir)) + print(args.nodes_dir) return args diff --git a/graphbook/media.py b/graphbook/media.py index 9b32f88..028fc33 100644 --- a/graphbook/media.py +++ b/graphbook/media.py @@ -1,6 +1,7 @@ import asyncio from aiohttp import web import os.path as osp +from pathlib import Path @web.middleware @@ -46,8 +47,8 @@ async def set_var_handler(request: web.Request): @routes.get(r"/{path:.*}") async def handle(request: web.Request) -> web.Response: path = request.match_info["path"] - full_path = osp.join(self.root_path, path) - if not osp.exists(full_path): + full_path = Path(self.root_path).joinpath(path) + if not full_path.exists(): return web.HTTPNotFound() return web.FileResponse(full_path) diff --git a/graphbook/nodes.py b/graphbook/nodes.py new file mode 100644 index 0000000..bc4ee0e --- /dev/null +++ b/graphbook/nodes.py @@ -0,0 +1,255 @@ +from typing import Tuple +from . import steps, resources +from .doc2md import convert_to_md +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from watchdog.observers import Observer +from pathlib import Path +import importlib +import importlib.util +import hashlib +import sys +import os +import inspect +import traceback +from .decorators import get_steps, get_resources +from .viewer import ViewManager + + +BUILT_IN_STEPS = [ + steps.Step, + steps.BatchStep, + steps.PromptStep, + steps.SourceStep, + steps.GeneratorSourceStep, + steps.AsyncStep, + steps.Split, + steps.SplitNotesByItems, + steps.SplitItemField, + steps.Copy, +] +BUILT_IN_RESOURCES = [ + resources.Resource, + resources.NumberResource, + resources.FunctionResource, + resources.ListResource, + resources.DictResource, +] + +default_exported_steps = { + "Split": steps.Split, + "SplitNotesByItems": steps.SplitNotesByItems, + "SplitItemField": steps.SplitItemField, + "Copy": steps.Copy, + "DumpJSONL": steps.DumpJSONL, + "LoadJSONL": steps.LoadJSONL, +} + +default_exported_resources = { + "Text": resources.Resource, + "Number": resources.NumberResource, + "Function": resources.FunctionResource, + "List": resources.ListResource, + "Dict": resources.DictResource, +} + + +class NodeHub: + def __init__( + self, path: Path, plugins: Tuple[dict, dict], view_manager: ViewManager + ): + self.exported_steps = default_exported_steps + self.exported_resources = default_exported_resources + self.view_manager = view_manager + self.custom_node_importer = CustomNodeImporter( + path, self.handle_step, self.handle_resource, self.handle_module + ) + plugin_steps, plugin_resources = plugins + for plugin in plugin_steps: + self.exported_steps.update(plugin_steps[plugin]) + for plugin in plugin_resources: + self.exported_resources.update(plugin_resources[plugin]) + + def start(self): + self.custom_node_importer.start_observer() + + def stop(self): + self.custom_node_importer.stop_observer() + + def handle_module(self, filename, module): + self.view_manager.set_state("node_updated") + + def handle_step(self, filename: Path, name: str, step: steps.Step): + print(f"{filename.name}: {name} (step)") + self.exported_steps[name] = step + + def handle_resource(self, filename: Path, name: str, resource: resources.Resource): + print(f"{filename.name}: {name} (resource)") + self.exported_resources[name] = resource + + def get_steps(self): + return self.exported_steps + + def get_resources(self): + return self.exported_resources + + def get_all(self): + return {"steps": self.get_steps(), "resources": self.get_resources()} + + def get_step_docstring(self, name): + if name in self.exported_steps: + docstring = self.exported_steps[name].__doc__ + if docstring is not None: + docstring = convert_to_md(docstring) + return docstring + return None + + def get_resource_docstring(self, name): + if name in self.exported_resources: + docstring = self.exported_resources[name].__doc__ + if docstring is not None: + docstring = convert_to_md(docstring) + return docstring + return None + + def get_exported_nodes(self): + # Create directory structure for nodes based on their category + def create_dir_structure(nodes): + node_tree = {} + for node_name in nodes: + node = nodes[node_name] + if node["category"] == "": + node_tree[node_name] = node + else: + category_tree = node["category"].split("/") + curr_category = node_tree + for category in category_tree: + if curr_category.get(category) is None: + curr_category[category] = {"children": {}} + curr_category = curr_category[category]["children"] + curr_category[node_name] = node + return node_tree + + steps = { + k: { + "name": k, + "parameters": v.Parameters, + "inputs": ["in"] if v.RequiresInput else [], + "outputs": v.Outputs, + "category": v.Category, + } + for k, v in self.get_steps().items() + } + resources = { + k: { + "name": k, + "parameters": v.Parameters, + "category": v.Category, + } + for k, v in self.get_resources().items() + } + + return { + "steps": create_dir_structure(steps), + "resources": create_dir_structure(resources), + } + + +class CustomModuleEventHandler(FileSystemEventHandler): + def __init__(self, root_path: Path, handler: callable): + super().__init__() + self.root_path = Path(root_path).absolute() + self.handler = handler + self.ha = {} + + def on_created(self, event): + if event.is_directory: + return + self.handle_new_file(event.src_path) + + def on_modified(self, event): + if event.is_directory: + return + self.handle_new_file(event.src_path) + + def on_deleted(self, event): + if event.is_directory: + return + + def on_moved(self, event: FileSystemEvent) -> None: + if event.is_directory: + return + self.handle_new_file(event.dest_path) + + def handle_new_file(self, filename: str): + filepath = Path(filename) + if not filepath.suffix == ".py": + return + + contents = filepath.read_text() + hash_code = hashlib.md5(contents.encode()).hexdigest() + og_hash_code = self.ha.get(filename, None) + if hash_code == og_hash_code: + return + + self.ha[filename] = hash_code + try: + module_spec = importlib.util.spec_from_file_location( + "transient_module", filepath + ) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) + except Exception: + print(f"Error loading {filename}:") + traceback.print_exc() + return + + self.handler(filepath, module) + + def init_custom_nodes(self): + for root, dirs, files in os.walk(self.root_path): + for file in files: + self.handle_new_file(Path(root).joinpath(file)) + + +class CustomNodeImporter: + def __init__( + self, + path: Path, + step_handler: callable, + resource_handler: callable, + module_handler: callable, + ): + self.websocket = None + self.path = path + self.step_handler = step_handler + self.resource_handler = resource_handler + self.module_handler = module_handler + self.observer = Observer() + self.event_handler = CustomModuleEventHandler(path, self.on_module) + self.event_handler.init_custom_nodes() + + def on_module(self, filename: Path, mod): + for name, obj in inspect.getmembers(mod): + if inspect.isclass(obj): + if issubclass(obj, steps.Step) and not obj in BUILT_IN_STEPS: + self.step_handler(filename, name, obj) + if ( + issubclass(obj, resources.Resource) + and not obj in BUILT_IN_RESOURCES + ): + self.resource_handler(filename, name, obj) + + for name, cls in get_steps().items(): + self.step_handler(filename, name, cls) + for name, cls in get_resources().items(): + self.resource_handler(filename, name, cls) + + self.module_handler(filename, mod) + + def start_observer(self): + self.observer.schedule(self.event_handler, self.path, recursive=True) + self.observer.start() + + def stop_observer(self): + self.observer.stop() + self.observer.join() diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py index 49d535a..08776b9 100644 --- a/graphbook/processing/web_processor.py +++ b/graphbook/processing/web_processor.py @@ -14,9 +14,11 @@ from ..shm import SharedMemoryManager from ..note import Note from typing import List +from pathlib import Path import queue import multiprocessing as mp import multiprocessing.connection as mpc +import threading as th import traceback import asyncio import time @@ -31,39 +33,37 @@ class WebInstanceProcessor: def __init__( self, - cmd_queue: mp.Queue, - server_request_conn: mpc.Connection, view_manager_queue: mp.Queue, img_mem: SharedMemoryManager | None, continue_on_failure: bool, copy_outputs: bool, - custom_nodes_path: str, - close_event: mp.Event, - pause_event: mp.Event, - spawn_method: bool, + custom_nodes_path: Path, + spawn: bool, num_workers: int = 1, ): - self.cmd_queue = cmd_queue - self.close_event = close_event - self.pause_event = pause_event + self.cmd_queue = mp.Queue() self.view_manager = ViewManagerInterface(view_manager_queue) self.img_mem = img_mem self.graph_state = GraphState(custom_nodes_path, view_manager_queue) self.continue_on_failure = continue_on_failure self.copy_outputs = copy_outputs - self.custom_nodes_path = custom_nodes_path self.num_workers = num_workers self.steps = {} - self.dataloader = Dataloader(self.num_workers, spawn_method) + self.dataloader = Dataloader(self.num_workers, spawn) setup_global_dl(self.dataloader) + self.server_request_conn, client_request_conn = mpc.Pipe() + self.close_event = mp.Event() + self.pause_event = mp.Event() self.state_client = ProcessorStateClient( - server_request_conn, - close_event, + client_request_conn, + self.close_event, + self.pause_event, self.graph_state, self.dataloader, ) self.is_running = False self.filename = None + self.thread = th.Thread(target=self.start_loop, daemon=True) def handle_images(self, outputs: StepOutput): if self.img_mem is None: @@ -252,7 +252,7 @@ def set_is_running(self, is_running: bool = True, filename: str | None = None): if filename is not None: self.filename = filename run_state = {"is_running": is_running, "filename": self.filename} - self.view_manager.handle_run_state(run_state) + self.view_manager.set_state("run_state", run_state) self.state_client.set_running_state(run_state) def cleanup(self): @@ -282,7 +282,7 @@ def try_update_state(self, local_graph: dict) -> bool: traceback.print_exc() return False - def exec(self, work: dict): + def _exec(self, work: dict): self.set_is_running(True, work["filename"]) if not self.try_update_state(work): return @@ -294,8 +294,8 @@ def exec(self, work: dict): elif work["cmd"] == "step": self.step(work["step_id"]) - async def start_loop(self): - loop = asyncio.get_running_loop() + def start_loop(self): + loop = asyncio.new_event_loop() loop.run_in_executor(None, self.state_client.start) exec_cmds = ["run_all", "run", "step"] while not self.close_event.is_set(): @@ -303,7 +303,7 @@ async def start_loop(self): try: work = self.cmd_queue.get(timeout=MP_WORKER_TIMEOUT) if work["cmd"] in exec_cmds: - self.exec(work) + self._exec(work) elif work["cmd"] == "clear": self.graph_state.clear_outputs(work.get("node_id")) self.view_manager.handle_clear(work.get("node_id")) @@ -314,6 +314,30 @@ async def start_loop(self): break except queue.Empty: pass + + def start(self): + self.thread.start() + + def close(self): + self.close_event.set() + self.state_client.close() + self.cleanup() + + def exec(self, work: dict): + self.cmd_queue.put(work) + + def poll_client( + self, req: ProcessorStateRequest, body: dict = None + ) -> dict: + req_data = {"cmd": req} + if body: + req_data.update(body) + self.server_request_conn.send(req_data) + if self.server_request_conn.poll(timeout=MP_WORKER_TIMEOUT): + res = self.server_request_conn.recv() + if res.get("res") == req: + return res.get("data") + return {} class ProcessorStateClient: @@ -321,11 +345,13 @@ def __init__( self, server_request_conn: mpc.Connection, close_event: mp.Event, + pause_event: mp.Event, graph_state: GraphState, dataloader: Dataloader, ): self.server_request_conn = server_request_conn self.close_event = close_event + self.pause_event = pause_event self.curr_task = None self.graph_state = graph_state self.dataloader = dataloader @@ -356,6 +382,9 @@ def _loop(self): step_id, req.get("response") ) output = {"ok": succeeded} + elif req["cmd"] == ProcessorStateRequest.PAUSE: + self.pause_event.set() + output = {} else: output = {} entry = {"res": req["cmd"], "data": output} @@ -370,17 +399,3 @@ def close(self): def set_running_state(self, state: dict): self.running_state = state - - -def poll_conn_for( - conn: mpc.Connection, req: ProcessorStateRequest, body: dict = None -) -> dict: - req_data = {"cmd": req} - if body: - req_data.update(body) - conn.send(req_data) - if conn.poll(timeout=MP_WORKER_TIMEOUT): - res = conn.recv() - if res.get("res") == req: - return res.get("data") - return {} diff --git a/graphbook/sample_assets/SampleWorkflow.json b/graphbook/sample_assets/SampleWorkflow.json index 268359a..cbe002e 100644 --- a/graphbook/sample_assets/SampleWorkflow.json +++ b/graphbook/sample_assets/SampleWorkflow.json @@ -6,8 +6,8 @@ "id": "1", "type": "step", "position": { - "x": 583.4070722675033, - "y": 307.30402857209594 + "x": 430.5, + "y": 122 }, "data": { "name": "CalcMean", @@ -20,24 +20,15 @@ ], "category": "", "label": "CalcMean", - "key": 31, "isCollapsed": false - }, - "width": 150, - "height": 107, - "positionAbsolute": { - "x": 583.4070722675033, - "y": 307.30402857209594 - }, - "selected": false, - "dragging": true + } }, { "id": "2", "type": "step", "position": { - "x": 585.5262378436905, - "y": 430.5629360957575 + "x": 430.5, + "y": 227 }, "data": { "name": "CalcRunningMean", @@ -50,24 +41,15 @@ ], "category": "", "label": "CalcRunningMean", - "key": 32, "isCollapsed": false - }, - "width": 150, - "height": 107, - "selected": false, - "positionAbsolute": { - "x": 585.5262378436905, - "y": 430.5629360957575 - }, - "dragging": true + } }, { "id": "3", "type": "step", "position": { - "x": 924.5511397425128, - "y": 306.8951569327104 + "x": 634.5, + "y": 122 }, "data": { "name": "Split", @@ -85,24 +67,15 @@ ], "category": "Filtering", "label": "Split", - "key": 2, "isCollapsed": false - }, - "width": 150, - "height": 120, - "selected": false, - "positionAbsolute": { - "x": 924.5511397425128, - "y": 306.8951569327104 - }, - "dragging": true + } }, { "id": "4", "type": "resource", "position": { - "x": 753.7478523416509, - "y": 360.9425229173932 + "x": 15.5, + "y": 122 }, "data": { "name": "Function", @@ -114,24 +87,15 @@ }, "category": "Util", "label": "Function", - "key": 37, "isCollapsed": true - }, - "width": 150, - "height": 24, - "selected": false, - "positionAbsolute": { - "x": 753.7478523416509, - "y": 360.9425229173932 - }, - "dragging": false + } }, { "id": "7", "type": "step", "position": { - "x": 392.76540599258135, - "y": 307.35784194806917 + "x": 233.5, + "y": 105 }, "data": { "name": "Transform", @@ -159,24 +123,15 @@ ], "category": "", "label": "Transform", - "key": 22, "isCollapsed": false - }, - "width": 154, - "height": 144, - "selected": false, - "positionAbsolute": { - "x": 392.76540599258135, - "y": 307.35784194806917 - }, - "dragging": true + } }, { "id": "8", "type": "step", "position": { - "x": 185, - "y": 307.5 + "x": 15.5, + "y": 196 }, "data": { "name": "GenerateTensors", @@ -195,17 +150,8 @@ ], "category": "", "label": "GenerateTensors", - "key": 20, "isCollapsed": false - }, - "width": 175, - "height": 127, - "selected": false, - "positionAbsolute": { - "x": 185, - "y": 307.5 - }, - "dragging": true + } } ], "edges": [ diff --git a/graphbook/state.py b/graphbook/state.py index 84f3e3e..504b4cd 100644 --- a/graphbook/state.py +++ b/graphbook/state.py @@ -1,5 +1,4 @@ from __future__ import annotations -from aiohttp.web import WebSocketResponse from typing import Dict, Tuple, List, Iterator, Set from .note import Note from .steps import Step, PromptStep, StepOutput as Outputs @@ -9,14 +8,13 @@ from .plugins import setup_plugins from .logger import setup_logging_nodes from .utils import transform_json_log -from . import exports +from . import nodes import multiprocessing as mp import importlib, importlib.util, inspect import sys, os -import os.path as osp -import json import hashlib from enum import Enum +from pathlib import Path class NodeInstantiationError(Exception): @@ -27,39 +25,12 @@ def __init__(self, message: str, node_id: str, node_name: str): self.node_name = node_name -class UIState: - def __init__(self, root_path: str, websocket: WebSocketResponse): - self.root_path = root_path - self.ws = websocket - self.nodes = {} - self.edges = {} - - def cmd(self, req: dict): - if req["cmd"] == "put_graph": - filename = req["filename"] - nodes = req["nodes"] - edges = req["edges"] - self.put_graph(filename, nodes, edges) - - def put_graph(self, filename: str, nodes: list, edges: list): - full_path = osp.join(self.root_path, filename) - with open(full_path, "w") as f: - serialized = { - "version": "0", - "type": "workflow", - "nodes": nodes, - "edges": edges, - } - json.dump(serialized, f) - - class NodeCatalog: - def __init__(self, custom_nodes_path: str): - sys.path.append(custom_nodes_path) + def __init__(self, custom_nodes_path: Path): self.custom_nodes_path = custom_nodes_path self.nodes = {"steps": {}, "resources": {}} - self.nodes["steps"] |= exports.default_exported_steps - self.nodes["resources"] |= exports.default_exported_resources + self.nodes["steps"] |= nodes.default_exported_steps + self.nodes["resources"] |= nodes.default_exported_resources self.plugins = setup_plugins() steps, resources, _ = self.plugins for plugin in steps: @@ -100,7 +71,7 @@ def _update_custom_nodes(self) -> dict: "steps": {k: False for k in self.nodes["steps"]}, "resources": {k: False for k in self.nodes["resources"]}, } - for root, dirs, files in os.walk(self.custom_nodes_path): + for root, dirs, files in os.walk(str(self.custom_nodes_path)): for file in files: if not file.endswith(".py"): continue @@ -111,11 +82,8 @@ def _update_custom_nodes(self) -> dict: while module_name.startswith("."): module_name = module_name[1:] # import - if module_name in sys.modules: - mod = sys.modules[module_name] - mod = importlib.reload(mod) - else: - mod = importlib.import_module(module_name) + mod = self._get_module(os.path.join(root, file)) + # get node classes for name, obj in inspect.getmembers(mod): if inspect.isclass(obj): @@ -144,9 +112,7 @@ def get_nodes(self) -> Tuple[dict, dict]: class GraphState: - def __init__(self, custom_nodes_path: str, view_manager_queue: mp.Queue): - sys.path.append(custom_nodes_path) - self.custom_nodes_path = custom_nodes_path + def __init__(self, custom_nodes_path: Path, view_manager_queue: mp.Queue): self.view_manager_queue = view_manager_queue self.view_manager = ViewManagerInterface(view_manager_queue) self._dict_graph = {} @@ -204,6 +170,12 @@ def set_resource_value(resource_id, resource_data): del curr_resource, self._dict_resources[resource_id] try: resource = resource_hub[resource_name](**p) + except KeyError: + raise NodeInstantiationError( + f"No resource node with name {resource_name} found", + resource_id, + resource_name, + ) except Exception as e: raise NodeInstantiationError(str(e), resource_id, resource_name) resource_values[resource_id] = resource.value() @@ -249,6 +221,10 @@ def set_resource_value(resource_id, resource_data): try: step = step_hub[step_name](**step_input) step.id = step_id + except KeyError: + raise NodeInstantiationError( + f"No step node with name {step_name} found", step_id, step_name + ) except Exception as e: raise NodeInstantiationError(str(e), step_id, step_name) steps[step_id] = step diff --git a/graphbook/utils.py b/graphbook/utils.py index 7b97c07..0bdff41 100644 --- a/graphbook/utils.py +++ b/graphbook/utils.py @@ -21,6 +21,7 @@ "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE", "PROMPT_RESPONSE", + "PAUSE", ], ) @@ -29,7 +30,7 @@ def is_batchable(obj: Any) -> bool: return isinstance(obj, list) or isinstance(obj, Tensor) -def transform_function_string(func_str): +def transform_function_string(func_str: str): """ This function is used to convert a string to a function by interpreting the string as a python-typed function diff --git a/graphbook/viewer.py b/graphbook/viewer.py index 3a35c7f..38ec00f 100644 --- a/graphbook/viewer.py +++ b/graphbook/viewer.py @@ -1,14 +1,13 @@ -from typing import Dict, List -from aiohttp.web import WebSocketResponse -import uuid +from typing import Dict, List, Any import asyncio import time import multiprocessing as mp -import multiprocessing.connection as mpc import queue import copy import psutil -from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for +from .utils import MP_WORKER_TIMEOUT, get_gpu_util +# from .processing.web_processor import WebInstanceProcessor +from .utils import ProcessorStateRequest class Viewer: @@ -40,9 +39,8 @@ class DataViewer(Viewer): so that the data can be displayed in the web interface. """ - def __init__(self, deque_max_size=5): + def __init__(self): super().__init__("view") - self.deque_max_size = deque_max_size self.last_outputs: Dict[str, dict] = {} self.filename = None @@ -151,9 +149,9 @@ class SystemUtilViewer(Viewer): Tracks system utilization: CPU util, CPU memory, GPU util, GPU memory """ - def __init__(self, processor_state_conn: mpc.Connection): + def __init__(self, processor): super().__init__("system_util") - self.processor_state_conn = processor_state_conn + self.processor = processor def get_cpu_usage(self): return psutil.cpu_percent() @@ -166,8 +164,8 @@ def get_gpu_usage(self): return gpus def get_next(self): - sizes = poll_conn_for( - self.processor_state_conn, ProcessorStateRequest.GET_WORKER_QUEUE_SIZES + sizes = self.processor.poll_client( + ProcessorStateRequest.GET_WORKER_QUEUE_SIZES ) return { "cpu": self.get_cpu_usage(), @@ -194,53 +192,17 @@ def get_next(self): return self.prompts -DEFAULT_CLIENT_OPTIONS = {"SEND_EVERY": 0.5} - - -class Client: - def __init__( - self, - ws: WebSocketResponse, - producers: List[DataViewer], - options: dict = DEFAULT_CLIENT_OPTIONS, - ): - self.ws = ws - self.producers = producers - self.options = options - self.curr_task = None - - async def _loop(self): - while True: - await asyncio.sleep(self.options["SEND_EVERY"]) - sends = [] - for producer in self.producers: - next_entry = producer.get_next() - if next_entry is not None: - entry = {"type": producer.get_event_name(), "data": next_entry} - sends.append(self.ws.send_json(entry)) - await asyncio.gather(*sends) - - def start(self): - loop = asyncio.get_event_loop() - self.curr_task = loop.create_task(self._loop()) - - async def close(self): - if self.curr_task is not None: - self.curr_task.cancel() - await self.ws.close() - - class ViewManager: def __init__( self, work_queue: mp.Queue, close_event: mp.Event, - processor_state_conn: mpc.Connection, + processor, ): self.data_viewer = DataViewer() self.node_stats_viewer = NodeStatsViewer() self.logs_viewer = NodeLogsViewer() - self.system_util_viewer = SystemUtilViewer(processor_state_conn) + self.system_util_viewer = SystemUtilViewer(processor) self.prompt_viewer = PromptViewer() self.viewers: List[Viewer] = [ self.data_viewer, @@ -249,27 +211,12 @@ def __init__( self.system_util_viewer, self.prompt_viewer, ] - self.clients: Dict[str, Client] = {} + self.states: Dict[str, Any] = {} self.work_queue = work_queue self.close_event = close_event - self.curr_task = None - - def add_client(self, ws: WebSocketResponse) -> str: - sid = uuid.uuid4().hex - client = Client(ws, self.viewers) - self.clients[sid] = client - client.start() - return sid - - async def remove_client(self, sid: str): - if sid in self.clients: - await self.clients[sid].close() - del self.clients[sid] - - def close_all(self): - for sid in self.clients: - self.clients[sid].close() - self.clients = {} + + def get_viewers(self): + return self.viewers def handle_outputs(self, node_id: str, outputs: dict): if len(outputs) == 0: @@ -300,22 +247,32 @@ def handle_prompt(self, node_id: str, prompt: dict): def handle_end(self): for viewer in self.viewers: viewer.handle_end() - - def send_to_clients(self, type: str, data: dict): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - sends = [] - for client in self.clients.values(): - entry = {"type": type, "data": data} - sends.append(client.ws.send_json(entry)) - loop.run_until_complete(asyncio.gather(*sends)) - loop.close() - - def handle_run_state(self, is_running: bool, filename: str): - self.data_viewer.set_filename(filename) - self.send_to_clients( - "run_state", {"is_running": is_running, "filename": filename} - ) + + def set_state(self, type: str, data: Any = None): + """ + Set state data for a specific type + """ + self.states[type] = data + + def get_current_states(self): + """ + Retrieve all current state data + """ + states = [{"type": key, "data": self.states[key]} for key in self.states] + self.states.clear() + return states + + def get_current_view_data(self): + """ + Get the current data from all viewer classes + """ + view_data = [] + for viewer in self.viewers: + next_entry = viewer.get_next() + if next_entry is not None: + entry = {"type": viewer.get_event_name(), "data": next_entry} + view_data.append(entry) + return view_data def _loop(self): while not self.close_event.is_set(): @@ -333,24 +290,19 @@ def _loop(self): self.handle_end() elif work["cmd"] == "handle_log": self.handle_log(work["node_id"], work["log"], work["type"]) - elif work["cmd"] == "handle_run_state": - self.handle_run_state(work["is_running"], work["filename"]) elif work["cmd"] == "handle_clear": self.handle_clear(work["node_id"]) elif work["cmd"] == "handle_prompt": self.handle_prompt(work["node_id"], work["prompt"]) + elif work["cmd"] == "set_state": + self.set_state(work["type"], work["data"]) except queue.Empty: pass def start(self): - self._loop() - - async def close(self): - if self.curr_task is not None: - self.curr_task.cancel() - await asyncio.gather(*[client.close() for client in self.clients.values()]) - + loop = asyncio.new_event_loop() + loop.run_in_executor(None, self._loop) class ViewManagerInterface: def __init__(self, view_manager_queue: mp.Queue): @@ -384,9 +336,6 @@ def handle_start(self, node_id: str): def handle_end(self): self.view_manager_queue.put({"cmd": "handle_end"}) - def handle_run_state(self, run_state: dict): - self.view_manager_queue.put({"cmd": "handle_run_state"} | run_state) - def handle_clear(self, node_id: str | None): self.view_manager_queue.put({"cmd": "handle_clear", "node_id": node_id}) @@ -394,3 +343,8 @@ def handle_prompt(self, node_id: str, prompt: dict): self.view_manager_queue.put( {"cmd": "handle_prompt", "node_id": node_id, "prompt": prompt} ) + + def set_state(self, type: str, data: Any): + self.view_manager_queue.put( + {"cmd": "set_state", "type": type, "data": data} + ) diff --git a/graphbook/web.py b/graphbook/web.py index 19f5052..158ae7b 100644 --- a/graphbook/web.py +++ b/graphbook/web.py @@ -1,21 +1,19 @@ import os -import os.path as osp import re import signal import multiprocessing as mp -import multiprocessing.connection as mpc import asyncio import base64 import hashlib import aiohttp from aiohttp import web -from .processing.web_processor import WebInstanceProcessor -from .viewer import ViewManager -from .exports import NodeHub -from .state import UIState +from pathlib import Path from .media import create_media_server -from .utils import poll_conn_for, ProcessorStateRequest +from .utils import ProcessorStateRequest from .shm import SharedMemoryManager +from .clients import ClientPool, Client +from .plugins import setup_plugins +import json @web.middleware @@ -30,7 +28,9 @@ async def cors_middleware(request: web.Request, handler): response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "POST, GET, DELETE, PUT, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, sid" + ) response.headers["Access-Control-Allow-Credentials"] = "true" return response @@ -38,15 +38,12 @@ async def cors_middleware(request: web.Request, handler): class GraphServer: def __init__( self, - processor_queue: mp.Queue, - state_conn: mpc.Connection, - processor_pause_event: mp.Event, - view_manager_queue: mp.Queue, + web_processor_args: dict, img_mem_args: dict, + setup_paths: dict, + isolate_users: bool, + no_sample: bool, close_event: mp.Event, - root_path: str, - custom_nodes_path: str, - docs_path: str, web_dir: str | None = None, host="0.0.0.0", port=8005, @@ -56,51 +53,76 @@ def __init__( self.close_event = close_event self.web_dir = web_dir if self.web_dir is None: - self.web_dir = osp.join(osp.dirname(__file__), "web") - self.node_hub = NodeHub(custom_nodes_path) - self.ui_state = None + self.web_dir = Path(__file__).parent.joinpath("web") routes = web.RouteTableDef() self.routes = routes - self.view_manager = ViewManager(view_manager_queue, close_event, state_conn) self.img_mem = SharedMemoryManager(**img_mem_args) if img_mem_args else None - abs_root_path = osp.abspath(root_path) middlewares = [cors_middleware] max_upload_size = 100 # MB max_upload_size = round(max_upload_size * 1024 * 1024) self.app = web.Application( client_max_size=max_upload_size, middlewares=middlewares ) + self.plugins = setup_plugins() + self.plugin_steps, self.plugin_resources, self.web_plugins = self.plugins + node_plugins = (self.plugin_steps, self.plugin_resources) + self.client_pool = ClientPool( + web_processor_args, + setup_paths, + node_plugins, + isolate_users, + no_sample, + close_event, + ) - if not osp.isdir(self.web_dir): + if not self.web_dir.is_dir(): print( f"Couldn't find web files inside {self.web_dir}. Will not serve web files." ) self.web_dir = None @routes.get("/ws") - async def websocket_handler(request): + async def websocket_handler(request: web.Request, *_) -> web.WebSocketResponse: if self.close_event.is_set(): + print("Server is shutting down. Rejecting new client.") raise web.HTTPServiceUnavailable() - ws = web.WebSocketResponse() - await ws.prepare(request) - self.ui_state = UIState(root_path, ws) - self.node_hub.set_websocket(ws) # Set the WebSocket in NodeHub + ws = web.WebSocketResponse() + try: + await ws.prepare(request) + except Exception as e: + print(f"Error preparing websocket: {e}") + return ws + client = self.client_pool.add_client(ws) + + def put_graph(req: dict): + filename = req["filename"] + nodes = req["nodes"] + edges = req["edges"] + full_path = client.get_root_path().joinpath(filename) + print(f"Saving graph to {full_path}") + with open(full_path, "w") as f: + serialized = { + "version": "0", + "type": "workflow", + "nodes": nodes, + "edges": edges, + } + json.dump(serialized, f) - sid = self.view_manager.add_client(ws) try: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: if msg.data == "close": - await ws.close() + await self.client_pool.remove_client(client) else: req = msg.json() # Unhandled - if req["api"] == "graph": - self.ui_state.cmd(req) + if req["api"] == "graph" and req["cmd"] == "put_graph": + put_graph(req) elif msg.type == aiohttp.WSMsgType.ERROR: print("ws connection closed with exception %s" % ws.exception()) finally: - await self.view_manager.remove_client(sid) + await self.client_pool.remove_client(client) return ws @@ -108,15 +130,15 @@ async def websocket_handler(request): async def get(request: web.Request) -> web.Response: if self.web_dir is None: raise web.HTTPNotFound(body="No web files found.") - return web.FileResponse(osp.join(self.web_dir, "index.html")) + return web.FileResponse(self.web_dir.joinpath("index.html")) @routes.get("/media") async def get_media(request: web.Request) -> web.Response: - path = request.query.get("path", None) + path = Path(request.query.get("path", None)) shm_id = request.query.get("shm_id", None) if path is not None: - if not osp.exists(path): + if not path.exists(): raise web.HTTPNotFound() return web.FileResponse(path) @@ -130,100 +152,108 @@ async def get_media(request: web.Request) -> web.Response: @routes.post("/run") async def run_all(request: web.Request) -> web.Response: + client = get_client(request) data = await request.json() graph = data.get("graph", {}) resources = data.get("resources", {}) filename = data.get("filename", "") - processor_queue.put( + client.exec( { "cmd": "run_all", "graph": graph, "resources": resources, "filename": filename, - } + }, ) return web.json_response({"success": True}) @routes.post("/run/{id}") async def run(request: web.Request) -> web.Response: + client = get_client(request) step_id = request.match_info.get("id") data = await request.json() graph = data.get("graph", {}) resources = data.get("resources", {}) filename = data.get("filename", "") - processor_queue.put( + client.exec( { "cmd": "run", "graph": graph, "resources": resources, "step_id": step_id, "filename": filename, - } + }, ) return web.json_response({"success": True}) @routes.post("/step/{id}") async def step(request: web.Request) -> web.Response: + client = get_client(request) step_id = request.match_info.get("id") data = await request.json() graph = data.get("graph", {}) resources = data.get("resources", {}) filename = data.get("filename", "") - processor_queue.put( + client.exec( { "cmd": "step", "graph": graph, "resources": resources, "step_id": step_id, "filename": filename, - } + }, ) return web.json_response({"success": True}) @routes.post("/pause") async def pause(request: web.Request) -> web.Response: - processor_pause_event.set() + client = get_client(request) + client.poll(ProcessorStateRequest.PAUSE) return web.json_response({"success": True}) @routes.post("/clear") @routes.post("/clear/{id}") async def clear(request: web.Request) -> web.Response: + client = get_client(request) node_id = request.match_info.get("id") - processor_queue.put( - { - "cmd": "clear", - "node_id": node_id, - } - ) + client.exec({"cmd": "clear", "node_id": node_id}) return web.json_response({"success": True}) @routes.post("/prompt_response/{id}") async def prompt_response(request: web.Request) -> web.Response: + client = get_client(request) step_id = request.match_info.get("id") data = await request.json() response = data.get("response") - res = poll_conn_for( - state_conn, + res = client.poll( ProcessorStateRequest.PROMPT_RESPONSE, - {"step_id": step_id, "response": response}, + { + "step_id": step_id, + "response": response, + }, ) return web.json_response(res) @routes.get("/nodes") async def get_nodes(request: web.Request) -> web.Response: - return web.json_response(self.node_hub.get_exported_nodes()) + client = get_client(request) + nodes = client.nodes() + return web.json_response(nodes) @routes.get("/state/{step_id}/{pin_id}/{index}") async def get_output_note(request: web.Request) -> web.Response: + client = get_client(request) step_id = request.match_info.get("step_id") pin_id = request.match_info.get("pin_id") index = int(request.match_info.get("index")) - res = poll_conn_for( - state_conn, + res = client.poll( ProcessorStateRequest.GET_OUTPUT_NOTE, - {"step_id": step_id, "pin_id": pin_id, "index": index}, + { + "step_id": step_id, + "pin_id": pin_id, + "index": index, + }, ) - if ( res and res.get("step_id") == step_id @@ -236,15 +266,32 @@ async def get_output_note(request: web.Request) -> web.Response: @routes.get("/state") async def get_run_state(request: web.Request) -> web.Response: - res = poll_conn_for(state_conn, ProcessorStateRequest.GET_RUNNING_STATE) + client = get_client(request) + res = client.poll(ProcessorStateRequest.GET_RUNNING_STATE) return web.json_response(res) + @routes.get("/step_docstring/{name}") + async def get_step_docstring(request: web.Request): + client = get_client(request) + name = request.match_info.get("name") + docstring = client.step_doc(name) + return web.json_response({"content": docstring}) + + @routes.get("/resource_docstring/{name}") + async def get_resource_docstring(request: web.Request): + client = get_client(request) + name = request.match_info.get("name") + docstring = client.resource_doc(name) + return web.json_response({"content": docstring}) + @routes.get(r"/docs/{path:.+}") - async def get_docs(request: web.Request): - path = request.match_info.get("path") - fullpath = osp.join(docs_path, path) + async def get_docs(request: web.Request) -> web.Response: + client = get_client(request) + path = request.match_info.get("path", "") + docs_path = client.get_docs_path() + fullpath = docs_path.joinpath(path) - if osp.exists(fullpath): + if fullpath.exists(): with open(fullpath, "r") as f: file_contents = f.read() d = {"content": file_contents} @@ -254,65 +301,54 @@ async def get_docs(request: web.Request): {"reason": "/%s: No such file or directory." % fullpath}, status=404 ) - @routes.get("/step_docstring/{name}") - async def get_step_docstring(request: web.Request): - name = request.match_info.get("name") - docstring = self.node_hub.get_step_docstring(name) - return web.json_response({"content": docstring}) - - @routes.get("/resource_docstring/{name}") - async def get_resource_docstring(request: web.Request): - name = request.match_info.get("name") - docstring = self.node_hub.get_resource_docstring(name) - return web.json_response({"content": docstring}) - @routes.get("/fs") @routes.get(r"/fs/{path:.+}") - async def get(request: web.Request): + async def get(request: web.Request) -> web.Response: + client = get_client(request) path = request.match_info.get("path", "") - fullpath = osp.join(abs_root_path, path) - assert fullpath.startswith( - abs_root_path - ), f"{fullpath} must be within {abs_root_path}" - - def handle_fs_tree(p: str, fn: callable) -> dict: - if osp.isdir(p): + client_path = client.get_root_path() + fullpath = client_path.joinpath(path) + assert str(fullpath).startswith( + str(client_path) + ), f"{fullpath} must be within {client_path}" + + def handle_fs_tree(p: Path, fn: callable) -> dict: + if Path.is_dir(p): p_data = fn(p) p_data["children"] = [ - handle_fs_tree(osp.join(p, f), fn) for f in os.listdir(p) + handle_fs_tree(f, fn) for f in Path.iterdir(p) ] return p_data else: return fn(p) - def get_stat(path): - stat = os.stat(path) - rel_path = osp.relpath(path, abs_root_path) + def get_stat(path: Path) -> dict: + stat = path.stat() + rel_path = path.relative_to(fullpath) st = { - "title": osp.basename(rel_path), - "path": rel_path, - "path_from_cwd": osp.join(root_path, rel_path), - "dirname": osp.dirname(rel_path), - "from_root": osp.basename(abs_root_path), + "title": path.name, + "path": str(rel_path), + "path_from_cwd": str(fullpath.joinpath(rel_path)), + "dirname": str(Path(rel_path).parent), "access_time": int(stat.st_atime), "modification_time": int(stat.st_mtime), "change_time": int(stat.st_ctime), } - if not osp.isdir(path): + if not path.is_dir(): st["size"] = int(stat.st_size) return st - if osp.exists(fullpath): + if fullpath.exists(): if request.query.get("stat", False): stats = handle_fs_tree(fullpath, get_stat) res = web.json_response(stats) res.headers["Content-Type"] = "application/json; charset=utf-8" return res - if osp.isdir(fullpath): - res = web.json_response(os.listdir(fullpath)) + if fullpath.is_dir(): + res = web.json_response(list(fullpath.iterdir())) res.headers["Content-Type"] = "application/json; charset=utf-8" return res else: @@ -334,12 +370,14 @@ def get_stat(path): @routes.put("/fs") @routes.put(r"/fs/{path:.+}") - async def put(request: web.Request): - path = request.match_info.get("path") - fullpath = osp.join(root_path, path) + async def put(request: web.Request) -> web.Response: + client = get_client(request) + path = request.match_info.get("path", "") + client_path = client.get_root_path() + fullpath = client_path.joinpath(path) data = await request.json() if request.query.get("mv"): - topath = osp.join(root_path, request.query.get("mv")) + topath = client_path.joinpath(request.query.get("mv")) os.rename(fullpath, topath) return web.json_response({}, status=200) @@ -355,7 +393,7 @@ async def put(request: web.Request): if encoding == "base64": file_contents = base64.b64decode(file_contents) - if osp.exists(fullpath): + if fullpath.exists(): with open(fullpath, "r") as f: current_hash = hashlib.md5(f.read().encode()).hexdigest() if current_hash != hash_key: @@ -369,96 +407,99 @@ async def put(request: web.Request): return web.json_response({}, status=201) @routes.delete("/fs/{path:.+}") - async def delete(request): + async def delete(request: web.Request) -> web.Response: + client = get_client(request) path = request.match_info.get("path") - fullpath = osp.join(root_path, path) - assert fullpath.startswith( - root_path - ), f"{fullpath} must be within {root_path}" - - if osp.exists(fullpath): - if osp.isdir(fullpath): - if os.listdir(fullpath) == []: - os.rmdir(fullpath) + client_path = client.get_root_path() + fullpath = client_path.joinpath(path) + assert str(fullpath).startswith( + client_path + ), f"{fullpath} must be within {client_path}" + + if fullpath.exists(): + if fullpath.is_dir(): + try: + fullpath.rmdir(fullpath) return web.json_response({"success": True}, status=204) - else: + except Exception as e: return web.json_response( - {"reason": "/%s: Directory is not empty." % path}, - status=403, + {"reason": f"Error deleting directory {fullpath}: {e}"}, + status=400, ) else: - os.remove(fullpath) + fullpath.unlink(True) return web.json_response({"success": True}, status=204) else: return web.json_response( - {"reason": "/%s: No such file or directory." % path}, status=404 + {"reason": f"No such file or directory {fullpath}."}, status=404 ) @routes.get("/plugins") - async def get_plugins(request): - plugin_list = list(self.node_hub.get_web_plugins().keys()) + async def get_plugins(request: web.Request) -> web.Response: + plugin_list = list(self.web_plugins.keys()) return web.json_response(plugin_list) @routes.get("/plugins/{name}") - async def get_plugin(request): + async def get_plugin(request: web.Request) -> web.Response: plugin_name = request.match_info.get("name") - plugin_location = self.node_hub.get_web_plugins().get(plugin_name) + plugin_location = self.web_plugins.get(plugin_name) if plugin_location is None: raise web.HTTPNotFound(body=f"Plugin {plugin_name} not found.") return web.FileResponse(plugin_location) + def get_client(request: web.Request) -> Client: + sid = request.headers.get("sid") + if sid is None: + raise web.HTTPUnauthorized() + client = self.client_pool.get(sid) + if client is None: + raise web.HTTPUnauthorized() + return client + async def _async_start(self): runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, self.host, self.port) await site.start() - loop = asyncio.get_running_loop() - loop.run_in_executor(None, self.view_manager.start) await asyncio.Event().wait() + async def on_shutdown(self): + self.client_pool.remove_all() + print("Shutting down graph server") + def start(self): self.app.router.add_routes(self.routes) - - web_plugins = self.node_hub.get_web_plugins() - if web_plugins: + if self.web_plugins: print("Loaded web plugins:") - print(web_plugins) + print(self.web_plugins) if self.web_dir is not None: self.app.router.add_routes([web.static("/", self.web_dir)]) + self.app.on_shutdown.append(self.on_shutdown) + print(f"Starting graph server at {self.host}:{self.port}") - self.node_hub.start() try: asyncio.run(self._async_start()) except KeyboardInterrupt: - self.node_hub.stop() print("Exiting graph server") def create_graph_server( args, - cmd_queue, - state_conn, - processor_pause_event, - view_manager_queue, - close_event, + web_processor_args, img_mem_args, - root_path, - custom_nodes_path, - docs_path, + setup_paths, + close_event, web_dir, ): server = GraphServer( - cmd_queue, - state_conn, - processor_pause_event, - view_manager_queue, + web_processor_args, img_mem_args, + setup_paths, + args.isolate_users, + args.no_sample, close_event, - root_path=root_path, - custom_nodes_path=custom_nodes_path, - docs_path=docs_path, web_dir=web_dir, host=args.host, port=args.port, @@ -466,68 +507,31 @@ def create_graph_server( server.start() -def create_sample_workflow(workflow_dir, custom_nodes_path, docs_path): - import shutil - - assets_dir = osp.join(osp.dirname(__file__), "sample_assets") - n = "SampleWorkflow.json" - shutil.copyfile(osp.join(assets_dir, n), osp.join(workflow_dir, n)) - n = "SampleWorkflow.md" - shutil.copyfile(osp.join(assets_dir, n), osp.join(docs_path, n)) - n = "sample_nodes.py" - shutil.copyfile(osp.join(assets_dir, n), osp.join(custom_nodes_path, n)) - - def start_web(args): # The start method on some systems like Mac default to spawn if not args.spawn and mp.get_start_method() == "spawn": mp.set_start_method("fork", force=True) - cmd_queue = mp.Queue() - parent_conn, child_conn = mp.Pipe() - view_manager_queue = mp.Queue() img_mem = ( SharedMemoryManager(size=args.img_shm_size) if args.img_shm_size > 0 else None ) close_event = mp.Event() - pause_event = mp.Event() - workflow_dir = args.workflow_dir - custom_nodes_path = args.nodes_dir - docs_path = args.docs_dir - should_create_sample = False - if not osp.exists(workflow_dir): - should_create_sample = not args.no_sample - os.mkdir(workflow_dir) - if not osp.exists(custom_nodes_path): - os.mkdir(custom_nodes_path) - if not osp.exists(docs_path): - os.mkdir(docs_path) - - if should_create_sample: - create_sample_workflow(workflow_dir, custom_nodes_path, docs_path) - - processes = [ - mp.Process( - target=create_graph_server, - args=( - args, - cmd_queue, - child_conn, - pause_event, - view_manager_queue, - close_event, - img_mem.get_shared_args() if img_mem else {}, - workflow_dir, - custom_nodes_path, - docs_path, - args.web_dir, - ), - ), - ] - if args.start_media_server: - processes.append(mp.Process(target=create_media_server, args=(args,))) + setup_paths = dict( + workflow_dir=args.workflow_dir, + custom_nodes_path=args.nodes_dir, + docs_path=args.docs_dir, + ) + + web_processor_args = dict( + img_mem=img_mem, + continue_on_failure=args.continue_on_failure, + copy_outputs=args.copy_outputs, + spawn=args.spawn, + num_workers=args.num_workers, + ) - for p in processes: + if args.start_media_server: + p = mp.Process(target=create_media_server, args=(args,)) p.daemon = True p.start() @@ -547,23 +551,11 @@ def signal_handler(*_): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - async def start(): - processor = WebInstanceProcessor( - cmd_queue, - parent_conn, - view_manager_queue, - img_mem, - args.continue_on_failure, - args.copy_outputs, - custom_nodes_path, - close_event, - pause_event, - args.spawn, - args.num_workers, - ) - try: - await processor.start_loop() - finally: - cleanup() - - asyncio.run(start()) + create_graph_server( + args, + web_processor_args, + img_mem.get_shared_args() if img_mem else {}, + setup_paths, + close_event, + args.web_dir, + ) diff --git a/scripts/graphbook b/scripts/graphbook deleted file mode 100755 index b984229..0000000 --- a/scripts/graphbook +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh - -REPO_PATH=$(dirname $(realpath $0))/.. -SCRIPT_DIR=$0 -SERVER_PATH=$REPO_PATH/graphbook/main.py -WEB_PATH=$REPO_PATH/web -WEB_BUILD_PATH=$WEB_PATH/dist - -python $SERVER_PATH --web_dir $WEB_BUILD_PATH --workflow_dir . --nodes_dir ./nodes $@ diff --git a/web/src/api.ts b/web/src/api.ts index 1aefd80..c55118b 100644 --- a/web/src/api.ts +++ b/web/src/api.ts @@ -8,12 +8,37 @@ export class ServerAPI { private nodes: any; private websocket: WebSocket | null; private listeners: Set<[string, EventListenerOrEventListenerObject]> - private reconnectListeners: Set; + private reconnectTimerListeners: Set; + private onConnectStateListeners: Set; + private protocol: string; + private wsProtocol: string; + private sid?: string; constructor() { this.nodes = {}; this.listeners = new Set(); - this.reconnectListeners = new Set(); + this.reconnectTimerListeners = new Set(); + this.onConnectStateListeners = new Set(); + this.protocol = window.location.protocol; + this.wsProtocol = this.protocol === 'https:' ? 'wss:' : 'ws:'; + this.addWSMessageListener(this.sidSetter.bind(this)); + } + + private sidSetter(res) { + const msg = JSON.parse(res.data); + if (msg.type === "sid") { + if (this.sid) { + console.error("Unexpected request from server to change SID when it has already been set."); + return; + } + this.sid = msg.data; + console.log(`SID: ${this.sid}`); + + this.refreshNodeCatalogue(); + for (const callback of this.onConnectStateListeners) { + callback(true); + } + } } public connect(host: string, mediaHost: string) { @@ -28,12 +53,42 @@ export class ServerAPI { this.websocket.close(); this.websocket = null; } + this.sid = undefined; + for (const callback of this.onConnectStateListeners) { + callback(false); + } + } + + public onConnectStateChange(callback: Function): Function { + this.onConnectStateListeners.add(callback); + + return () => { + this.onConnectStateListeners.delete(callback); + }; + } + + public onReconnectTimerChange(callback: Function): Function { + this.reconnectTimerListeners.add(callback); + + return () => { + this.reconnectTimerListeners.delete(callback); + }; + } + + + public addReconnectTimerListener(callback: Function) { + this.reconnectTimerListeners.add(callback); } + public removeReconnectTimerListener(callback: Function) { + this.reconnectTimerListeners.delete(callback); + } + + private connectWebSocket() { const connect = () => { try { - this.websocket = new WebSocket(`ws://${this.host}/ws`); + this.websocket = new WebSocket(`${this.wsProtocol}//${this.host}/ws`); } catch (e) { console.error(e); return; @@ -43,10 +98,13 @@ export class ServerAPI { } this.websocket.onopen = () => { console.log("Connected to server."); - this.refreshNodeCatalogue(); }; this.websocket.onclose = () => { this.retryWebSocketConnection(); + this.sid = undefined; + for (const callback of this.onConnectStateListeners) { + callback(false); + } }; }; if (this.websocket) { @@ -75,7 +133,7 @@ export class ServerAPI { } else { reconnectSecond(i - 1); } - for (const callback of this.reconnectListeners) { + for (const callback of this.reconnectTimerListeners) { callback(i); } }, 1000); @@ -114,50 +172,73 @@ export class ServerAPI { this.nodes = nodesRes; } - private async post(path, data): Promise { + private async post(path, data): Promise { + if (!this.sid) { + throw Error("Cannot make request without SID."); + } + try { - const response = await fetch(`http://${this.host}/${path}`, { + const response = await fetch(`${this.protocol}//${this.host}/${path}`, { method: 'POST', headers: { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + 'sid': this.sid, }, body: JSON.stringify(data) }); - if (response.ok) { - return await response.json(); - } + return response; } catch (e) { - console.error(e); - return null; + console.error(`POST request error: ${e}`); + throw e; } } private async put(path, data): Promise { - const response = await fetch(`http://${this.host}/${path}`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify(data) - }); - return response; + if (!this.sid) { + throw Error("Cannot make request without SID."); + } + + try { + const response = await fetch(`${this.protocol}//${this.host}/${path}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + 'sid': this.sid, + }, + body: JSON.stringify(data) + }); + return response; + } catch (e) { + console.error(`PUT request error: ${e}`); + throw e; + } } private async get(path): Promise { + if (!this.sid) { + throw Error("Cannot make request without SID."); + } + try { - const response = await fetch(`http://${this.host}/${path}`); + const response = await fetch(`${this.protocol}//${this.host}/${path}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'sid': this.sid, + } + }); if (response.ok) { return await response.json(); } } catch (e) { - console.error(e); - return null; + console.error(`POST request error: ${e}`); + throw e; } } private async delete(path): Promise { try { - const response = await fetch(`http://${this.host}/${path}`, { + const response = await fetch(`${this.host}/${path}`, { method: 'DELETE', headers: { 'Content-Type': 'application/json' @@ -194,14 +275,6 @@ export class ServerAPI { this.removeWsEventListener('message', callback); } - public addReconnectListener(callback: Function) { - this.reconnectListeners.add(callback); - } - - public removeReconnectListener(callback: Function) { - this.reconnectListeners.delete(callback); - } - /** * Processor API */ @@ -317,12 +390,12 @@ export class ServerAPI { * Image/Media API */ public getImagePath(imageName: string) { - return `http://${this.mediaHost}/${imageName}`; + return `${this.protocol}//${this.mediaHost}/${imageName}`; } private async mediaServerPut(path, data) { try { - const response = await fetch(`http://${this.mediaHost}/${path}`, { + const response = await fetch(`${this.protocol}//${this.mediaHost}/${path}`, { method: 'PUT', headers: { 'Content-Type': 'application/json' @@ -343,7 +416,7 @@ export class ServerAPI { } public async getMediaServerVars(): Promise { - return this.mediaServerPut('set', { }); + return this.mediaServerPut('set', {}); } /** diff --git a/web/src/components/Flow.tsx b/web/src/components/Flow.tsx index 605490e..1a88b75 100644 --- a/web/src/components/Flow.tsx +++ b/web/src/components/Flow.tsx @@ -305,21 +305,19 @@ export default function Flow({ filename }) { } const searchNodes = (catalogue, name, category) => { - if (!category) { - return null; - } - const categories = category.split('/'); - let c = catalogue[categories[0]]; - for (let i = 1; i < categories.length; i++) { - c = c?.children[categories[i]]; + const categories = category === '' ? [] : category.split('/'); + let collection = catalogue; + for (let i = 0; i < categories.length; i++) { + collection = collection?.children?.[categories[i]]; } - if (!c) { + if (!collection) { return null; } - return c.children?.[name]; + return collection[name]; }; const updatedNodes = await API.getNodes(); + console.log(updatedNodes.steps); setNodes(nodes => { const mergedNodes = nodes.map(node => { diff --git a/web/src/components/LeftPanel/Filesystem.tsx b/web/src/components/LeftPanel/Filesystem.tsx index f7e980e..7f28469 100644 --- a/web/src/components/LeftPanel/Filesystem.tsx +++ b/web/src/components/LeftPanel/Filesystem.tsx @@ -41,11 +41,12 @@ export default function Filesystem({ setWorkflow, onBeginEdit }) { return; } const files = await API.listFiles(); + console.log(files); if (!files) { return; } - const filesRoot = files.from_root.toUpperCase(); + const filesRoot = files.title.toUpperCase(); const setKey = (data) => { data.forEach((item) => { item.key = item.path; @@ -98,6 +99,7 @@ export default function Filesystem({ setWorkflow, onBeginEdit }) { return; } if (filename.slice(-3) == '.py') { + console.log("Attempting to open", filename) onBeginEdit({ name: filename }); } else if (filename.slice(-5) == '.json') { setWorkflow(filename); diff --git a/web/src/graphstore.ts b/web/src/graphstore.ts index 3e109d2..5980df2 100644 --- a/web/src/graphstore.ts +++ b/web/src/graphstore.ts @@ -56,6 +56,12 @@ export class GraphStore { const storedNodes = nodes.map(node => { const storedNode = { ...node, data: { ...node.data } }; delete storedNode.data.properties; + delete storedNode.data.key; + delete storedNode.selected; + delete storedNode.dragging; + delete storedNode.positionAbsolute; + delete storedNode.width; + delete storedNode.height; return storedNode; }); const storedEdges = edges.map(edge => { @@ -79,8 +85,8 @@ export class GraphStore { for (let i = 0; i < prev.nodes.length; i++) { // If node's data is updated - const prevData = { ...prev.nodes[i].data, properties: undefined }; - const nextData = { ...next.nodes[i].data, properties: undefined }; + const prevData = JSON.stringify({ ...prev.nodes[i].data, properties: undefined }); + const nextData = JSON.stringify({ ...next.nodes[i].data, properties: undefined }); if (prevData !== nextData) { return true; } diff --git a/web/src/hooks/API.ts b/web/src/hooks/API.ts index 2581793..5cf25c9 100644 --- a/web/src/hooks/API.ts +++ b/web/src/hooks/API.ts @@ -6,8 +6,13 @@ let globalAPI: ServerAPI | null = null; let localSetters: Function[] = []; let initialized = false; -const initialize = () => setGlobalAPI(API); -const disable = () => setGlobalAPI(null); +const onConnectStateChange = (isConnected: boolean) => { + if (!isConnected) { + setGlobalAPI(null); + } else { + setGlobalAPI(API); + } +}; function setGlobalAPI(api: ServerAPI | null) { globalAPI = api; @@ -20,21 +25,22 @@ export function useAPI() { const [_, setAPI] = useState(globalAPI); useEffect(() => { - localSetters.push(setAPI); - if (!initialized) { - API.addWsEventListener('open', initialize); - API.addWsEventListener('close', disable); initialized = true; + const discard = API.onConnectStateChange(onConnectStateChange); + + return () => { + discard(); + initialized = false + }; } + }, []); + + useEffect(() => { + localSetters.push(setAPI); + return () => { localSetters = localSetters.filter((setter) => setter !== setAPI); - - if (localSetters.length === 0) { - API.removeWsEventListener('open', initialize); - API.removeWsEventListener('close', disable); - initialized = false; - } } }, []); @@ -94,19 +100,23 @@ export function useAPIReconnectTimer() { }; useEffect(() => { - localReconnectListeners.push(reconnectTime); + if (!reconnectInitialized) { - if(!reconnectInitialized) { - API.addReconnectListener(onTimerChanged); + const discard = API.onReconnectTimerChange(onTimerChanged); reconnectInitialized = true; + + return () => { + discard(); + reconnectInitialized = false; + }; } + }); + + useEffect(() => { + localReconnectListeners.push(reconnectTime); return () => { localReconnectListeners = localReconnectListeners.filter((listener) => listener !== reconnectTime); - if(localReconnectListeners.length === 0) { - API.removeReconnectListener(onTimerChanged); - reconnectInitialized = false; - } } }, []); diff --git a/web/src/hooks/Settings.ts b/web/src/hooks/Settings.ts index 4aff49b..bbb07df 100644 --- a/web/src/hooks/Settings.ts +++ b/web/src/hooks/Settings.ts @@ -1,11 +1,12 @@ import { useState, useEffect, useCallback } from "react"; const MONITOR_DATA_COLUMNS = ['stats', 'logs', 'notes', 'images']; +const defaultPort = window.location.port === '' ? '' : `:${window.location.port}`; let settings = { theme: "Light", disableTooltips: false, - graphServerHost: "localhost:8005", - mediaServerHost: "localhost:8006", + graphServerHost: `${window.location.hostname}${defaultPort}`, + mediaServerHost: `${window.location.hostname}:8006`, useExternalMediaServer: false, monitorDataColumns: MONITOR_DATA_COLUMNS, monitorLogsShouldScrollToBottom: true, diff --git a/web/src/plugins.ts b/web/src/plugins.ts index 7b11de8..06590f5 100644 --- a/web/src/plugins.ts +++ b/web/src/plugins.ts @@ -25,7 +25,8 @@ class PluginManager { for await (const p of plugins) { if (!this.plugins.has(p)) { try { - const url = `http://${API.getHost()}/plugins/${p}`; + const protocol = window.location.protocol; + const url = `${protocol}//${API.getHost()}/plugins/${p}`; console.log("Loading plugin from", url); const module = await import(/* @vite-ignore */url); this.plugins.set(p, module); diff --git a/web/src/utils.ts b/web/src/utils.ts index 58032fe..be2ac98 100644 --- a/web/src/utils.ts +++ b/web/src/utils.ts @@ -47,20 +47,12 @@ export const getMediaPath = (settings: any, item: ImageRef): string => { return ''; } + const protocol = window.location.protocol; if (!settings.useExternalMediaServer) { - let graphHost = settings.graphServerHost; - if (!graphHost.startsWith('http')) { - graphHost = 'http://' + graphHost; - } - return `${graphHost}/media${query}`; - } - - let mediaHost = settings.mediaServerHost; - if (!mediaHost.startsWith('http')) { - mediaHost = 'http://' + query; + return `${protocol}//${settings.graphServerHost}/media${query}`; } - return mediaHost + query; + return `${protocol}//${settings.mediaServerHost}${query}`; } export const uniqueIdFrom = (obj: any): string => {