diff --git a/README.md b/README.md index 5ddeeed..acf6ef8 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,34 @@

Graphbook

+

+ + GitHub License + + + GitHub Actions Workflow Status + + + Docker Pulls + + + PyPI Downloads + + + PyPI - Version + +

+
+ + Join Discord + +
+

+ + Discord + +

+

The ML workflow framework
@@ -23,7 +51,7 @@

## Overview -Graphbook is a framework for building efficient, visual DAG-structured ML workflows composed of nodes written in Python. Graphbook provides common ML processing features such as multiprocessing IO and automatic batching, and it features a web-based UI to assemble, monitor, and execute data processing workflows. It can be used to prepare training data for custom ML models, experiment with custom trained or off-the-shelf models, and to build ML-based ETL applications. Custom nodes can be built in Python, and Graphbook will behave like a framework and call lifecycle methods on those nodes. +Graphbook is a framework for building efficient, visual DAG-structured ML workflows composed of nodes written in Python. Graphbook provides common ML processing features such as multiprocessing IO and automatic batching for PyTorch tensors, and it features a web-based UI to assemble, monitor, and execute data processing workflows. It can be used to prepare training data for custom ML models, experiment with custom trained or off-the-shelf models, and to build ML-based ETL applications. Custom nodes can be built in Python, and Graphbook will behave like a framework and call lifecycle methods on those nodes.

