diff --git a/README.md b/README.md
index 5ddeeed..acf6ef8 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,34 @@
Graphbook
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
The ML workflow framework
@@ -23,7 +51,7 @@
## Overview
-Graphbook is a framework for building efficient, visual DAG-structured ML workflows composed of nodes written in Python. Graphbook provides common ML processing features such as multiprocessing IO and automatic batching, and it features a web-based UI to assemble, monitor, and execute data processing workflows. It can be used to prepare training data for custom ML models, experiment with custom trained or off-the-shelf models, and to build ML-based ETL applications. Custom nodes can be built in Python, and Graphbook will behave like a framework and call lifecycle methods on those nodes.
+Graphbook is a framework for building efficient, visual DAG-structured ML workflows composed of nodes written in Python. Graphbook provides common ML processing features such as multiprocessing IO and automatic batching for PyTorch tensors, and it features a web-based UI to assemble, monitor, and execute data processing workflows. It can be used to prepare training data for custom ML models, experiment with custom trained or off-the-shelf models, and to build ML-based ETL applications. Custom nodes can be built in Python, and Graphbook will behave like a framework and call lifecycle methods on those nodes.
diff --git a/graphbook/__init__.py b/graphbook/__init__.py
index 44e8cc5..14fb2eb 100644
--- a/graphbook/__init__.py
+++ b/graphbook/__init__.py
@@ -1,4 +1,4 @@
from .note import Note
-from .decorators import step, param, source, output, batch, resource, event
+from .decorators import step, param, source, output, batch, resource, event, prompt
-__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "Note"]
+__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "prompt", "Note"]
diff --git a/graphbook/custom_nodes.py b/graphbook/custom_nodes.py
index ffdbdc1..0fba930 100644
--- a/graphbook/custom_nodes.py
+++ b/graphbook/custom_nodes.py
@@ -13,6 +13,7 @@
from graphbook.steps import (
Step,
BatchStep,
+ PromptStep,
SourceStep,
GeneratorSourceStep,
AsyncStep,
@@ -25,6 +26,7 @@
BUILT_IN_STEPS = [
Step,
BatchStep,
+ PromptStep,
SourceStep,
GeneratorSourceStep,
AsyncStep,
diff --git a/graphbook/decorators.py b/graphbook/decorators.py
index 17849bf..16cd909 100644
--- a/graphbook/decorators.py
+++ b/graphbook/decorators.py
@@ -31,14 +31,13 @@ def param(
"required": required,
"description": description,
}
+ self.parameter_type_casts[name] = cast_as
if cast_as is None:
# Default casts
if type == "function":
self.parameter_type_casts[name] = transform_function_string
if type == "int":
self.parameter_type_casts[name] = int
- else:
- self.parameter_type_casts[name] = cast_as
@abc.abstractmethod
def build():
@@ -85,6 +84,11 @@ def batch(
if dump_fn is not None:
self.event("dump_fn", dump_fn)
+ def prompt(self, get_prompt=None):
+ self.BaseClass = steps.PromptStep
+ if get_prompt is not None:
+ self.event("get_prompt", get_prompt)
+
def build(self):
def __init__(cls, **kwargs):
if self.BaseClass == steps.BatchStep:
@@ -212,6 +216,8 @@ def decorator(func):
factory.event("on_note", func)
elif factory.BaseClass == steps.BatchStep:
factory.event("on_item_batch", func)
+ elif factory.BaseClass == steps.PromptStep:
+ factory.event("on_prompt_response", func)
else:
factory.event("load", func)
@@ -456,3 +462,65 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator
+
+
+def prompt(get_prompt: callable = None):
+ """
+ Marks a function as a step that is capable of prompting the user.
+ This is useful for interactive workflows where data labeling, model evaluation, or any other human input is required.
+ Events ``get_prompt(ctx, note: Note)`` and ``on_prompt_response(ctx, note: Note, response: Any)`` are required to be implemented.
+ The decorator accepts the ``get_prompt`` function that returns a prompt to display to the user.
+ If nothing is passed as an argument, a ``bool_prompt`` will be used by default.
+ If the function returns **None** on any given note, no prompt will be displayed for that note allowing for conditional prompts based on the note's content.
+ Available prompts are located in the ``graphbook.prompts`` module.
+ The function that this decorator decorates is ``on_prompt_response`` and will be called when a response to a prompt is obtained from a user.
+ Once the prompt is handled, the execution lifecycle of the Step will proceed, normally.
+
+ Args:
+ get_prompt (callable): A function that returns a prompt. Default is ``bool_prompt``.
+
+ Examples:
+ .. highlight:: python
+ .. code-block:: python
+
+ def dog_or_cat(ctx, note: Note):
+ return selection_prompt(note, choices=["dog", "cat"], show_images=True)
+
+
+ @step("Prompts/Label")
+ @prompt(dog_or_cat)
+ def label_images(ctx, note: Note, response: str):
+ note["label"] = response
+
+
+ def corrective_prompt(ctx, note: Note):
+ if note["prediction_confidence"] < 0.65:
+ return bool_prompt(
+ note,
+ msg=f"Model prediction ({note['pred']}) was uncertain. Is its prediction correct?",
+ show_images=True,
+ )
+ else:
+ return None
+
+
+ @step("Prompts/CorrectModelLabel")
+ @prompt(corrective_prompt)
+ def correct_model_labels(ctx, note: Note, response: bool):
+ if response:
+ ctx.log("Model is correct!")
+ note["label"] = note["pred"]
+ else:
+ ctx.log("Model is incorrect!")
+ if note["pred"] == "dog":
+ note["label"] = "cat"
+ else:
+ note["label"] = "dog"
+ """
+ def decorator(func):
+ def set_prompt(factory: StepClassFactory):
+ factory.prompt(get_prompt)
+
+ return DecoratorFunction(func, set_prompt)
+
+ return decorator
diff --git a/graphbook/logger.py b/graphbook/logger.py
index 40dd64c..0d80417 100644
--- a/graphbook/logger.py
+++ b/graphbook/logger.py
@@ -2,6 +2,7 @@
import multiprocessing as mp
from typing import Dict, Tuple, Any
import inspect
+from graphbook.note import Note
from graphbook.viewer import ViewManagerInterface
from graphbook.utils import transform_json_log
@@ -44,3 +45,18 @@ def log(msg: Any, type: LogType = "info", caller_id: int | None = None):
else:
raise ValueError(f"Unknown log type {type}")
view_manager.handle_log(node_id, msg, type)
+
+def prompt(prompt: dict, caller_id: int | None = None):
+ if caller_id is None:
+ prev_frame = inspect.currentframe().f_back
+ caller = prev_frame.f_locals.get("self")
+ if caller is not None:
+ caller_id = id(caller)
+
+ node = logging_nodes.get(caller_id, None)
+ if node is None:
+ raise ValueError(
+ f"Can't find node id in {caller}. Only initialized steps can log."
+ )
+ node_id, _ = node
+ view_manager.handle_prompt(node_id, prompt)
diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py
index fb293dd..a26c6ca 100644
--- a/graphbook/processing/web_processor.py
+++ b/graphbook/processing/web_processor.py
@@ -1,4 +1,11 @@
-from graphbook.steps import Step, SourceStep, GeneratorSourceStep, AsyncStep, StepOutput
+from graphbook.steps import (
+ Step,
+ SourceStep,
+ GeneratorSourceStep,
+ AsyncStep,
+ BatchStep,
+ StepOutput,
+)
from graphbook.dataloading import Dataloader, setup_global_dl
from graphbook.utils import MP_WORKER_TIMEOUT, ProcessorStateRequest, transform_json_log
from graphbook.state import GraphState, StepState, NodeInstantiationError
@@ -49,7 +56,10 @@ def __init__(
self.dataloader = Dataloader(self.num_workers)
setup_global_dl(self.dataloader)
self.state_client = ProcessorStateClient(
- server_request_conn, close_event, self.graph_state, self.dataloader
+ server_request_conn,
+ close_event,
+ self.graph_state,
+ self.dataloader,
)
self.is_running = False
self.filename = None
@@ -120,13 +130,11 @@ def exec_step(
id(step),
)
return None
-
self.handle_images(outputs)
self.graph_state.handle_outputs(
step.id, outputs if not self.copy_outputs else copy.deepcopy(outputs)
)
- self.view_manager.handle_outputs(step.id, transform_json_log(outputs))
- self.view_manager.handle_time(step.id, time.time() - start_time)
+ self.view_manager.handle_time(step.id, time.time() - start_time)
return outputs
def handle_steps(self, steps: List[Step]) -> bool:
@@ -182,13 +190,25 @@ def step_until_received_output(self, steps: List[Step], step_id: str):
step_executed = self.graph_state.get_state(
step_id, StepState.EXECUTED_THIS_RUN
)
+
+ def try_execute_step_event(self, step: Step, event: str):
+ try:
+ if hasattr(step, event):
+ getattr(step, event)()
+ return True
+ except Exception as e:
+ log(f"{type(e).__name__}: {str(e)}", "error", id(step))
+ traceback.print_exc()
+ return False
def run(self, step_id: str = None):
steps: List[Step] = self.graph_state.get_processing_steps(step_id)
- self.setup_dataloader(steps)
for step in steps:
self.view_manager.handle_start(step.id)
- step.on_start()
+ succeeded = self.try_execute_step_event(step, "on_start")
+ if not succeeded:
+ return
+ self.setup_dataloader(steps)
self.pause_event.clear()
dag_is_active = True
try:
@@ -201,24 +221,26 @@ def run(self, step_id: str = None):
dag_is_active = self.handle_steps(steps)
finally:
self.view_manager.handle_end()
- for step in steps:
- step.on_end()
self.dataloader.stop()
+ for step in steps:
+ self.try_execute_step_event(step, "on_end")
def step(self, step_id: str = None):
steps: List[Step] = self.graph_state.get_processing_steps(step_id)
- self.setup_dataloader(steps)
for step in steps:
self.view_manager.handle_start(step.id)
- step.on_start()
+ succeeded = self.try_execute_step_event(step, "on_start")
+ if not succeeded:
+ return
+ self.setup_dataloader(steps)
self.pause_event.clear()
try:
self.step_until_received_output(steps, step_id)
finally:
self.view_manager.handle_end()
- for step in steps:
- step.on_end()
self.dataloader.stop()
+ for step in steps:
+ self.try_execute_step_event(step, "on_end")
def set_is_running(self, is_running: bool = True, filename: str | None = None):
self.is_running = is_running
@@ -232,7 +254,7 @@ def cleanup(self):
self.dataloader.shutdown()
def setup_dataloader(self, steps: List[Step]):
- dataloader_consumers = [step for step in steps if isinstance(step, AsyncStep)]
+ dataloader_consumers = [step for step in steps if isinstance(step, BatchStep)]
consumer_ids = [id(c) for c in dataloader_consumers]
consumer_load_fn = [
c.load_fn if hasattr(c, "load_fn") else None for c in dataloader_consumers
@@ -323,6 +345,10 @@ def _loop(self):
output = self.dataloader.get_all_sizes()
elif req["cmd"] == ProcessorStateRequest.GET_RUNNING_STATE:
output = self.running_state
+ elif req["cmd"] == ProcessorStateRequest.PROMPT_RESPONSE:
+ step_id = req.get("step_id")
+ succeeded = self.graph_state.handle_prompt_response(step_id, req.get("response"))
+ output = {"ok": succeeded}
else:
output = {}
entry = {"res": req["cmd"], "data": output}
diff --git a/graphbook/prompts.py b/graphbook/prompts.py
new file mode 100644
index 0000000..04b4b24
--- /dev/null
+++ b/graphbook/prompts.py
@@ -0,0 +1,158 @@
+from typing import Any, List
+from .note import Note
+from .utils import transform_json_log
+
+
+def none():
+ return {"type": None}
+
+
+def prompt(note: Note, *, msg: str = "", show_images: bool = False, default: Any = ""):
+ return {
+ "note": transform_json_log(note),
+ "msg": msg,
+ "show_images": show_images,
+ "def": default,
+ }
+
+
+def bool_prompt(
+ note: Note,
+ *,
+ msg: str = "Continue?",
+ style: str = "yes/no",
+ default: bool = False,
+ show_images: bool = False,
+):
+ """
+ Prompt the user with a yes/no for binary questions.
+
+ Args:
+ note (Note): The current note that triggered the prompt.
+ msg (str): An informative message or inquiry to display to the user.
+ style (str): The style of the bool prompt. Can be "yes/no" or "switch".
+ default (bool): The default bool value.
+ show_images (bool): Whether to present the images (instead of the Note object) to the user.
+ """
+ p = prompt(note, msg=msg, default=default, show_images=show_images)
+ p["type"] = "bool"
+ p["options"] = {"style": style}
+ return p
+
+def selection_prompt(
+ note: Note,
+ choices: List[str],
+ *,
+ msg: str = "Select an option:",
+ default: List[str] | str = None,
+ show_images: bool = False,
+ multiple_allowed: bool = False
+):
+ """
+ Prompt the user to select an option from a list of choices.
+
+ Args:
+ note (Note): The current note that triggered the prompt.
+ choices (List[str]): A list of strings representing the options the user can select.
+ msg (str): An informative message or inquiry to display to the user.
+ default (List[str] | str): The default value. If multiple_allowed is True, this should be a list of strings.
+ show_images (bool): Whether to present the images (instead of the Note object) to the user.
+ multiple_allowed (bool): Whether the user can select multiple options from the list of given choices.
+ """
+ assert len(choices) > 0, "Choices must not be empty in selection prompt."
+ if default is None:
+ if multiple_allowed:
+ default = [choices[0]]
+ else:
+ default = choices[0]
+
+ p = prompt(note, msg=msg, default=default, show_images=show_images)
+ p["type"] = "selection"
+ p["options"] = {
+ "choices": choices,
+ "multiple_allowed": multiple_allowed,
+ }
+ return p
+
+def text_prompt(
+ note: Note,
+ *,
+ msg: str = "Enter text:",
+ default: str = "",
+ show_images: bool = False,
+):
+ """
+ Prompt the user to enter text.
+
+ Args:
+ note (Note): The current note that triggered the prompt.
+ msg (str): An informative message or inquiry to display to the user.
+ default (str): The default text value.
+ show_images (bool): Whether to present the images (instead of the Note object) to the user.
+ """
+ p = prompt(note, msg=msg, default=default, show_images=show_images)
+ p["type"] = "string"
+ return p
+
+def number_prompt(
+ note: Note,
+ *,
+ msg: str = "Enter a number:",
+ default: float = 0.0,
+ show_images: bool = False,
+):
+ """
+ Prompt the user to enter a number.
+
+ Args:
+ note (Note): The current note that triggered the prompt.
+ msg (str): An informative message or inquiry to display to the user.
+ default (float): The default number value.
+ show_images (bool): Whether to present the images (instead of the Note object) to the user.
+ """
+ p = prompt(note, msg=msg, default=default, show_images=show_images)
+ p["type"] = "number"
+ return p
+
+def dict_prompt(
+ note: Note,
+ *,
+ msg: str = "Enter a dictionary:",
+ default: dict = {},
+ show_images: bool = False,
+):
+ """
+ Prompt the user to enter a dictionary.
+
+ Args:
+ note (Note): The current note that triggered the prompt.
+ msg (str): An informative message or inquiry to display to the user.
+ default (dict): The default dictionary value.
+ show_images (bool): Whether to present the images (instead of the Note object) to the user.
+ """
+ p = prompt(note, msg=msg, default=default, show_images=show_images)
+ p["type"] = "dict"
+ return p
+
+def list_prompt(
+ note: Note,
+ type: str = "string",
+ *,
+ msg: str = "Enter a list:",
+ default: list = [],
+ show_images: bool = False,
+):
+ """
+ Prompt the user to enter a list.
+
+ Args:
+ note (Note): The current note that triggered the prompt.
+ type (str): The type of the list elements. Can be "string", "number", "dict", or "bool".
+ msg (str): An informative message or inquiry to display to the user.
+ default (list): The default list value.
+ show_images (bool): Whether to present the images (instead of the Note object) to the user.
+ """
+ assert type in ["string", "number", "dict", "bool"], "Invalid type in list prompt."
+ p = prompt(note, msg=msg, default=default, show_images=show_images)
+ p["type"] = f"list[{type}]"
+ return p
diff --git a/graphbook/state.py b/graphbook/state.py
index ce35071..4be7d56 100644
--- a/graphbook/state.py
+++ b/graphbook/state.py
@@ -2,12 +2,13 @@
from aiohttp.web import WebSocketResponse
from typing import Dict, Tuple, List, Iterator, Set
from graphbook.note import Note
-from graphbook.steps import Step, StepOutput as Outputs
+from graphbook.steps import Step, PromptStep, StepOutput as Outputs
from graphbook.resources import Resource
from graphbook.decorators import get_steps, get_resources
from graphbook.viewer import ViewManagerInterface
from graphbook.plugins import setup_plugins
from graphbook.logger import setup_logging_nodes
+from graphbook.utils import transform_json_log
import multiprocessing as mp
import importlib, importlib.util, inspect
import graphbook.exports as exports
@@ -344,6 +345,7 @@ def handle_outputs(self, step_id: str, outputs: Outputs):
self._step_states[step_id].add(StepState.EXECUTED)
self._step_states[step_id].add(StepState.EXECUTED_THIS_RUN)
self.view_manager.handle_queue_size(step_id, self._queues[step_id].dict_sizes())
+ self.view_manager.handle_outputs(step_id, transform_json_log(outputs))
def clear_outputs(self, node_id: str | None = None):
if node_id is None:
@@ -400,6 +402,16 @@ def get_output_note(self, step_id: str, pin_id: str, index: int) -> dict:
note = internal_list[index]
entry.update(data=note.items)
return entry
+
+ def handle_prompt_response(self, step_id: str, response: dict) -> bool:
+ step = self._steps.get(step_id)
+ if not isinstance(step, PromptStep):
+ return False
+ try:
+ step.handle_prompt_response(response)
+ return True
+ except:
+ return False
def get_step(self, step_id: str):
return self._steps.get(step_id)
diff --git a/graphbook/steps/__init__.py b/graphbook/steps/__init__.py
index 22ecd60..f2a2bab 100644
--- a/graphbook/steps/__init__.py
+++ b/graphbook/steps/__init__.py
@@ -3,6 +3,7 @@
SourceStep,
GeneratorSourceStep,
BatchStep,
+ PromptStep,
StepOutput,
AsyncStep,
Split,
@@ -16,6 +17,7 @@
"SourceStep",
"GeneratorSourceStep",
"BatchStep",
+ "PromptStep",
"StepOutput",
"AsyncStep",
"Split",
diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py
index a08b95a..b0fe684 100644
--- a/graphbook/steps/base.py
+++ b/graphbook/steps/base.py
@@ -1,11 +1,16 @@
from __future__ import annotations
from typing import List, Dict, Tuple, Generator, Any
-from ..utils import transform_function_string, convert_dict_values_to_list, is_batchable
from graphbook import Note
-from graphbook.logger import log
+from graphbook.utils import (
+ transform_function_string,
+ convert_dict_values_to_list,
+ is_batchable,
+)
+from graphbook.logger import log, prompt
+import graphbook.prompts as prompts
import graphbook.dataloading as dataloader
import warnings
-
+import traceback
warnings.simplefilter("default", DeprecationWarning)
@@ -221,8 +226,12 @@ class AsyncStep(Step):
def __init__(self, item_key=None):
super().__init__(item_key)
- self._is_processing = True
self._in_queue = []
+ self._out_queue = []
+
+ def on_clear(self):
+ self._in_queue = []
+ self._out_queue = []
def in_q(self, note: Note | None):
if note is None:
@@ -230,7 +239,17 @@ def in_q(self, note: Note | None):
self._in_queue.append(note)
def is_active(self) -> bool:
- return self._is_processing
+ return len(self._in_queue) > 0
+
+ def __call__(self) -> StepOutput:
+ # 1. on_note -> 2. on_item -> 3. on_after_item -> 4. forward_note
+ if len(self._out_queue) == 0:
+ return {}
+ note = self._out_queue.pop(0)
+ return super().__call__(note)
+
+ def all(self) -> StepOutput:
+ return self.__call__()
class NoteItemHolders:
@@ -489,6 +508,83 @@ def is_active(self) -> bool:
)
+class PromptStep(AsyncStep):
+ """
+ A Step that is capable of prompting the user for input.
+ This is useful for interactive workflows where data labeling, model evaluation, or any other human input is required.
+ Once the prompt is handled, the execution lifecycle of the Step will proceed, normally.
+ """
+ def __init__(self):
+ super().__init__()
+ self._is_awaiting_response = False
+ self._awaiting_note = None
+
+ def handle_prompt_response(self, response: Any):
+ note = self._awaiting_note
+ try:
+ assert note is not None, "PromptStep is not awaiting a response."
+ self.on_prompt_response(note, response)
+ self._out_queue.append(note)
+ except Exception as e:
+ self.log(f"{type(e).__name__}: {str(e)}", "error")
+ traceback.print_exc()
+
+ self._is_awaiting_response = False
+ self._awaiting_note = None
+ prompt(prompts.none())
+
+ def on_clear(self):
+ """
+ Clears any awaiting prompts and the prompt queue.
+ If you plan on overriding this method, make sure to call super().on_clear() to ensure the prompt queue is cleared.
+ """
+ self._is_awaiting_response = False
+ self._awaiting_note = None
+ prompt(prompts.none())
+ super().on_clear()
+
+ def get_prompt(self, note: Note) -> dict:
+ """
+ Returns the prompt to be displayed to the user.
+ This method can be overriden by the subclass.
+ By default, it will return a boolean prompt.
+ If None is returned, the prompt will be skipped on this note.
+ A list of available prompts can be found in ``graphbook.prompts``.
+
+ Args:
+ note (Note): The Note input to display to the user
+ """
+ return prompts.bool_prompt(note)
+
+ def on_prompt_response(self, note: Note, response: Any):
+ """
+ Called when the user responds to the prompt.
+ This method must be overriden by the subclass.
+
+ Args:
+ note (Note): The Note input that was prompted
+ response (Any): The user's response
+ """
+ raise NotImplementedError(
+ "on_prompt_response must be implemented for PromptStep"
+ )
+
+ def __call__(self):
+ if not self._is_awaiting_response and len(self._in_queue) > 0:
+ note = self._in_queue.pop(0)
+ p = self.get_prompt(note)
+ if p:
+ prompt(self.get_prompt(note))
+ self._is_awaiting_response = True
+ self._awaiting_note = note
+ else:
+ self._out_queue.append(note)
+ return super().__call__()
+
+ def is_active(self) -> bool:
+ return len(self._in_queue) > 0 or self._awaiting_note is not None
+
+
class Split(Step):
"""
Routes incoming Notes into either of two output slots, A or B. If split_fn
diff --git a/graphbook/utils.py b/graphbook/utils.py
index 28d4601..a33e724 100644
--- a/graphbook/utils.py
+++ b/graphbook/utils.py
@@ -16,7 +16,7 @@
MP_WORKER_TIMEOUT = 5.0
ProcessorStateRequest = Enum(
"ProcessorStateRequest",
- ["GET_OUTPUT_NOTE", "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE"],
+ ["GET_OUTPUT_NOTE", "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE", "PROMPT_RESPONSE"],
)
diff --git a/graphbook/viewer.py b/graphbook/viewer.py
index f1fd16c..3a35c7f 100644
--- a/graphbook/viewer.py
+++ b/graphbook/viewer.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Tuple
+from typing import Dict, List
from aiohttp.web import WebSocketResponse
import uuid
import asyncio
@@ -8,7 +8,7 @@
import queue
import copy
import psutil
-from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for, transform_json_log
+from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for
class Viewer:
@@ -177,6 +177,23 @@ def get_next(self):
}
+class PromptViewer(Viewer):
+ def __init__(self):
+ super().__init__("prompt")
+ self.prompts = {}
+
+ def handle_prompt(self, node_id: str, prompt: dict):
+ prev_prompt = self.prompts.get(node_id)
+ idx = 0
+ if prev_prompt:
+ idx = prev_prompt["idx"] + 1
+ prompt["idx"] = idx
+ self.prompts[node_id] = prompt
+
+ def get_next(self):
+ return self.prompts
+
+
DEFAULT_CLIENT_OPTIONS = {"SEND_EVERY": 0.5}
@@ -224,11 +241,13 @@ def __init__(
self.node_stats_viewer = NodeStatsViewer()
self.logs_viewer = NodeLogsViewer()
self.system_util_viewer = SystemUtilViewer(processor_state_conn)
+ self.prompt_viewer = PromptViewer()
self.viewers: List[Viewer] = [
self.data_viewer,
self.node_stats_viewer,
self.logs_viewer,
self.system_util_viewer,
+ self.prompt_viewer,
]
self.clients: Dict[str, Client] = {}
self.work_queue = work_queue
@@ -275,6 +294,9 @@ def handle_clear(self, node_id: str | None):
def handle_log(self, node_id: str, log: str, type: str):
self.logs_viewer.handle_log(node_id, log, type)
+ def handle_prompt(self, node_id: str, prompt: dict):
+ self.prompt_viewer.handle_prompt(node_id, prompt)
+
def handle_end(self):
for viewer in self.viewers:
viewer.handle_end()
@@ -315,6 +337,9 @@ def _loop(self):
self.handle_run_state(work["is_running"], work["filename"])
elif work["cmd"] == "handle_clear":
self.handle_clear(work["node_id"])
+ elif work["cmd"] == "handle_prompt":
+ self.handle_prompt(work["node_id"], work["prompt"])
+
except queue.Empty:
pass
@@ -364,3 +389,8 @@ def handle_run_state(self, run_state: dict):
def handle_clear(self, node_id: str | None):
self.view_manager_queue.put({"cmd": "handle_clear", "node_id": node_id})
+
+ def handle_prompt(self, node_id: str, prompt: dict):
+ self.view_manager_queue.put(
+ {"cmd": "handle_prompt", "node_id": node_id, "prompt": prompt}
+ )
diff --git a/graphbook/web.py b/graphbook/web.py
index fc8f2cf..648cf02 100644
--- a/graphbook/web.py
+++ b/graphbook/web.py
@@ -205,6 +205,14 @@ async def clear(request: web.Request) -> web.Response:
}
)
return web.json_response({"success": True})
+
+ @routes.post("/prompt_response/{id}")
+ async def prompt_response(request: web.Request) -> web.Response:
+ step_id = request.match_info.get("id")
+ data = await request.json()
+ response = data.get("response")
+ res = poll_conn_for(state_conn, ProcessorStateRequest.PROMPT_RESPONSE, {"step_id": step_id, "response": response})
+ return web.json_response(res)
@routes.get("/nodes")
async def get_nodes(request: web.Request) -> web.Response:
@@ -516,7 +524,10 @@ def signal_handler(*_):
close_event.set()
if img_mem:
- img_mem.close()
+ try:
+ img_mem.close()
+ except FileNotFoundError:
+ pass
raise KeyboardInterrupt()
diff --git a/pyproject.toml b/pyproject.toml
index 2083781..ad1e100 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "graphbook"
-version = "0.7.0"
+version = "0.8.0"
authors = ["Richard Franklin "]
description = "An extensible ML workflow framework built for data scientists and ML engineers."
keywords = ["ml", "workflow", "framework", "pytorch", "data science", "machine learning", "ai"]
diff --git a/web/package.json b/web/package.json
index 1e80146..a64c9c3 100644
--- a/web/package.json
+++ b/web/package.json
@@ -1,7 +1,7 @@
{
"name": "graphbook-web",
"private": true,
- "version": "0.7.0",
+ "version": "0.8.0",
"type": "module",
"scripts": {
"dev": "vite",
diff --git a/web/src/api.ts b/web/src/api.ts
index e1d501a..b1f4842 100644
--- a/web/src/api.ts
+++ b/web/src/api.ts
@@ -248,6 +248,10 @@ export class ServerAPI {
return await this.get('state');
}
+ public async respondToPrompt(stepId: string, response: any) {
+ return await this.post(`prompt_response/${stepId}`, { response });
+ }
+
/**
* File API
*/
diff --git a/web/src/components/Flow.tsx b/web/src/components/Flow.tsx
index 1f6e174..8e463ac 100644
--- a/web/src/components/Flow.tsx
+++ b/web/src/components/Flow.tsx
@@ -118,7 +118,7 @@ export default function Flow({ filename }) {
return onNodesChange(newChanges);
}
onNodesChange(changes);
- }, []);
+ }, [runState]);
const onEdgesChangeCallback = useCallback((changes) => {
setIsAddNodeActive(false);
@@ -135,7 +135,7 @@ export default function Flow({ filename }) {
return onEdgesChange(newChanges);
}
onEdgesChange(changes);
- }, []);
+ }, [runState]);
const onConnect = useCallback((params) => {
const targetNode = nodes.find(n => n.id === params.target);
diff --git a/web/src/components/Monitor.tsx b/web/src/components/Monitor.tsx
index 3b48029..b3d10ac 100644
--- a/web/src/components/Monitor.tsx
+++ b/web/src/components/Monitor.tsx
@@ -403,7 +403,7 @@ function NotesView({ stepId, numNotes, type }: NotesViewProps) {
-
+
);
}
@@ -164,6 +166,9 @@ export function WorkflowStep({ id, data, selected }) {
}).filter(x => x)
}
+
{!data.isCollapsed && }
@@ -229,7 +234,7 @@ function QuickviewCollapse({ data }) {
span {
+ font-size: .6em;
}
.workflow-node .ant-btn-icon {
@@ -142,10 +203,21 @@ textarea.code {
font-size: .6em;
}
-.workflow-node .outputs .output .label{
+.workflow-node .outputs .output .label {
+ font-size: .6em;
+}
+
+.workflow-node .widgets .prompt .ant-typography {
font-size: .6em;
}
+.workflow-node .widgets .prompt .ant-radio-button-wrapper {
+ font-size: 10px;
+ line-height: 1.2;
+ height: 14px;
+ padding: 0 4px;
+}
+
.workflow-node .handles {
display: flex;
flex-direction: row;
diff --git a/web/src/components/Nodes/widgets/NotePreview.tsx b/web/src/components/Nodes/widgets/NotePreview.tsx
new file mode 100644
index 0000000..33ee4dd
--- /dev/null
+++ b/web/src/components/Nodes/widgets/NotePreview.tsx
@@ -0,0 +1,80 @@
+import React, { CSSProperties, useMemo } from 'react';
+import { Image, Space, Flex } from 'antd';
+import { theme } from 'antd';
+import { useSettings } from '../../../hooks/Settings';
+import { getMediaPath } from '../../../utils';
+import type { ImageRef } from '../../../utils';
+import ReactJson from '@microlink/react-json-view';
+
+type QuickViewEntry = {
+ [key: string]: any;
+};
+
+export function NotePreview({ data, showImages }: { data: QuickViewEntry, showImages: boolean }) {
+ const globalTheme = theme.useToken().theme;
+
+ if (!showImages) {
+ return (
+
+ );
+
+ }
+ return (
+
+ );
+
+}
+
+function EntryImages({ entry, style }: { entry: QuickViewEntry, style: CSSProperties | undefined }) {
+ const [settings, _] = useSettings();
+
+ const imageEntries = useMemo(() => {
+ let entries: { [key: string]: ImageRef[] } = {};
+ Object.entries(entry).forEach(([key, item]) => {
+ let imageItems: any = [];
+ if (Array.isArray(item)) {
+ imageItems = item.filter(item => item.type?.slice(0, 5) === 'image');
+ } else {
+ if (item.type?.slice(0, 5) === 'image') {
+ imageItems.push(item);
+ }
+ }
+ if (imageItems.length > 0) {
+ entries[key] = imageItems;
+ }
+ });
+ return entries;
+ }, [entry]);
+
+ return (
+
+ {
+ Object.entries(imageEntries).map(([key, images]) => {
+ return (
+
+ {key}
+
+ {
+ images.map((image, i) => (
+
+ ))
+ }
+
+
+
+ );
+ })
+ }
+
+ );
+}
diff --git a/web/src/components/Nodes/widgets/Prompts.tsx b/web/src/components/Nodes/widgets/Prompts.tsx
new file mode 100644
index 0000000..3681b4f
--- /dev/null
+++ b/web/src/components/Nodes/widgets/Prompts.tsx
@@ -0,0 +1,73 @@
+import React, { useCallback, useMemo, useState } from 'react';
+import { Typography, Flex, Button } from 'antd';
+import { usePluginWidgets } from '../../../hooks/Plugins';
+import { NotePreview } from './NotePreview';
+import { ListWidget, getWidgetLookup } from './Widgets';
+import { useAPI } from '../../../hooks/API';
+import { usePrompt } from '../../../hooks/Prompts';
+import type { Prompt } from '../../../hooks/Prompts';
+import { parseDictWidgetValue } from '../../../utils';
+
+const { Text } = Typography;
+
+function WidgetPrompt({ type, options, value, onChange }) {
+ const pluginWidgets = usePluginWidgets();
+ const widgets = useMemo(() => {
+ return getWidgetLookup(pluginWidgets);
+ }, [pluginWidgets]);
+
+ if (type.startsWith('list')) {
+ return
+ }
+
+ if (widgets[type]) {
+ return widgets[type]({ name: "Answer", def: value, onChange, ...options });
+ }
+}
+
+export function Prompt({ nodeId }: { nodeId: string }) {
+ const API = useAPI();
+ const [value, setValue] = useState(null);
+ const [loading, setLoading] = useState(false);
+
+ const onPromptChange = useCallback((prompt) => {
+ setValue(prompt?.def);
+ }, []);
+
+ const [prompt, setSubmitted] = usePrompt(nodeId, onPromptChange);
+
+ const onChange = useCallback((value) => {
+ setValue(value);
+ }, []);
+
+ const onSubmit = useCallback(async () => {
+ if (API) {
+ setLoading(true);
+ try {
+ let answer = value;
+ if (prompt?.type === 'dict') {
+ answer = parseDictWidgetValue(answer);
+ }
+ await API.respondToPrompt(nodeId, answer);
+ } catch (e) {
+ console.error(e);
+ }
+ setLoading(false);
+ setSubmitted();
+ }
+ }, [value, API, nodeId]);
+
+ if (!prompt || prompt.type === null) {
+ return null;
+ }
+
+ return (
+
+ Prompted:
+
+ {prompt.msg}
+
+ Submit
+
+ );
+}
diff --git a/web/src/components/Nodes/Widgets.tsx b/web/src/components/Nodes/widgets/Widgets.tsx
similarity index 65%
rename from web/src/components/Nodes/Widgets.tsx
rename to web/src/components/Nodes/widgets/Widgets.tsx
index 5ebff5f..921416a 100644
--- a/web/src/components/Nodes/Widgets.tsx
+++ b/web/src/components/Nodes/widgets/Widgets.tsx
@@ -1,17 +1,17 @@
-import { Switch, Typography, theme, Flex, Button } from 'antd';
-import { PlusOutlined, MinusCircleOutlined, MinusOutlined } from '@ant-design/icons';
-import React, { useCallback, useState, useMemo } from 'react';
+import { Switch, Typography, theme, Flex, Button, Radio, Select as ASelect } from 'antd';
+import { PlusOutlined, MinusOutlined } from '@ant-design/icons';
+import React, { useCallback, useState, useMemo, useEffect } from 'react';
import CodeMirror from '@uiw/react-codemirror';
import { python } from '@codemirror/lang-python';
import { basicDark } from '@uiw/codemirror-theme-basic';
import { bbedit } from '@uiw/codemirror-theme-bbedit';
-import { Graph } from '../../graph';
+import { Graph } from '../../../graph';
import { useReactFlow } from 'reactflow';
-import { usePluginWidgets } from '../../hooks/Plugins';
+import { usePluginWidgets } from '../../../hooks/Plugins';
const { Text } = Typography;
const { useToken } = theme;
-const getWidgetLookup = (pluginWidgets) => {
+export const getWidgetLookup = (pluginWidgets) => {
const lookup = {
number: NumberWidget,
string: StringWidget,
@@ -19,14 +19,16 @@ const getWidgetLookup = (pluginWidgets) => {
bool: BooleanWidget,
function: FunctionWidget,
dict: DictWidget,
+ selection: SelectionWidget,
};
+
pluginWidgets.forEach((widget) => {
lookup[widget.type] = widget.children;
});
return lookup;
};
-export function Widget({ id, type, name, value }) {
+export function Widget({ id, type, name, value, ...props }) {
const { setNodes } = useReactFlow();
const pluginWidgets = usePluginWidgets();
const widgets = useMemo(() => {
@@ -44,7 +46,7 @@ export function Widget({ id, type, name, value }) {
}
if (widgets[type]) {
- return widgets[type]({ name, def: value, onChange });
+ return widgets[type]({ name, def: value, onChange, ...props });
}
}
@@ -60,13 +62,27 @@ export function StringWidget({ name, def, onChange }) {
);
}
-export function BooleanWidget({ name, def, onChange }) {
+export function BooleanWidget({ name, def, onChange, style }) {
+ const input = useMemo(() => {
+ if (style === "yes/no") {
+ const M = {
+ "Yes": true,
+ "No": false,
+ };
+ const M_ = {
+ true: "Yes",
+ false: "No",
+ };
+ return onChange(M[e.target.value])} value={M_[def]} optionType="button" />
+ }
+ return
+ }, [style, def]);
return (
-
+
{name}
-
+ {input}
- )
+ );
}
export function FunctionWidget({ name, def, onChange }) {
@@ -148,29 +164,35 @@ export function ListWidget({ name, def, onChange, type }) {
- )
+ );
}
-export function DictWidget({ name, def, onChange, type }) {
+export function DictWidget({ name, def, onChange }) {
const { token } = useToken();
const value = useMemo(() => {
- if (!def) {
- return [];
- }
-
if (Array.isArray(def)) {
return def;
}
+ // Needs to be converted to an array of 3-tuples
+
+ if (!def) {
+ onChange([]);
+ return [];
+ }
+
try {
- return Object.entries(def).map(([key, value]) => {
+ const newValue = Object.entries(def).map(([key, value]) => {
let type = typeof value as string;
if (type === 'number') {
type = 'float';
}
return [type, key, value];
});
+ onChange(newValue);
+ return newValue;
} catch (e) {
+ onChange([]);
return [];
}
}, [def]);
@@ -210,10 +232,10 @@ export function DictWidget({ name, def, onChange, type }) {
const selectedInputs = useMemo(() => {
return value.map((item, i) => {
if (item[0] === 'string') {
- return onValueChange(i, value)} value={item[2]} />
+ return onValueChange(i, value)} value={item[2]} />
}
if (item[0] === 'float' || item[0] === 'int' || item[0] === 'number') {
- return onValueChange(i, value)} value={item[2]} />
+ return onValueChange(i, value)} value={item[2]} />
}
if (item[0] === 'boolean') {
return onValueChange(i, value)} value={item[2]} />
@@ -230,66 +252,113 @@ export function DictWidget({ name, def, onChange, type }) {
}, []);
return (
-
+
{name}
{value && value.map((item, i) => {
return (
-
-
+
+
onTypeChange(i, value)}
options={options}
/>
- onKeyChange(i, value)} value={item[1]} />
-
-
-
- {
- selectedInputs[i]
- }
- } onClick={() => onRemoveItem(i)} />
+ onKeyChange(i, value)} value={item[1]} />
+ {
+ selectedInputs[i]
+ }
+ } onClick={() => onRemoveItem(i)} />
)
})}
- } onClick={onAddItem} />
+ } onClick={onAddItem} />
- )
+ );
}
-function Select({ onChange, options, value }) {
- const { token } = useToken();
+export function SelectionWidget({ name, def, onChange, choices, multiple_allowed }) {
+ const options = useMemo(() => {
+ if (!choices) {
+ return [];
+ }
+ return choices.map((choice) => {
+ return {
+ label: choice,
+ value: choice,
+ };
+ });
+ }, [choices]);
- const inputStyle = {
- backgroundColor: token.colorBgContainer,
- color: token.colorText,
- };
+ return (
+
+ {name &&
+ {name}
+ }
+
+
+ );
+}
+type SelectProps = {
+ onChange: (value: string | string[]) => void,
+ options: { label: string, value: string }[],
+ value: any,
+ style?: React.CSSProperties,
+ multipleAllowed?: boolean,
+};
+function Select({ onChange, options, value, style, multipleAllowed }: SelectProps) {
+ const [open, setOpen] = useState(false);
- const onValueChange = useCallback((e) => {
- onChange(e.target.value);
- }, []);
+ const onSelect = useCallback((val) => {
+ if (multipleAllowed) {
+ onChange([...value, val]);
+ } else {
+ onChange(val);
+ }
+ setOpen(false);
+ console.log(val);
+ }, [setOpen, open, multipleAllowed, value]);
+
+ const onDeselect = useCallback((val) => {
+ if (multipleAllowed) {
+ onChange(value.filter((v) => v !== val));
+ }
+ }, [value]);
+
+ const onClick = useCallback(() => {
+ setOpen(!open);
+ }, [setOpen, open]);
return (
-
-
- {
- options.map((option, i) => {
- return {option.label}
- })
- }
-
-
- )
+
+ );
}
-function InputNumber({ onChange, label, value }: { onChange: (value: number) => void, label?: string, value?: number }) {
+type InputNumberProps = {
+ onChange: (value: number) => void,
+ label?: string,
+ value?: number,
+ placeholder?: string,
+};
+
+function InputNumber({ onChange, label, value, placeholder }: InputNumberProps) {
const { token } = useToken();
const defaultFocusedStyle = { border: `1px solid ${token.colorBorder}` };
const [focusedStyle, setFocusedStyle] = useState(defaultFocusedStyle);
@@ -325,13 +394,21 @@ function InputNumber({ onChange, label, value }: { onChange: (value: number) =>
className="input"
type="number"
value={value || 0}
+ placeholder={placeholder}
/>
);
}
+type InputProps = {
+ onChange: (value: string) => void,
+ label?: string,
+ value?: string,
+ placeholder?: string,
+ style?: React.CSSProperties,
+};
-function Input({ onChange, label, value }: { onChange: (value: string) => void, label?: string, value?: string }) {
+function Input({ onChange, label, value, placeholder, style }: InputProps) {
const { token } = useToken();
const defaultFocusedStyle = { border: `1px solid ${token.colorBorder}` };
const [focusedStyle, setFocusedStyle] = useState(defaultFocusedStyle)
@@ -339,6 +416,7 @@ function Input({ onChange, label, value }: { onChange: (value: string) => void,
const inputStyle = {
backgroundColor: token.colorBgContainer,
color: token.colorText,
+ ...style,
};
const labelStyle = {
backgroundColor: token.colorBgContainer
@@ -368,11 +446,12 @@ function Input({ onChange, label, value }: { onChange: (value: string) => void,
className="input"
type="text"
value={value || ''}
+ placeholder={placeholder}
/>
);
}
export const isWidgetType = (type) => {
- return ['number', 'string', 'boolean', 'bool', 'function'].includes(type) || type.startsWith('list') || type.startsWith('dict');
+ return ['number', 'string', 'boolean', 'bool', 'function', 'selection'].includes(type) || type.startsWith('list') || type.startsWith('dict');
};
diff --git a/web/src/graph.ts b/web/src/graph.ts
index 739e47f..7c95916 100644
--- a/web/src/graph.ts
+++ b/web/src/graph.ts
@@ -1,4 +1,4 @@
-import { uniqueIdFrom, getHandle, Parameter } from './utils';
+import { uniqueIdFrom, getHandle, Parameter, parseDictWidgetValue } from './utils';
import { API } from './api';
import type { ServerAPI } from './api';
import type { Node, Edge } from 'reactflow';
@@ -180,11 +180,7 @@ export const Graph = {
parameters[key] = param.value;
if (param.type && param.value) {
if (param.type === 'dict') {
- const d = {};
- for (const [t, k, v] of param.value) {
- d[k] = v;
- }
- parameters[key] = d;
+ parameters[key] = parseDictWidgetValue(param.value);
}
}
}
@@ -416,8 +412,11 @@ export const Graph = {
const parseEdges = (nodes: Node[], edges: Edge[]) => {
return edges.map((edge) => {
- const targetNode = nodes.find(n => n.id === edge.target)!;
- const sourceNode = nodes.find(n => n.id === edge.source)!;
+ const targetNode = nodes.find(n => n.id === edge.target);
+ const sourceNode = nodes.find(n => n.id === edge.source);
+ if (!targetNode || !sourceNode) {
+ return null;
+ }
const targetHandle = getHandle(targetNode, edge.targetHandle!, true);
const sourceHandle = getHandle(sourceNode, edge.sourceHandle!, false);
return {
@@ -431,7 +430,7 @@ export const Graph = {
}
};
- });
+ }).filter(e => e !== null);
};
const { nodes, edges } = graph;
diff --git a/web/src/hooks/Prompts.ts b/web/src/hooks/Prompts.ts
new file mode 100644
index 0000000..030d7b3
--- /dev/null
+++ b/web/src/hooks/Prompts.ts
@@ -0,0 +1,77 @@
+import { useAPINodeMessage } from "./API";
+import { useFilename } from "./Filename"
+import { useEffect, useCallback, useState } from "react";
+
+
+let globalPrompts = {};
+let localSetters: Function[] = [];
+
+type PromptData = {
+ "idx": number,
+ "note": object,
+ "msg": string,
+ "show_images": boolean,
+ "def": any,
+ "options": object,
+ "type": string
+};
+
+export type Prompt = {
+ "note": object,
+ "msg": string,
+ "showImages": boolean,
+ "def": any,
+ "options": object,
+ "type": string
+};
+
+const dataToPrompt = (data: PromptData): Prompt => ({
+ note: data.note,
+ msg: data.msg,
+ showImages: data.show_images,
+ def: data.def,
+ options: data.options,
+ type: data.type
+});
+
+export function usePrompt(nodeId: string, callback?: Function | null): [Prompt | null, Function] {
+ const filename = useFilename();
+ const [_, setPrompt] = useState(null);
+
+ useEffect(() => {
+ localSetters.push(setPrompt);
+ return () => {
+ localSetters = localSetters.filter((setter) => setter !== setPrompt);
+ delete globalPrompts[nodeId];
+ };
+ }, [nodeId]);
+
+ const onPrompt = useCallback((data: PromptData) => {
+ if (globalPrompts[nodeId]) {
+ if (globalPrompts[nodeId].idx === data.idx) {
+ return;
+ }
+ }
+
+ globalPrompts[nodeId] = data;
+ if (callback) {
+ callback(dataToPrompt(data));
+ }
+ }, [nodeId]);
+
+ const setSubmitted = useCallback(() => {
+ if (!globalPrompts[nodeId]) {
+ return;
+ }
+ globalPrompts[nodeId] = { ...globalPrompts[nodeId], "type": null };
+ }, []);
+
+ useAPINodeMessage("prompt", nodeId, filename, onPrompt);
+
+ const promptData = globalPrompts[nodeId];
+ if (!promptData) {
+ return [null, setSubmitted];
+ }
+
+ return [dataToPrompt(promptData), setSubmitted];
+}
diff --git a/web/src/utils.ts b/web/src/utils.ts
index 62946b7..8601631 100644
--- a/web/src/utils.ts
+++ b/web/src/utils.ts
@@ -3,6 +3,14 @@ import type { ServerAPI } from "./api";
import type { ReactFlowInstance } from "reactflow";
import type { Node } from "reactflow";
+export const parseDictWidgetValue = (entries) => {
+ const obj = {};
+ for (const [type, key, val] of entries) {
+ obj[key] = val;
+ }
+ return obj;
+};
+
export const keyRecursively = (obj: Array, childrenKey: string = "children"): Array => {
let currKeyVal = 0;
const keyRec = (obj: Array) => {