diff --git a/graphbook/__init__.py b/graphbook/__init__.py index 44e8cc5..14fb2eb 100644 --- a/graphbook/__init__.py +++ b/graphbook/__init__.py @@ -1,4 +1,4 @@ from .note import Note -from .decorators import step, param, source, output, batch, resource, event +from .decorators import step, param, source, output, batch, resource, event, prompt -__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "Note"] +__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "prompt", "Note"] diff --git a/graphbook/custom_nodes.py b/graphbook/custom_nodes.py index ffdbdc1..0fba930 100644 --- a/graphbook/custom_nodes.py +++ b/graphbook/custom_nodes.py @@ -13,6 +13,7 @@ from graphbook.steps import ( Step, BatchStep, + PromptStep, SourceStep, GeneratorSourceStep, AsyncStep, @@ -25,6 +26,7 @@ BUILT_IN_STEPS = [ Step, BatchStep, + PromptStep, SourceStep, GeneratorSourceStep, AsyncStep, diff --git a/graphbook/decorators.py b/graphbook/decorators.py index 17849bf..16cd909 100644 --- a/graphbook/decorators.py +++ b/graphbook/decorators.py @@ -31,14 +31,13 @@ def param( "required": required, "description": description, } + self.parameter_type_casts[name] = cast_as if cast_as is None: # Default casts if type == "function": self.parameter_type_casts[name] = transform_function_string if type == "int": self.parameter_type_casts[name] = int - else: - self.parameter_type_casts[name] = cast_as @abc.abstractmethod def build(): @@ -85,6 +84,11 @@ def batch( if dump_fn is not None: self.event("dump_fn", dump_fn) + def prompt(self, get_prompt=None): + self.BaseClass = steps.PromptStep + if get_prompt is not None: + self.event("get_prompt", get_prompt) + def build(self): def __init__(cls, **kwargs): if self.BaseClass == steps.BatchStep: @@ -212,6 +216,8 @@ def decorator(func): factory.event("on_note", func) elif factory.BaseClass == steps.BatchStep: factory.event("on_item_batch", func) + elif factory.BaseClass == steps.PromptStep: + factory.event("on_prompt_response", func) else: factory.event("load", func) @@ -456,3 +462,65 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def prompt(get_prompt: callable = None): + """ + Marks a function as a step that is capable of prompting the user. + This is useful for interactive workflows where data labeling, model evaluation, or any other human input is required. + Events ``get_prompt(ctx, note: Note)`` and ``on_prompt_response(ctx, note: Note, response: Any)`` are required to be implemented. + The decorator accepts the ``get_prompt`` function that returns a prompt to display to the user. + If nothing is passed as an argument, a ``bool_prompt`` will be used by default. + If the function returns **None** on any given note, no prompt will be displayed for that note allowing for conditional prompts based on the note's content. + Available prompts are located in the ``graphbook.prompts`` module. + The function that this decorator decorates is ``on_prompt_response`` and will be called when a response to a prompt is obtained from a user. + Once the prompt is handled, the execution lifecycle of the Step will proceed, normally. + + Args: + get_prompt (callable): A function that returns a prompt. Default is ``bool_prompt``. + + Examples: + .. highlight:: python + .. code-block:: python + + def dog_or_cat(ctx, note: Note): + return selection_prompt(note, choices=["dog", "cat"], show_images=True) + + + @step("Prompts/Label") + @prompt(dog_or_cat) + def label_images(ctx, note: Note, response: str): + note["label"] = response + + + def corrective_prompt(ctx, note: Note): + if note["prediction_confidence"] < 0.65: + return bool_prompt( + note, + msg=f"Model prediction ({note['pred']}) was uncertain. Is its prediction correct?", + show_images=True, + ) + else: + return None + + + @step("Prompts/CorrectModelLabel") + @prompt(corrective_prompt) + def correct_model_labels(ctx, note: Note, response: bool): + if response: + ctx.log("Model is correct!") + note["label"] = note["pred"] + else: + ctx.log("Model is incorrect!") + if note["pred"] == "dog": + note["label"] = "cat" + else: + note["label"] = "dog" + """ + def decorator(func): + def set_prompt(factory: StepClassFactory): + factory.prompt(get_prompt) + + return DecoratorFunction(func, set_prompt) + + return decorator diff --git a/graphbook/logger.py b/graphbook/logger.py index 40dd64c..0d80417 100644 --- a/graphbook/logger.py +++ b/graphbook/logger.py @@ -2,6 +2,7 @@ import multiprocessing as mp from typing import Dict, Tuple, Any import inspect +from graphbook.note import Note from graphbook.viewer import ViewManagerInterface from graphbook.utils import transform_json_log @@ -44,3 +45,18 @@ def log(msg: Any, type: LogType = "info", caller_id: int | None = None): else: raise ValueError(f"Unknown log type {type}") view_manager.handle_log(node_id, msg, type) + +def prompt(prompt: dict, caller_id: int | None = None): + if caller_id is None: + prev_frame = inspect.currentframe().f_back + caller = prev_frame.f_locals.get("self") + if caller is not None: + caller_id = id(caller) + + node = logging_nodes.get(caller_id, None) + if node is None: + raise ValueError( + f"Can't find node id in {caller}. Only initialized steps can log." + ) + node_id, _ = node + view_manager.handle_prompt(node_id, prompt) diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py index fb293dd..a26c6ca 100644 --- a/graphbook/processing/web_processor.py +++ b/graphbook/processing/web_processor.py @@ -1,4 +1,11 @@ -from graphbook.steps import Step, SourceStep, GeneratorSourceStep, AsyncStep, StepOutput +from graphbook.steps import ( + Step, + SourceStep, + GeneratorSourceStep, + AsyncStep, + BatchStep, + StepOutput, +) from graphbook.dataloading import Dataloader, setup_global_dl from graphbook.utils import MP_WORKER_TIMEOUT, ProcessorStateRequest, transform_json_log from graphbook.state import GraphState, StepState, NodeInstantiationError @@ -49,7 +56,10 @@ def __init__( self.dataloader = Dataloader(self.num_workers) setup_global_dl(self.dataloader) self.state_client = ProcessorStateClient( - server_request_conn, close_event, self.graph_state, self.dataloader + server_request_conn, + close_event, + self.graph_state, + self.dataloader, ) self.is_running = False self.filename = None @@ -120,13 +130,11 @@ def exec_step( id(step), ) return None - self.handle_images(outputs) self.graph_state.handle_outputs( step.id, outputs if not self.copy_outputs else copy.deepcopy(outputs) ) - self.view_manager.handle_outputs(step.id, transform_json_log(outputs)) - self.view_manager.handle_time(step.id, time.time() - start_time) + self.view_manager.handle_time(step.id, time.time() - start_time) return outputs def handle_steps(self, steps: List[Step]) -> bool: @@ -182,13 +190,25 @@ def step_until_received_output(self, steps: List[Step], step_id: str): step_executed = self.graph_state.get_state( step_id, StepState.EXECUTED_THIS_RUN ) + + def try_execute_step_event(self, step: Step, event: str): + try: + if hasattr(step, event): + getattr(step, event)() + return True + except Exception as e: + log(f"{type(e).__name__}: {str(e)}", "error", id(step)) + traceback.print_exc() + return False def run(self, step_id: str = None): steps: List[Step] = self.graph_state.get_processing_steps(step_id) - self.setup_dataloader(steps) for step in steps: self.view_manager.handle_start(step.id) - step.on_start() + succeeded = self.try_execute_step_event(step, "on_start") + if not succeeded: + return + self.setup_dataloader(steps) self.pause_event.clear() dag_is_active = True try: @@ -201,24 +221,26 @@ def run(self, step_id: str = None): dag_is_active = self.handle_steps(steps) finally: self.view_manager.handle_end() - for step in steps: - step.on_end() self.dataloader.stop() + for step in steps: + self.try_execute_step_event(step, "on_end") def step(self, step_id: str = None): steps: List[Step] = self.graph_state.get_processing_steps(step_id) - self.setup_dataloader(steps) for step in steps: self.view_manager.handle_start(step.id) - step.on_start() + succeeded = self.try_execute_step_event(step, "on_start") + if not succeeded: + return + self.setup_dataloader(steps) self.pause_event.clear() try: self.step_until_received_output(steps, step_id) finally: self.view_manager.handle_end() - for step in steps: - step.on_end() self.dataloader.stop() + for step in steps: + self.try_execute_step_event(step, "on_end") def set_is_running(self, is_running: bool = True, filename: str | None = None): self.is_running = is_running @@ -232,7 +254,7 @@ def cleanup(self): self.dataloader.shutdown() def setup_dataloader(self, steps: List[Step]): - dataloader_consumers = [step for step in steps if isinstance(step, AsyncStep)] + dataloader_consumers = [step for step in steps if isinstance(step, BatchStep)] consumer_ids = [id(c) for c in dataloader_consumers] consumer_load_fn = [ c.load_fn if hasattr(c, "load_fn") else None for c in dataloader_consumers @@ -323,6 +345,10 @@ def _loop(self): output = self.dataloader.get_all_sizes() elif req["cmd"] == ProcessorStateRequest.GET_RUNNING_STATE: output = self.running_state + elif req["cmd"] == ProcessorStateRequest.PROMPT_RESPONSE: + step_id = req.get("step_id") + succeeded = self.graph_state.handle_prompt_response(step_id, req.get("response")) + output = {"ok": succeeded} else: output = {} entry = {"res": req["cmd"], "data": output} diff --git a/graphbook/prompts.py b/graphbook/prompts.py new file mode 100644 index 0000000..04b4b24 --- /dev/null +++ b/graphbook/prompts.py @@ -0,0 +1,158 @@ +from typing import Any, List +from .note import Note +from .utils import transform_json_log + + +def none(): + return {"type": None} + + +def prompt(note: Note, *, msg: str = "", show_images: bool = False, default: Any = ""): + return { + "note": transform_json_log(note), + "msg": msg, + "show_images": show_images, + "def": default, + } + + +def bool_prompt( + note: Note, + *, + msg: str = "Continue?", + style: str = "yes/no", + default: bool = False, + show_images: bool = False, +): + """ + Prompt the user with a yes/no for binary questions. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + style (str): The style of the bool prompt. Can be "yes/no" or "switch". + default (bool): The default bool value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "bool" + p["options"] = {"style": style} + return p + +def selection_prompt( + note: Note, + choices: List[str], + *, + msg: str = "Select an option:", + default: List[str] | str = None, + show_images: bool = False, + multiple_allowed: bool = False +): + """ + Prompt the user to select an option from a list of choices. + + Args: + note (Note): The current note that triggered the prompt. + choices (List[str]): A list of strings representing the options the user can select. + msg (str): An informative message or inquiry to display to the user. + default (List[str] | str): The default value. If multiple_allowed is True, this should be a list of strings. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + multiple_allowed (bool): Whether the user can select multiple options from the list of given choices. + """ + assert len(choices) > 0, "Choices must not be empty in selection prompt." + if default is None: + if multiple_allowed: + default = [choices[0]] + else: + default = choices[0] + + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "selection" + p["options"] = { + "choices": choices, + "multiple_allowed": multiple_allowed, + } + return p + +def text_prompt( + note: Note, + *, + msg: str = "Enter text:", + default: str = "", + show_images: bool = False, +): + """ + Prompt the user to enter text. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + default (str): The default text value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "string" + return p + +def number_prompt( + note: Note, + *, + msg: str = "Enter a number:", + default: float = 0.0, + show_images: bool = False, +): + """ + Prompt the user to enter a number. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + default (float): The default number value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "number" + return p + +def dict_prompt( + note: Note, + *, + msg: str = "Enter a dictionary:", + default: dict = {}, + show_images: bool = False, +): + """ + Prompt the user to enter a dictionary. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + default (dict): The default dictionary value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "dict" + return p + +def list_prompt( + note: Note, + type: str = "string", + *, + msg: str = "Enter a list:", + default: list = [], + show_images: bool = False, +): + """ + Prompt the user to enter a list. + + Args: + note (Note): The current note that triggered the prompt. + type (str): The type of the list elements. Can be "string", "number", "dict", or "bool". + msg (str): An informative message or inquiry to display to the user. + default (list): The default list value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + assert type in ["string", "number", "dict", "bool"], "Invalid type in list prompt." + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = f"list[{type}]" + return p diff --git a/graphbook/state.py b/graphbook/state.py index ce35071..4be7d56 100644 --- a/graphbook/state.py +++ b/graphbook/state.py @@ -2,12 +2,13 @@ from aiohttp.web import WebSocketResponse from typing import Dict, Tuple, List, Iterator, Set from graphbook.note import Note -from graphbook.steps import Step, StepOutput as Outputs +from graphbook.steps import Step, PromptStep, StepOutput as Outputs from graphbook.resources import Resource from graphbook.decorators import get_steps, get_resources from graphbook.viewer import ViewManagerInterface from graphbook.plugins import setup_plugins from graphbook.logger import setup_logging_nodes +from graphbook.utils import transform_json_log import multiprocessing as mp import importlib, importlib.util, inspect import graphbook.exports as exports @@ -344,6 +345,7 @@ def handle_outputs(self, step_id: str, outputs: Outputs): self._step_states[step_id].add(StepState.EXECUTED) self._step_states[step_id].add(StepState.EXECUTED_THIS_RUN) self.view_manager.handle_queue_size(step_id, self._queues[step_id].dict_sizes()) + self.view_manager.handle_outputs(step_id, transform_json_log(outputs)) def clear_outputs(self, node_id: str | None = None): if node_id is None: @@ -400,6 +402,16 @@ def get_output_note(self, step_id: str, pin_id: str, index: int) -> dict: note = internal_list[index] entry.update(data=note.items) return entry + + def handle_prompt_response(self, step_id: str, response: dict) -> bool: + step = self._steps.get(step_id) + if not isinstance(step, PromptStep): + return False + try: + step.handle_prompt_response(response) + return True + except: + return False def get_step(self, step_id: str): return self._steps.get(step_id) diff --git a/graphbook/steps/__init__.py b/graphbook/steps/__init__.py index 22ecd60..f2a2bab 100644 --- a/graphbook/steps/__init__.py +++ b/graphbook/steps/__init__.py @@ -3,6 +3,7 @@ SourceStep, GeneratorSourceStep, BatchStep, + PromptStep, StepOutput, AsyncStep, Split, @@ -16,6 +17,7 @@ "SourceStep", "GeneratorSourceStep", "BatchStep", + "PromptStep", "StepOutput", "AsyncStep", "Split", diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py index a08b95a..b0fe684 100644 --- a/graphbook/steps/base.py +++ b/graphbook/steps/base.py @@ -1,11 +1,16 @@ from __future__ import annotations from typing import List, Dict, Tuple, Generator, Any -from ..utils import transform_function_string, convert_dict_values_to_list, is_batchable from graphbook import Note -from graphbook.logger import log +from graphbook.utils import ( + transform_function_string, + convert_dict_values_to_list, + is_batchable, +) +from graphbook.logger import log, prompt +import graphbook.prompts as prompts import graphbook.dataloading as dataloader import warnings - +import traceback warnings.simplefilter("default", DeprecationWarning) @@ -221,8 +226,12 @@ class AsyncStep(Step): def __init__(self, item_key=None): super().__init__(item_key) - self._is_processing = True self._in_queue = [] + self._out_queue = [] + + def on_clear(self): + self._in_queue = [] + self._out_queue = [] def in_q(self, note: Note | None): if note is None: @@ -230,7 +239,17 @@ def in_q(self, note: Note | None): self._in_queue.append(note) def is_active(self) -> bool: - return self._is_processing + return len(self._in_queue) > 0 + + def __call__(self) -> StepOutput: + # 1. on_note -> 2. on_item -> 3. on_after_item -> 4. forward_note + if len(self._out_queue) == 0: + return {} + note = self._out_queue.pop(0) + return super().__call__(note) + + def all(self) -> StepOutput: + return self.__call__() class NoteItemHolders: @@ -489,6 +508,83 @@ def is_active(self) -> bool: ) +class PromptStep(AsyncStep): + """ + A Step that is capable of prompting the user for input. + This is useful for interactive workflows where data labeling, model evaluation, or any other human input is required. + Once the prompt is handled, the execution lifecycle of the Step will proceed, normally. + """ + def __init__(self): + super().__init__() + self._is_awaiting_response = False + self._awaiting_note = None + + def handle_prompt_response(self, response: Any): + note = self._awaiting_note + try: + assert note is not None, "PromptStep is not awaiting a response." + self.on_prompt_response(note, response) + self._out_queue.append(note) + except Exception as e: + self.log(f"{type(e).__name__}: {str(e)}", "error") + traceback.print_exc() + + self._is_awaiting_response = False + self._awaiting_note = None + prompt(prompts.none()) + + def on_clear(self): + """ + Clears any awaiting prompts and the prompt queue. + If you plan on overriding this method, make sure to call super().on_clear() to ensure the prompt queue is cleared. + """ + self._is_awaiting_response = False + self._awaiting_note = None + prompt(prompts.none()) + super().on_clear() + + def get_prompt(self, note: Note) -> dict: + """ + Returns the prompt to be displayed to the user. + This method can be overriden by the subclass. + By default, it will return a boolean prompt. + If None is returned, the prompt will be skipped on this note. + A list of available prompts can be found in ``graphbook.prompts``. + + Args: + note (Note): The Note input to display to the user + """ + return prompts.bool_prompt(note) + + def on_prompt_response(self, note: Note, response: Any): + """ + Called when the user responds to the prompt. + This method must be overriden by the subclass. + + Args: + note (Note): The Note input that was prompted + response (Any): The user's response + """ + raise NotImplementedError( + "on_prompt_response must be implemented for PromptStep" + ) + + def __call__(self): + if not self._is_awaiting_response and len(self._in_queue) > 0: + note = self._in_queue.pop(0) + p = self.get_prompt(note) + if p: + prompt(self.get_prompt(note)) + self._is_awaiting_response = True + self._awaiting_note = note + else: + self._out_queue.append(note) + return super().__call__() + + def is_active(self) -> bool: + return len(self._in_queue) > 0 or self._awaiting_note is not None + + class Split(Step): """ Routes incoming Notes into either of two output slots, A or B. If split_fn diff --git a/graphbook/utils.py b/graphbook/utils.py index 28d4601..a33e724 100644 --- a/graphbook/utils.py +++ b/graphbook/utils.py @@ -16,7 +16,7 @@ MP_WORKER_TIMEOUT = 5.0 ProcessorStateRequest = Enum( "ProcessorStateRequest", - ["GET_OUTPUT_NOTE", "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE"], + ["GET_OUTPUT_NOTE", "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE", "PROMPT_RESPONSE"], ) diff --git a/graphbook/viewer.py b/graphbook/viewer.py index f1fd16c..3a35c7f 100644 --- a/graphbook/viewer.py +++ b/graphbook/viewer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List from aiohttp.web import WebSocketResponse import uuid import asyncio @@ -8,7 +8,7 @@ import queue import copy import psutil -from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for, transform_json_log +from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for class Viewer: @@ -177,6 +177,23 @@ def get_next(self): } +class PromptViewer(Viewer): + def __init__(self): + super().__init__("prompt") + self.prompts = {} + + def handle_prompt(self, node_id: str, prompt: dict): + prev_prompt = self.prompts.get(node_id) + idx = 0 + if prev_prompt: + idx = prev_prompt["idx"] + 1 + prompt["idx"] = idx + self.prompts[node_id] = prompt + + def get_next(self): + return self.prompts + + DEFAULT_CLIENT_OPTIONS = {"SEND_EVERY": 0.5} @@ -224,11 +241,13 @@ def __init__( self.node_stats_viewer = NodeStatsViewer() self.logs_viewer = NodeLogsViewer() self.system_util_viewer = SystemUtilViewer(processor_state_conn) + self.prompt_viewer = PromptViewer() self.viewers: List[Viewer] = [ self.data_viewer, self.node_stats_viewer, self.logs_viewer, self.system_util_viewer, + self.prompt_viewer, ] self.clients: Dict[str, Client] = {} self.work_queue = work_queue @@ -275,6 +294,9 @@ def handle_clear(self, node_id: str | None): def handle_log(self, node_id: str, log: str, type: str): self.logs_viewer.handle_log(node_id, log, type) + def handle_prompt(self, node_id: str, prompt: dict): + self.prompt_viewer.handle_prompt(node_id, prompt) + def handle_end(self): for viewer in self.viewers: viewer.handle_end() @@ -315,6 +337,9 @@ def _loop(self): self.handle_run_state(work["is_running"], work["filename"]) elif work["cmd"] == "handle_clear": self.handle_clear(work["node_id"]) + elif work["cmd"] == "handle_prompt": + self.handle_prompt(work["node_id"], work["prompt"]) + except queue.Empty: pass @@ -364,3 +389,8 @@ def handle_run_state(self, run_state: dict): def handle_clear(self, node_id: str | None): self.view_manager_queue.put({"cmd": "handle_clear", "node_id": node_id}) + + def handle_prompt(self, node_id: str, prompt: dict): + self.view_manager_queue.put( + {"cmd": "handle_prompt", "node_id": node_id, "prompt": prompt} + ) diff --git a/graphbook/web.py b/graphbook/web.py index fc8f2cf..648cf02 100644 --- a/graphbook/web.py +++ b/graphbook/web.py @@ -205,6 +205,14 @@ async def clear(request: web.Request) -> web.Response: } ) return web.json_response({"success": True}) + + @routes.post("/prompt_response/{id}") + async def prompt_response(request: web.Request) -> web.Response: + step_id = request.match_info.get("id") + data = await request.json() + response = data.get("response") + res = poll_conn_for(state_conn, ProcessorStateRequest.PROMPT_RESPONSE, {"step_id": step_id, "response": response}) + return web.json_response(res) @routes.get("/nodes") async def get_nodes(request: web.Request) -> web.Response: @@ -516,7 +524,10 @@ def signal_handler(*_): close_event.set() if img_mem: - img_mem.close() + try: + img_mem.close() + except FileNotFoundError: + pass raise KeyboardInterrupt() diff --git a/pyproject.toml b/pyproject.toml index 2083781..ad1e100 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "graphbook" -version = "0.7.0" +version = "0.8.0" authors = ["Richard Franklin "] description = "An extensible ML workflow framework built for data scientists and ML engineers." keywords = ["ml", "workflow", "framework", "pytorch", "data science", "machine learning", "ai"] diff --git a/web/package.json b/web/package.json index 1e80146..a64c9c3 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "graphbook-web", "private": true, - "version": "0.7.0", + "version": "0.8.0", "type": "module", "scripts": { "dev": "vite", diff --git a/web/src/api.ts b/web/src/api.ts index e1d501a..b1f4842 100644 --- a/web/src/api.ts +++ b/web/src/api.ts @@ -248,6 +248,10 @@ export class ServerAPI { return await this.get('state'); } + public async respondToPrompt(stepId: string, response: any) { + return await this.post(`prompt_response/${stepId}`, { response }); + } + /** * File API */ diff --git a/web/src/components/Flow.tsx b/web/src/components/Flow.tsx index 1f6e174..8e463ac 100644 --- a/web/src/components/Flow.tsx +++ b/web/src/components/Flow.tsx @@ -118,7 +118,7 @@ export default function Flow({ filename }) { return onNodesChange(newChanges); } onNodesChange(changes); - }, []); + }, [runState]); const onEdgesChangeCallback = useCallback((changes) => { setIsAddNodeActive(false); @@ -135,7 +135,7 @@ export default function Flow({ filename }) { return onEdgesChange(newChanges); } onEdgesChange(changes); - }, []); + }, [runState]); const onConnect = useCallback((params) => { const targetNode = nodes.find(n => n.id === params.target); diff --git a/web/src/components/Monitor.tsx b/web/src/components/Monitor.tsx index 3b48029..b3d10ac 100644 --- a/web/src/components/Monitor.tsx +++ b/web/src/components/Monitor.tsx @@ -403,7 +403,7 @@ function NotesView({ stepId, numNotes, type }: NotesViewProps) { - + ); } @@ -164,6 +166,9 @@ export function WorkflowStep({ id, data, selected }) { }).filter(x => x) } +

+ +
{!data.isCollapsed && } @@ -229,7 +234,7 @@ function QuickviewCollapse({ data }) { span { + font-size: .6em; } .workflow-node .ant-btn-icon { @@ -142,10 +203,21 @@ textarea.code { font-size: .6em; } -.workflow-node .outputs .output .label{ +.workflow-node .outputs .output .label { + font-size: .6em; +} + +.workflow-node .widgets .prompt .ant-typography { font-size: .6em; } +.workflow-node .widgets .prompt .ant-radio-button-wrapper { + font-size: 10px; + line-height: 1.2; + height: 14px; + padding: 0 4px; +} + .workflow-node .handles { display: flex; flex-direction: row; diff --git a/web/src/components/Nodes/widgets/NotePreview.tsx b/web/src/components/Nodes/widgets/NotePreview.tsx new file mode 100644 index 0000000..33ee4dd --- /dev/null +++ b/web/src/components/Nodes/widgets/NotePreview.tsx @@ -0,0 +1,80 @@ +import React, { CSSProperties, useMemo } from 'react'; +import { Image, Space, Flex } from 'antd'; +import { theme } from 'antd'; +import { useSettings } from '../../../hooks/Settings'; +import { getMediaPath } from '../../../utils'; +import type { ImageRef } from '../../../utils'; +import ReactJson from '@microlink/react-json-view'; + +type QuickViewEntry = { + [key: string]: any; +}; + +export function NotePreview({ data, showImages }: { data: QuickViewEntry, showImages: boolean }) { + const globalTheme = theme.useToken().theme; + + if (!showImages) { + return ( + + ); + + } + return ( + + ); + +} + +function EntryImages({ entry, style }: { entry: QuickViewEntry, style: CSSProperties | undefined }) { + const [settings, _] = useSettings(); + + const imageEntries = useMemo(() => { + let entries: { [key: string]: ImageRef[] } = {}; + Object.entries(entry).forEach(([key, item]) => { + let imageItems: any = []; + if (Array.isArray(item)) { + imageItems = item.filter(item => item.type?.slice(0, 5) === 'image'); + } else { + if (item.type?.slice(0, 5) === 'image') { + imageItems.push(item); + } + } + if (imageItems.length > 0) { + entries[key] = imageItems; + } + }); + return entries; + }, [entry]); + + return ( + + { + Object.entries(imageEntries).map(([key, images]) => { + return ( + +
{key}
+ + { + images.map((image, i) => ( + + )) + } + +
+ + ); + }) + } +
+ ); +} diff --git a/web/src/components/Nodes/widgets/Prompts.tsx b/web/src/components/Nodes/widgets/Prompts.tsx new file mode 100644 index 0000000..3681b4f --- /dev/null +++ b/web/src/components/Nodes/widgets/Prompts.tsx @@ -0,0 +1,73 @@ +import React, { useCallback, useMemo, useState } from 'react'; +import { Typography, Flex, Button } from 'antd'; +import { usePluginWidgets } from '../../../hooks/Plugins'; +import { NotePreview } from './NotePreview'; +import { ListWidget, getWidgetLookup } from './Widgets'; +import { useAPI } from '../../../hooks/API'; +import { usePrompt } from '../../../hooks/Prompts'; +import type { Prompt } from '../../../hooks/Prompts'; +import { parseDictWidgetValue } from '../../../utils'; + +const { Text } = Typography; + +function WidgetPrompt({ type, options, value, onChange }) { + const pluginWidgets = usePluginWidgets(); + const widgets = useMemo(() => { + return getWidgetLookup(pluginWidgets); + }, [pluginWidgets]); + + if (type.startsWith('list')) { + return + } + + if (widgets[type]) { + return widgets[type]({ name: "Answer", def: value, onChange, ...options }); + } +} + +export function Prompt({ nodeId }: { nodeId: string }) { + const API = useAPI(); + const [value, setValue] = useState(null); + const [loading, setLoading] = useState(false); + + const onPromptChange = useCallback((prompt) => { + setValue(prompt?.def); + }, []); + + const [prompt, setSubmitted] = usePrompt(nodeId, onPromptChange); + + const onChange = useCallback((value) => { + setValue(value); + }, []); + + const onSubmit = useCallback(async () => { + if (API) { + setLoading(true); + try { + let answer = value; + if (prompt?.type === 'dict') { + answer = parseDictWidgetValue(answer); + } + await API.respondToPrompt(nodeId, answer); + } catch (e) { + console.error(e); + } + setLoading(false); + setSubmitted(); + } + }, [value, API, nodeId]); + + if (!prompt || prompt.type === null) { + return null; + } + + return ( + + Prompted: + + {prompt.msg} + + + + ); +} diff --git a/web/src/components/Nodes/Widgets.tsx b/web/src/components/Nodes/widgets/Widgets.tsx similarity index 65% rename from web/src/components/Nodes/Widgets.tsx rename to web/src/components/Nodes/widgets/Widgets.tsx index 5ebff5f..921416a 100644 --- a/web/src/components/Nodes/Widgets.tsx +++ b/web/src/components/Nodes/widgets/Widgets.tsx @@ -1,17 +1,17 @@ -import { Switch, Typography, theme, Flex, Button } from 'antd'; -import { PlusOutlined, MinusCircleOutlined, MinusOutlined } from '@ant-design/icons'; -import React, { useCallback, useState, useMemo } from 'react'; +import { Switch, Typography, theme, Flex, Button, Radio, Select as ASelect } from 'antd'; +import { PlusOutlined, MinusOutlined } from '@ant-design/icons'; +import React, { useCallback, useState, useMemo, useEffect } from 'react'; import CodeMirror from '@uiw/react-codemirror'; import { python } from '@codemirror/lang-python'; import { basicDark } from '@uiw/codemirror-theme-basic'; import { bbedit } from '@uiw/codemirror-theme-bbedit'; -import { Graph } from '../../graph'; +import { Graph } from '../../../graph'; import { useReactFlow } from 'reactflow'; -import { usePluginWidgets } from '../../hooks/Plugins'; +import { usePluginWidgets } from '../../../hooks/Plugins'; const { Text } = Typography; const { useToken } = theme; -const getWidgetLookup = (pluginWidgets) => { +export const getWidgetLookup = (pluginWidgets) => { const lookup = { number: NumberWidget, string: StringWidget, @@ -19,14 +19,16 @@ const getWidgetLookup = (pluginWidgets) => { bool: BooleanWidget, function: FunctionWidget, dict: DictWidget, + selection: SelectionWidget, }; + pluginWidgets.forEach((widget) => { lookup[widget.type] = widget.children; }); return lookup; }; -export function Widget({ id, type, name, value }) { +export function Widget({ id, type, name, value, ...props }) { const { setNodes } = useReactFlow(); const pluginWidgets = usePluginWidgets(); const widgets = useMemo(() => { @@ -44,7 +46,7 @@ export function Widget({ id, type, name, value }) { } if (widgets[type]) { - return widgets[type]({ name, def: value, onChange }); + return widgets[type]({ name, def: value, onChange, ...props }); } } @@ -60,13 +62,27 @@ export function StringWidget({ name, def, onChange }) { ); } -export function BooleanWidget({ name, def, onChange }) { +export function BooleanWidget({ name, def, onChange, style }) { + const input = useMemo(() => { + if (style === "yes/no") { + const M = { + "Yes": true, + "No": false, + }; + const M_ = { + true: "Yes", + false: "No", + }; + return onChange(M[e.target.value])} value={M_[def]} optionType="button" /> + } + return + }, [style, def]); return ( - + {name} - + {input} - ) + ); } export function FunctionWidget({ name, def, onChange }) { @@ -148,29 +164,35 @@ export function ListWidget({ name, def, onChange, type }) { - ) + ); } -export function DictWidget({ name, def, onChange, type }) { +export function DictWidget({ name, def, onChange }) { const { token } = useToken(); const value = useMemo(() => { - if (!def) { - return []; - } - if (Array.isArray(def)) { return def; } + // Needs to be converted to an array of 3-tuples + + if (!def) { + onChange([]); + return []; + } + try { - return Object.entries(def).map(([key, value]) => { + const newValue = Object.entries(def).map(([key, value]) => { let type = typeof value as string; if (type === 'number') { type = 'float'; } return [type, key, value]; }); + onChange(newValue); + return newValue; } catch (e) { + onChange([]); return []; } }, [def]); @@ -210,10 +232,10 @@ export function DictWidget({ name, def, onChange, type }) { const selectedInputs = useMemo(() => { return value.map((item, i) => { if (item[0] === 'string') { - return onValueChange(i, value)} value={item[2]} /> + return onValueChange(i, value)} value={item[2]} /> } if (item[0] === 'float' || item[0] === 'int' || item[0] === 'number') { - return onValueChange(i, value)} value={item[2]} /> + return onValueChange(i, value)} value={item[2]} /> } if (item[0] === 'boolean') { return onValueChange(i, value)} value={item[2]} /> @@ -230,66 +252,113 @@ export function DictWidget({ name, def, onChange, type }) { }, []); return ( - + {name} {value && value.map((item, i) => { return ( - - + + onKeyChange(i, value)} value={item[1]} /> - - - - { - selectedInputs[i] - } -