From 2f81f4f8197ee9e2134470fe705f883a600e27ab Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Mon, 16 Sep 2024 17:59:20 -0700 Subject: [PATCH 01/14] Skeleton for prompting --- graphbook/custom_nodes.py | 2 + graphbook/decorators.py | 3 +- graphbook/processing/web_processor.py | 60 +++++++++++++- graphbook/prompts.py | 19 +++++ graphbook/steps/__init__.py | 2 + graphbook/steps/base.py | 27 ++++++- graphbook/utils.py | 2 +- graphbook/viewer.py | 30 ++++++- graphbook/web.py | 8 ++ web/src/api.ts | 4 + web/src/components/Monitor.tsx | 2 +- web/src/components/Nodes/Node.tsx | 17 +++- web/src/components/Nodes/Resource.tsx | 2 +- web/src/components/Nodes/node.css | 29 ++++++- .../components/Nodes/widgets/NotePreview.tsx | 80 +++++++++++++++++++ web/src/components/Nodes/widgets/Prompts.tsx | 72 +++++++++++++++++ .../Nodes/{ => widgets}/Widgets.tsx | 22 +++-- web/src/hooks/Prompts.ts | 28 +++++++ 18 files changed, 387 insertions(+), 22 deletions(-) create mode 100644 graphbook/prompts.py create mode 100644 web/src/components/Nodes/widgets/NotePreview.tsx create mode 100644 web/src/components/Nodes/widgets/Prompts.tsx rename web/src/components/Nodes/{ => widgets}/Widgets.tsx (94%) create mode 100644 web/src/hooks/Prompts.ts diff --git a/graphbook/custom_nodes.py b/graphbook/custom_nodes.py index ffdbdc1..0fba930 100644 --- a/graphbook/custom_nodes.py +++ b/graphbook/custom_nodes.py @@ -13,6 +13,7 @@ from graphbook.steps import ( Step, BatchStep, + PromptStep, SourceStep, GeneratorSourceStep, AsyncStep, @@ -25,6 +26,7 @@ BUILT_IN_STEPS = [ Step, BatchStep, + PromptStep, SourceStep, GeneratorSourceStep, AsyncStep, diff --git a/graphbook/decorators.py b/graphbook/decorators.py index 17849bf..6cc7ccd 100644 --- a/graphbook/decorators.py +++ b/graphbook/decorators.py @@ -31,14 +31,13 @@ def param( "required": required, "description": description, } + self.parameter_type_casts[name] = cast_as if cast_as is None: # Default casts if type == "function": self.parameter_type_casts[name] = transform_function_string if type == "int": self.parameter_type_casts[name] = int - else: - self.parameter_type_casts[name] = cast_as @abc.abstractmethod def build(): diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py index fb293dd..99ec80b 100644 --- a/graphbook/processing/web_processor.py +++ b/graphbook/processing/web_processor.py @@ -1,4 +1,11 @@ -from graphbook.steps import Step, SourceStep, GeneratorSourceStep, AsyncStep, StepOutput +from graphbook.steps import ( + Step, + SourceStep, + GeneratorSourceStep, + AsyncStep, + PromptStep, + StepOutput, +) from graphbook.dataloading import Dataloader, setup_global_dl from graphbook.utils import MP_WORKER_TIMEOUT, ProcessorStateRequest, transform_json_log from graphbook.state import GraphState, StepState, NodeInstantiationError @@ -38,6 +45,7 @@ def __init__( self.cmd_queue = cmd_queue self.close_event = close_event self.pause_event = pause_event + self.prompted_pause_event = False self.view_manager = ViewManagerInterface(view_manager_queue) self.img_mem = img_mem self.graph_state = GraphState(custom_nodes_path, view_manager_queue) @@ -48,8 +56,9 @@ def __init__( self.steps = {} self.dataloader = Dataloader(self.num_workers) setup_global_dl(self.dataloader) + self.unhandled_prompts = {} self.state_client = ProcessorStateClient( - server_request_conn, close_event, self.graph_state, self.dataloader + server_request_conn, close_event, self.graph_state, self.dataloader, self.unhandled_prompts ) self.is_running = False self.filename = None @@ -129,6 +138,36 @@ def exec_step( self.view_manager.handle_time(step.id, time.time() - start_time) return outputs + def handle_prompt_step(self, step: PromptStep, input: Note | None = None): + unhandled_prompt = self.unhandled_prompts.get(step.id) + if unhandled_prompt: + response = unhandled_prompt.get("response") + if response: + note = unhandled_prompt.get("note") + step.on_prompt_response(note, response) + self.unhandled_prompts.pop(step.id) + outputs = self.exec_step(step, unhandled_prompt.get("note")) + return outputs + return {} + + if input is None: + return {} + + prompt = step.get_prompt(input) + print("prompt", prompt) + if prompt is None: + return {} + + if prompt.get("pause"): + self.prompted_pause_event = True + self.view_manager.handle_prompt(step.id, prompt) + self.unhandled_prompts[step.id] = { + "note": input, + "prompt": prompt, + "response": None, + } + return {} + def handle_steps(self, steps: List[Step]) -> bool: is_active = False for step in steps: @@ -145,7 +184,10 @@ def handle_steps(self, steps: List[Step]) -> bool: except StopIteration: input = None - if isinstance(step, AsyncStep): + if isinstance(step, PromptStep): + self.handle_prompt_step(step, input) + is_active = is_active or len(self.unhandled_prompts) > 0 + elif isinstance(step, AsyncStep): if is_active: # parent is active # Proceed with normal step execution output = self.exec_step(step, input) @@ -199,6 +241,9 @@ def run(self, step_id: str = None): and not self.dataloader.is_failed() ): dag_is_active = self.handle_steps(steps) + print(self.unhandled_prompts) + if self.unhandled_prompts: + print("truth") finally: self.view_manager.handle_end() for step in steps: @@ -296,12 +341,14 @@ def __init__( close_event: mp.Event, graph_state: GraphState, dataloader: Dataloader, + unhandled_prompts: dict, ): self.server_request_conn = server_request_conn self.close_event = close_event self.curr_task = None self.graph_state = graph_state self.dataloader = dataloader + self.unhandled_prompts = unhandled_prompts self.running_state = {} def _loop(self): @@ -323,6 +370,13 @@ def _loop(self): output = self.dataloader.get_all_sizes() elif req["cmd"] == ProcessorStateRequest.GET_RUNNING_STATE: output = self.running_state + elif req["cmd"] == ProcessorStateRequest.PROMPT_RESPONSE: + step_id = req.get("step_id") + if not self.unhandled_prompts.get(step_id): + output = {} + else: + self.unhandled_prompts[step_id]["response"] = req.get("response") + output = self.unhandled_prompts[step_id] else: output = {} entry = {"res": req["cmd"], "data": output} diff --git a/graphbook/prompts.py b/graphbook/prompts.py new file mode 100644 index 0000000..c71d6ff --- /dev/null +++ b/graphbook/prompts.py @@ -0,0 +1,19 @@ +from typing import Any +from .note import Note +from .utils import transform_json_log + +def prompt(note: Note, msg: str, show_images: bool=False, default: Any=""): + return { + "note": transform_json_log(note), + "msg": msg, + "show_images": show_images, + "def": default + } + +def bool_prompt(note: Note, msg: str="", style: str="yes/no", default: bool=False, show_images: bool=False): + p = prompt(note, msg, default=default, show_images=show_images) + p["type"] = "bool" + p["options"] = { + "style": style + } + return p diff --git a/graphbook/steps/__init__.py b/graphbook/steps/__init__.py index 22ecd60..f2a2bab 100644 --- a/graphbook/steps/__init__.py +++ b/graphbook/steps/__init__.py @@ -3,6 +3,7 @@ SourceStep, GeneratorSourceStep, BatchStep, + PromptStep, StepOutput, AsyncStep, Split, @@ -16,6 +17,7 @@ "SourceStep", "GeneratorSourceStep", "BatchStep", + "PromptStep", "StepOutput", "AsyncStep", "Split", diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py index a08b95a..7c691b9 100644 --- a/graphbook/steps/base.py +++ b/graphbook/steps/base.py @@ -1,8 +1,13 @@ from __future__ import annotations from typing import List, Dict, Tuple, Generator, Any -from ..utils import transform_function_string, convert_dict_values_to_list, is_batchable from graphbook import Note +from graphbook.utils import ( + transform_function_string, + convert_dict_values_to_list, + is_batchable, +) from graphbook.logger import log +import graphbook.prompts as prompts import graphbook.dataloading as dataloader import warnings @@ -489,6 +494,26 @@ def is_active(self) -> bool: ) +class PromptStep(Step): + def __init__(self): + super().__init__() + self.waiting = False + + def get_prompt(self, note: Note) -> dict: + return prompts.bool_prompt(note, "Continue?", "yes/no") + + def on_prompt_response(self, note: Note, response: Any): + raise NotImplementedError( + "on_prompt_response must be implemented for PromptStep" + ) + + def __call__(self, note: Note) -> StepOutput: + if self.waiting: + return {} + else: + return + + class Split(Step): """ Routes incoming Notes into either of two output slots, A or B. If split_fn diff --git a/graphbook/utils.py b/graphbook/utils.py index 28d4601..a33e724 100644 --- a/graphbook/utils.py +++ b/graphbook/utils.py @@ -16,7 +16,7 @@ MP_WORKER_TIMEOUT = 5.0 ProcessorStateRequest = Enum( "ProcessorStateRequest", - ["GET_OUTPUT_NOTE", "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE"], + ["GET_OUTPUT_NOTE", "GET_WORKER_QUEUE_SIZES", "GET_RUNNING_STATE", "PROMPT_RESPONSE"], ) diff --git a/graphbook/viewer.py b/graphbook/viewer.py index f1fd16c..0bd9998 100644 --- a/graphbook/viewer.py +++ b/graphbook/viewer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List from aiohttp.web import WebSocketResponse import uuid import asyncio @@ -8,7 +8,7 @@ import queue import copy import psutil -from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for, transform_json_log +from .utils import MP_WORKER_TIMEOUT, get_gpu_util, ProcessorStateRequest, poll_conn_for class Viewer: @@ -177,6 +177,19 @@ def get_next(self): } +class PromptViewer(Viewer): + def __init__(self): + super().__init__("prompt") + self.prompts = {} + + def handle_prompt(self, node_id: str, prompt: dict): + self.prompts[node_id] = prompt + + def get_next(self): + print(self.prompts) + return self.prompts + + DEFAULT_CLIENT_OPTIONS = {"SEND_EVERY": 0.5} @@ -224,11 +237,13 @@ def __init__( self.node_stats_viewer = NodeStatsViewer() self.logs_viewer = NodeLogsViewer() self.system_util_viewer = SystemUtilViewer(processor_state_conn) + self.prompt_viewer = PromptViewer() self.viewers: List[Viewer] = [ self.data_viewer, self.node_stats_viewer, self.logs_viewer, self.system_util_viewer, + self.prompt_viewer, ] self.clients: Dict[str, Client] = {} self.work_queue = work_queue @@ -275,6 +290,9 @@ def handle_clear(self, node_id: str | None): def handle_log(self, node_id: str, log: str, type: str): self.logs_viewer.handle_log(node_id, log, type) + def handle_prompt(self, node_id: str, prompt: dict): + self.prompt_viewer.handle_prompt(node_id, prompt) + def handle_end(self): for viewer in self.viewers: viewer.handle_end() @@ -315,6 +333,9 @@ def _loop(self): self.handle_run_state(work["is_running"], work["filename"]) elif work["cmd"] == "handle_clear": self.handle_clear(work["node_id"]) + elif work["cmd"] == "handle_prompt": + self.handle_prompt(work["node_id"], work["prompt"]) + except queue.Empty: pass @@ -364,3 +385,8 @@ def handle_run_state(self, run_state: dict): def handle_clear(self, node_id: str | None): self.view_manager_queue.put({"cmd": "handle_clear", "node_id": node_id}) + + def handle_prompt(self, node_id: str, prompt: dict): + self.view_manager_queue.put( + {"cmd": "handle_prompt", "node_id": node_id, "prompt": prompt} + ) diff --git a/graphbook/web.py b/graphbook/web.py index fc8f2cf..77c9d20 100644 --- a/graphbook/web.py +++ b/graphbook/web.py @@ -205,6 +205,14 @@ async def clear(request: web.Request) -> web.Response: } ) return web.json_response({"success": True}) + + @routes.post("/prompt_response/{id}") + async def prompt_response(request: web.Request) -> web.Response: + step_id = request.match_info.get("id") + data = await request.json() + response = data.get("response") + poll_conn_for(state_conn, ProcessorStateRequest.PROMPT_RESPONSE, {"step_id": step_id, "response": response}) + return web.json_response({"success": True}) @routes.get("/nodes") async def get_nodes(request: web.Request) -> web.Response: diff --git a/web/src/api.ts b/web/src/api.ts index e1d501a..b1f4842 100644 --- a/web/src/api.ts +++ b/web/src/api.ts @@ -248,6 +248,10 @@ export class ServerAPI { return await this.get('state'); } + public async respondToPrompt(stepId: string, response: any) { + return await this.post(`prompt_response/${stepId}`, { response }); + } + /** * File API */ diff --git a/web/src/components/Monitor.tsx b/web/src/components/Monitor.tsx index 3b48029..b3d10ac 100644 --- a/web/src/components/Monitor.tsx +++ b/web/src/components/Monitor.tsx @@ -403,7 +403,7 @@ function NotesView({ stepId, numNotes, type }: NotesViewProps) { ([]); const [recordCount, setRecordCount] = useState({}); const [errored, setErrored] = useState(false); + const [prompt, setPrompt] = useState(null); const [parentSelected, setParentSelected] = useState(false); const [runState, runStateShouldChange] = useRunState(); const nodes = useNodes(); @@ -47,6 +50,11 @@ export function WorkflowStep({ id, data, selected }) { setLogsData(prev => getMergedLogs(prev, newEntries)); }, [setLogsData])); + usePrompt(id, (data: PromptProps) => { + console.log(data); + setPrompt({ ...data, stepId: id }); + }); + useEffect(() => { for (const log of logsData) { if (log.type === 'error') { @@ -164,6 +172,11 @@ export function WorkflowStep({ id, data, selected }) { }).filter(x => x) } + {prompt && +
+ +
+ } {!data.isCollapsed && } @@ -229,7 +242,7 @@ function QuickviewCollapse({ data }) { span { + font-size: .6em; } .workflow-node .ant-btn-icon { @@ -142,10 +158,21 @@ textarea.code { font-size: .6em; } -.workflow-node .outputs .output .label{ +.workflow-node .outputs .output .label { + font-size: .6em; +} + +.workflow-node .widgets .prompt .ant-typography { font-size: .6em; } +.workflow-node .widgets .prompt .ant-radio-button-wrapper { + font-size: 10px; + line-height: 1.2; + height: 14px; + padding: 0 4px; +} + .workflow-node .handles { display: flex; flex-direction: row; diff --git a/web/src/components/Nodes/widgets/NotePreview.tsx b/web/src/components/Nodes/widgets/NotePreview.tsx new file mode 100644 index 0000000..2f75d5d --- /dev/null +++ b/web/src/components/Nodes/widgets/NotePreview.tsx @@ -0,0 +1,80 @@ +import React, { CSSProperties, useMemo } from 'react'; +import { Image, Space, Flex } from 'antd'; +import { theme } from 'antd'; +import { useSettings } from '../../../hooks/Settings'; +import { getMediaPath } from '../../../utils'; +import type { ImageRef } from '../../../utils'; +import ReactJson from '@microlink/react-json-view'; + +type QuickViewEntry = { + [key: string]: any; +}; + +export function NotePreview({ data, show_images }: { data: QuickViewEntry, show_images: boolean }) { + const globalTheme = theme.useToken().theme; + + if (!show_images) { + return ( + + ); + + } + return ( + + ); + +} + +function EntryImages({ entry, style }: { entry: QuickViewEntry, style: CSSProperties | undefined }) { + const [settings, _] = useSettings(); + + const imageEntries = useMemo(() => { + let entries: { [key: string]: ImageRef[] } = {}; + Object.entries(entry).forEach(([key, item]) => { + let imageItems: any = []; + if (Array.isArray(item)) { + imageItems = item.filter(item => item.type?.slice(0, 5) === 'image'); + } else { + if (item.type?.slice(0, 5) === 'image') { + imageItems.push(item); + } + } + if (imageItems.length > 0) { + entries[key] = imageItems; + } + }); + return entries; + }, [entry]); + + return ( + + { + Object.entries(imageEntries).map(([key, images]) => { + return ( + +
{key}
+ + { + images.map((image, i) => ( + + )) + } + +
+ + ); + }) + } +
+ ); +} diff --git a/web/src/components/Nodes/widgets/Prompts.tsx b/web/src/components/Nodes/widgets/Prompts.tsx new file mode 100644 index 0000000..7761ae9 --- /dev/null +++ b/web/src/components/Nodes/widgets/Prompts.tsx @@ -0,0 +1,72 @@ +import React, { useCallback, useMemo, useState } from 'react'; +import { Typography, Flex, Button } from 'antd'; +import { usePluginWidgets } from '../../../hooks/Plugins'; +import { NotePreview } from './NotePreview'; +import { NumberWidget, StringWidget, BooleanWidget, FunctionWidget, DictWidget, ListWidget } from './Widgets'; +import { useAPI } from '../../../hooks/API'; + +const { Text } = Typography; +export type PromptProps = { + stepId: string, + note: any, + msg: string, + type: string, + def: any, + show_images?: boolean, + options?: any, +}; + +const getWidgetLookup = (pluginWidgets) => { + const lookup = { + number: NumberWidget, + string: StringWidget, + boolean: BooleanWidget, + bool: BooleanWidget, + function: FunctionWidget, + dict: DictWidget, + }; + pluginWidgets.forEach((widget) => { + lookup[widget.type] = widget.children; + }); + return lookup; +}; + +function Widget({ type, options, value, onChange }) { + const pluginWidgets = usePluginWidgets(); + const widgets = useMemo(() => { + return getWidgetLookup(pluginWidgets); + }, [pluginWidgets]); + + if (type.startsWith('list')) { + return + } + + if (widgets[type]) { + return widgets[type]({ name: "Answer", def: value, onChange, style: options.style }); + } +} + + +export function Prompt({ stepId, note, msg, type, options, def, show_images }: PromptProps) { + const API = useAPI(); + const [value, setValue] = useState(def); + const onChange = useCallback((value) => { + setValue(value); + }, []); + + const onSubmit = useCallback(() => { + if (API) { + API.respondToPrompt(stepId, value); + } + }, [value]); + + return ( + + Prompted: + + {msg} + + + + ); +} \ No newline at end of file diff --git a/web/src/components/Nodes/Widgets.tsx b/web/src/components/Nodes/widgets/Widgets.tsx similarity index 94% rename from web/src/components/Nodes/Widgets.tsx rename to web/src/components/Nodes/widgets/Widgets.tsx index 5ebff5f..7baaf31 100644 --- a/web/src/components/Nodes/Widgets.tsx +++ b/web/src/components/Nodes/widgets/Widgets.tsx @@ -1,13 +1,13 @@ -import { Switch, Typography, theme, Flex, Button } from 'antd'; -import { PlusOutlined, MinusCircleOutlined, MinusOutlined } from '@ant-design/icons'; +import { Switch, Typography, theme, Flex, Button, Radio } from 'antd'; +import { PlusOutlined, MinusOutlined } from '@ant-design/icons'; import React, { useCallback, useState, useMemo } from 'react'; import CodeMirror from '@uiw/react-codemirror'; import { python } from '@codemirror/lang-python'; import { basicDark } from '@uiw/codemirror-theme-basic'; import { bbedit } from '@uiw/codemirror-theme-bbedit'; -import { Graph } from '../../graph'; +import { Graph } from '../../../graph'; import { useReactFlow } from 'reactflow'; -import { usePluginWidgets } from '../../hooks/Plugins'; +import { usePluginWidgets } from '../../../hooks/Plugins'; const { Text } = Typography; const { useToken } = theme; @@ -60,13 +60,19 @@ export function StringWidget({ name, def, onChange }) { ); } -export function BooleanWidget({ name, def, onChange }) { +export function BooleanWidget({ name, def, onChange, style }) { + const input = useMemo(() => { + if (style === "yes/no") { + return onChange(e.target.value)} value={def} optionType="button" /> + } + return + }, [style, def]); return ( - + {name} - + {input} - ) + ); } export function FunctionWidget({ name, def, onChange }) { diff --git a/web/src/hooks/Prompts.ts b/web/src/hooks/Prompts.ts new file mode 100644 index 0000000..d400dd0 --- /dev/null +++ b/web/src/hooks/Prompts.ts @@ -0,0 +1,28 @@ +import { useAPINodeMessage } from "./API"; +import { useFilename } from "./Filename" +import { useEffect, useCallback } from "react"; + + +let globalPrompts = {}; + +export function usePrompt(nodeId: string, callback: Function) { + const filename = useFilename(); + + useEffect(() => { + return () => { + delete globalPrompts[nodeId]; + }; + }, [nodeId]); + + const internalCallback = useCallback((data) => { + if (globalPrompts[nodeId]) { + return; + } + + console.log(data); + globalPrompts[nodeId] = data; + callback(data); + }, [nodeId]); + + useAPINodeMessage("prompt", nodeId, filename, internalCallback); +} From 21df3309b9a0a1146bfe29c877d5b647ecce92c1 Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Tue, 17 Sep 2024 20:43:23 -0700 Subject: [PATCH 02/14] new logic --- graphbook/logger.py | 16 +++++++++ graphbook/processing/web_processor.py | 33 +++++++++--------- graphbook/prompts.py | 1 + graphbook/state.py | 8 ++++- graphbook/steps/base.py | 49 +++++++++++++++++++++------ 5 files changed, 79 insertions(+), 28 deletions(-) diff --git a/graphbook/logger.py b/graphbook/logger.py index 40dd64c..925e3fb 100644 --- a/graphbook/logger.py +++ b/graphbook/logger.py @@ -2,6 +2,7 @@ import multiprocessing as mp from typing import Dict, Tuple, Any import inspect +from graphbook.note import Note from graphbook.viewer import ViewManagerInterface from graphbook.utils import transform_json_log @@ -44,3 +45,18 @@ def log(msg: Any, type: LogType = "info", caller_id: int | None = None): else: raise ValueError(f"Unknown log type {type}") view_manager.handle_log(node_id, msg, type) + +def prompt(prompt: dict, caller_id: int | None = None): + if caller_id is None: + prev_frame = inspect.currentframe().f_back + caller = prev_frame.f_locals.get("self") + if caller is not None: + caller_id = id(caller) + + node = logging_nodes.get(caller_id, None) + if node is None: + raise ValueError( + f"Can't find node id in {caller}. Only initialized steps can log." + ) + node_id, node_name = node + view_manager.handle_prompt(node_id, prompt) diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py index 99ec80b..5a7ee41 100644 --- a/graphbook/processing/web_processor.py +++ b/graphbook/processing/web_processor.py @@ -3,6 +3,7 @@ SourceStep, GeneratorSourceStep, AsyncStep, + BatchStep, PromptStep, StepOutput, ) @@ -58,7 +59,11 @@ def __init__( setup_global_dl(self.dataloader) self.unhandled_prompts = {} self.state_client = ProcessorStateClient( - server_request_conn, close_event, self.graph_state, self.dataloader, self.unhandled_prompts + server_request_conn, + close_event, + self.graph_state, + self.dataloader, + self.unhandled_prompts, ) self.is_running = False self.filename = None @@ -139,6 +144,7 @@ def exec_step( return outputs def handle_prompt_step(self, step: PromptStep, input: Note | None = None): + print("Handling prompt with input", input) unhandled_prompt = self.unhandled_prompts.get(step.id) if unhandled_prompt: response = unhandled_prompt.get("response") @@ -149,12 +155,11 @@ def handle_prompt_step(self, step: PromptStep, input: Note | None = None): outputs = self.exec_step(step, unhandled_prompt.get("note")) return outputs return {} - + if input is None: return {} prompt = step.get_prompt(input) - print("prompt", prompt) if prompt is None: return {} @@ -184,10 +189,11 @@ def handle_steps(self, steps: List[Step]) -> bool: except StopIteration: input = None - if isinstance(step, PromptStep): - self.handle_prompt_step(step, input) - is_active = is_active or len(self.unhandled_prompts) > 0 - elif isinstance(step, AsyncStep): + # if isinstance(step, PromptStep): + # self.handle_prompt_step(step, input) + # print("ON PROMPT", len(self.unhandled_prompts)) + # is_active = is_active or len(self.unhandled_prompts) > 0 + if isinstance(step, AsyncStep): if is_active: # parent is active # Proceed with normal step execution output = self.exec_step(step, input) @@ -241,9 +247,8 @@ def run(self, step_id: str = None): and not self.dataloader.is_failed() ): dag_is_active = self.handle_steps(steps) - print(self.unhandled_prompts) - if self.unhandled_prompts: - print("truth") + print("Done") + print(dag_is_active, self.pause_event.is_set(), self.close_event.is_set(), self.dataloader.is_failed()) finally: self.view_manager.handle_end() for step in steps: @@ -277,7 +282,7 @@ def cleanup(self): self.dataloader.shutdown() def setup_dataloader(self, steps: List[Step]): - dataloader_consumers = [step for step in steps if isinstance(step, AsyncStep)] + dataloader_consumers = [step for step in steps if isinstance(step, BatchStep)] consumer_ids = [id(c) for c in dataloader_consumers] consumer_load_fn = [ c.load_fn if hasattr(c, "load_fn") else None for c in dataloader_consumers @@ -372,11 +377,7 @@ def _loop(self): output = self.running_state elif req["cmd"] == ProcessorStateRequest.PROMPT_RESPONSE: step_id = req.get("step_id") - if not self.unhandled_prompts.get(step_id): - output = {} - else: - self.unhandled_prompts[step_id]["response"] = req.get("response") - output = self.unhandled_prompts[step_id] + self.graph_state.handle_prompt_response(step_id, req.get("response")) else: output = {} entry = {"res": req["cmd"], "data": output} diff --git a/graphbook/prompts.py b/graphbook/prompts.py index c71d6ff..0133781 100644 --- a/graphbook/prompts.py +++ b/graphbook/prompts.py @@ -11,6 +11,7 @@ def prompt(note: Note, msg: str, show_images: bool=False, default: Any=""): } def bool_prompt(note: Note, msg: str="", style: str="yes/no", default: bool=False, show_images: bool=False): + default = "Yes" if default else "No" p = prompt(note, msg, default=default, show_images=show_images) p["type"] = "bool" p["options"] = { diff --git a/graphbook/state.py b/graphbook/state.py index ce35071..5e471cf 100644 --- a/graphbook/state.py +++ b/graphbook/state.py @@ -2,7 +2,7 @@ from aiohttp.web import WebSocketResponse from typing import Dict, Tuple, List, Iterator, Set from graphbook.note import Note -from graphbook.steps import Step, StepOutput as Outputs +from graphbook.steps import Step, PromptStep, StepOutput as Outputs from graphbook.resources import Resource from graphbook.decorators import get_steps, get_resources from graphbook.viewer import ViewManagerInterface @@ -400,6 +400,12 @@ def get_output_note(self, step_id: str, pin_id: str, index: int) -> dict: note = internal_list[index] entry.update(data=note.items) return entry + + def handle_prompt_response(self, step_id: str, response: dict): + step = self._steps.get(step_id) + if not isinstance(step, PromptStep): + return + step.handle_prompt_response(response) def get_step(self, step_id: str): return self._steps.get(step_id) diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py index 7c691b9..8313d3e 100644 --- a/graphbook/steps/base.py +++ b/graphbook/steps/base.py @@ -6,7 +6,7 @@ convert_dict_values_to_list, is_batchable, ) -from graphbook.logger import log +from graphbook.logger import log, prompt import graphbook.prompts as prompts import graphbook.dataloading as dataloader import warnings @@ -226,8 +226,8 @@ class AsyncStep(Step): def __init__(self, item_key=None): super().__init__(item_key) - self._is_processing = True self._in_queue = [] + self._out_queue = [] def in_q(self, note: Note | None): if note is None: @@ -235,7 +235,18 @@ def in_q(self, note: Note | None): self._in_queue.append(note) def is_active(self) -> bool: - return self._is_processing + print("In queue", len(self._in_queue)) + return len(self._in_queue) > 0 + + def __call__(self) -> StepOutput: + # 1. on_note -> 2. on_item -> 3. on_after_item -> 4. forward_note + if len(self._out_queue) == 0: + return {} + note = self._out_queue.pop(0) + return super().__call__(note) + + def all(self) -> StepOutput: + return self.__call__() class NoteItemHolders: @@ -494,10 +505,18 @@ def is_active(self) -> bool: ) -class PromptStep(Step): +class PromptStep(AsyncStep): def __init__(self): super().__init__() - self.waiting = False + self._is_awaiting_response = False + self._awaiting_note = None + + def handle_prompt_response(self, response: dict): + note = self._awaiting_note + self.on_prompt_response(note, response) + self._out_queue.append(note) + self._is_awaiting_response = False + self._awaiting_note = None def get_prompt(self, note: Note) -> dict: return prompts.bool_prompt(note, "Continue?", "yes/no") @@ -505,13 +524,21 @@ def get_prompt(self, note: Note) -> dict: def on_prompt_response(self, note: Note, response: Any): raise NotImplementedError( "on_prompt_response must be implemented for PromptStep" - ) + ) + + def __call__(self): + # Handle Prompt + if not self._is_awaiting_response and len(self._in_queue) > 0: + note = self._in_queue.pop(0) + prompt(self.get_prompt(note)) + self._is_awaiting_response = True + self._awaiting_note = note + return super().__call__() + + def is_active(self) -> bool: + print("In queue", len(self._in_queue)) + return len(self._in_queue) > 0 or self._awaiting_note is not None - def __call__(self, note: Note) -> StepOutput: - if self.waiting: - return {} - else: - return class Split(Step): From 76062314195bfdfb1a9c52e0b1506deac020a037 Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Wed, 18 Sep 2024 18:50:39 -0700 Subject: [PATCH 03/14] pdate --- graphbook/steps/base.py | 14 +++-- graphbook/viewer.py | 6 ++- web/src/components/Nodes/Node.tsx | 22 +++----- .../components/Nodes/widgets/NotePreview.tsx | 4 +- web/src/components/Nodes/widgets/Prompts.tsx | 28 ++++++---- web/src/hooks/API.ts | 2 + web/src/hooks/Prompts.ts | 51 ++++++++++++++++--- 7 files changed, 86 insertions(+), 41 deletions(-) diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py index 8313d3e..7bdf183 100644 --- a/graphbook/steps/base.py +++ b/graphbook/steps/base.py @@ -235,16 +235,15 @@ def in_q(self, note: Note | None): self._in_queue.append(note) def is_active(self) -> bool: - print("In queue", len(self._in_queue)) return len(self._in_queue) > 0 - + def __call__(self) -> StepOutput: # 1. on_note -> 2. on_item -> 3. on_after_item -> 4. forward_note if len(self._out_queue) == 0: return {} note = self._out_queue.pop(0) return super().__call__(note) - + def all(self) -> StepOutput: return self.__call__() @@ -513,6 +512,7 @@ def __init__(self): def handle_prompt_response(self, response: dict): note = self._awaiting_note + assert note is not None, "PromptStep is not awaiting a response." self.on_prompt_response(note, response) self._out_queue.append(note) self._is_awaiting_response = False @@ -524,8 +524,8 @@ def get_prompt(self, note: Note) -> dict: def on_prompt_response(self, note: Note, response: Any): raise NotImplementedError( "on_prompt_response must be implemented for PromptStep" - ) - + ) + def __call__(self): # Handle Prompt if not self._is_awaiting_response and len(self._in_queue) > 0: @@ -534,11 +534,9 @@ def __call__(self): self._is_awaiting_response = True self._awaiting_note = note return super().__call__() - + def is_active(self) -> bool: - print("In queue", len(self._in_queue)) return len(self._in_queue) > 0 or self._awaiting_note is not None - class Split(Step): diff --git a/graphbook/viewer.py b/graphbook/viewer.py index 0bd9998..3a35c7f 100644 --- a/graphbook/viewer.py +++ b/graphbook/viewer.py @@ -183,10 +183,14 @@ def __init__(self): self.prompts = {} def handle_prompt(self, node_id: str, prompt: dict): + prev_prompt = self.prompts.get(node_id) + idx = 0 + if prev_prompt: + idx = prev_prompt["idx"] + 1 + prompt["idx"] = idx self.prompts[node_id] = prompt def get_next(self): - print(self.prompts) return self.prompts diff --git a/web/src/components/Nodes/Node.tsx b/web/src/components/Nodes/Node.tsx index 617f0e8..4796c5e 100644 --- a/web/src/components/Nodes/Node.tsx +++ b/web/src/components/Nodes/Node.tsx @@ -12,10 +12,10 @@ import { getMergedLogs, getMediaPath } from '../../utils'; import { useNotification } from '../../hooks/Notification'; import { useSettings } from '../../hooks/Settings'; import { SerializationErrorMessages } from '../Errors'; -import { Prompt, PromptProps } from './widgets/Prompts'; -import type { LogEntry, Parameter, ImageRef } from '../../utils'; +import { Prompt } from './widgets/Prompts'; import ReactJson from '@microlink/react-json-view'; -import { usePrompt } from '../../hooks/Prompts'; +import type { LogEntry, Parameter, ImageRef } from '../../utils'; + const { Panel } = Collapse; const { useToken } = theme; @@ -29,7 +29,6 @@ export function WorkflowStep({ id, data, selected }) { const [logsData, setLogsData] = useState([]); const [recordCount, setRecordCount] = useState({}); const [errored, setErrored] = useState(false); - const [prompt, setPrompt] = useState(null); const [parentSelected, setParentSelected] = useState(false); const [runState, runStateShouldChange] = useRunState(); const nodes = useNodes(); @@ -40,6 +39,8 @@ export function WorkflowStep({ id, data, selected }) { const API = useAPI(); const filename = useFilename(); + console.log("WorkflowStep"); + useAPINodeMessage('stats', id, filename, (msg) => { setRecordCount(msg.queue_size || {}); }); @@ -50,11 +51,6 @@ export function WorkflowStep({ id, data, selected }) { setLogsData(prev => getMergedLogs(prev, newEntries)); }, [setLogsData])); - usePrompt(id, (data: PromptProps) => { - console.log(data); - setPrompt({ ...data, stepId: id }); - }); - useEffect(() => { for (const log of logsData) { if (log.type === 'error') { @@ -172,11 +168,9 @@ export function WorkflowStep({ id, data, selected }) { }).filter(x => x) } - {prompt && -
- -
- } +
+ +
{!data.isCollapsed && } diff --git a/web/src/components/Nodes/widgets/NotePreview.tsx b/web/src/components/Nodes/widgets/NotePreview.tsx index 2f75d5d..33ee4dd 100644 --- a/web/src/components/Nodes/widgets/NotePreview.tsx +++ b/web/src/components/Nodes/widgets/NotePreview.tsx @@ -10,10 +10,10 @@ type QuickViewEntry = { [key: string]: any; }; -export function NotePreview({ data, show_images }: { data: QuickViewEntry, show_images: boolean }) { +export function NotePreview({ data, showImages }: { data: QuickViewEntry, showImages: boolean }) { const globalTheme = theme.useToken().theme; - if (!show_images) { + if (!showImages) { return ( { setValue(value); }, []); const onSubmit = useCallback(() => { if (API) { - API.respondToPrompt(stepId, value); + API.respondToPrompt(nodeId, value); } - }, [value]); + }, [value, API, nodeId]); + + useEffect(() => { + console.log('Prompted:', prompt); + }, [prompt]); + + if (!prompt) { + return null; + } return ( Prompted: - - {msg} - + + {prompt.msg} + ); -} \ No newline at end of file +} diff --git a/web/src/hooks/API.ts b/web/src/hooks/API.ts index 2581793..44e4171 100644 --- a/web/src/hooks/API.ts +++ b/web/src/hooks/API.ts @@ -77,6 +77,8 @@ export function useAPINodeMessage(event_type: string, node_id: string, filename: callback(msg[node_id]); } }, [node_id, callback, filename])); + + console.log("useAPINodeMessage"); } diff --git a/web/src/hooks/Prompts.ts b/web/src/hooks/Prompts.ts index d400dd0..5a86f33 100644 --- a/web/src/hooks/Prompts.ts +++ b/web/src/hooks/Prompts.ts @@ -1,28 +1,65 @@ import { useAPINodeMessage } from "./API"; import { useFilename } from "./Filename" -import { useEffect, useCallback } from "react"; +import { useEffect, useCallback, useState } from "react"; let globalPrompts = {}; +let localSetters: Function[] = []; -export function usePrompt(nodeId: string, callback: Function) { +type PromptData = { + "idx": number, + "note": object, + "msg": string, + "show_images": boolean, + "def": any, + "options": object, + "type": string +}; + +type Prompt = { + "note": object, + "msg": string, + "showImages": boolean, + "def": any, + "options": object, + "type": string +}; + +export function usePrompt(nodeId: string): Prompt | null { const filename = useFilename(); + const [_, setPrompt] = useState(null); useEffect(() => { + localSetters.push(setPrompt); return () => { + localSetters = localSetters.filter((setter) => setter !== setPrompt); delete globalPrompts[nodeId]; }; }, [nodeId]); - const internalCallback = useCallback((data) => { + const onPrompt = useCallback((data: PromptData) => { if (globalPrompts[nodeId]) { - return; + if (globalPrompts[nodeId].idx === data.idx) { + return; + } } - console.log(data); globalPrompts[nodeId] = data; - callback(data); }, [nodeId]); - useAPINodeMessage("prompt", nodeId, filename, internalCallback); + useAPINodeMessage("prompt", nodeId, filename, onPrompt); + + const promptData = globalPrompts[nodeId]; + if (!promptData) { + return null; + } + + return { + note: promptData.note, + msg: promptData.msg, + showImages: promptData.show_images, + def: promptData.def, + options: promptData.options, + type: promptData.type + }; } From 3810e709d2a4ce4e51bd29152213e37374ac348a Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Thu, 19 Sep 2024 15:58:59 -0700 Subject: [PATCH 04/14] Working, finished --- graphbook/processing/web_processor.py | 44 ++------------------ graphbook/prompts.py | 5 ++- graphbook/state.py | 12 ++++-- graphbook/steps/base.py | 19 +++++---- graphbook/web.py | 9 ++-- web/src/components/Nodes/Node.tsx | 2 - web/src/components/Nodes/widgets/Prompts.tsx | 19 +++++---- web/src/hooks/API.ts | 2 - web/src/hooks/Prompts.ts | 15 +++++-- 9 files changed, 55 insertions(+), 72 deletions(-) diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py index 5a7ee41..d7f129c 100644 --- a/graphbook/processing/web_processor.py +++ b/graphbook/processing/web_processor.py @@ -4,7 +4,6 @@ GeneratorSourceStep, AsyncStep, BatchStep, - PromptStep, StepOutput, ) from graphbook.dataloading import Dataloader, setup_global_dl @@ -134,45 +133,13 @@ def exec_step( id(step), ) return None - self.handle_images(outputs) self.graph_state.handle_outputs( step.id, outputs if not self.copy_outputs else copy.deepcopy(outputs) ) - self.view_manager.handle_outputs(step.id, transform_json_log(outputs)) - self.view_manager.handle_time(step.id, time.time() - start_time) + self.view_manager.handle_time(step.id, time.time() - start_time) return outputs - def handle_prompt_step(self, step: PromptStep, input: Note | None = None): - print("Handling prompt with input", input) - unhandled_prompt = self.unhandled_prompts.get(step.id) - if unhandled_prompt: - response = unhandled_prompt.get("response") - if response: - note = unhandled_prompt.get("note") - step.on_prompt_response(note, response) - self.unhandled_prompts.pop(step.id) - outputs = self.exec_step(step, unhandled_prompt.get("note")) - return outputs - return {} - - if input is None: - return {} - - prompt = step.get_prompt(input) - if prompt is None: - return {} - - if prompt.get("pause"): - self.prompted_pause_event = True - self.view_manager.handle_prompt(step.id, prompt) - self.unhandled_prompts[step.id] = { - "note": input, - "prompt": prompt, - "response": None, - } - return {} - def handle_steps(self, steps: List[Step]) -> bool: is_active = False for step in steps: @@ -189,10 +156,6 @@ def handle_steps(self, steps: List[Step]) -> bool: except StopIteration: input = None - # if isinstance(step, PromptStep): - # self.handle_prompt_step(step, input) - # print("ON PROMPT", len(self.unhandled_prompts)) - # is_active = is_active or len(self.unhandled_prompts) > 0 if isinstance(step, AsyncStep): if is_active: # parent is active # Proceed with normal step execution @@ -247,8 +210,6 @@ def run(self, step_id: str = None): and not self.dataloader.is_failed() ): dag_is_active = self.handle_steps(steps) - print("Done") - print(dag_is_active, self.pause_event.is_set(), self.close_event.is_set(), self.dataloader.is_failed()) finally: self.view_manager.handle_end() for step in steps: @@ -377,7 +338,8 @@ def _loop(self): output = self.running_state elif req["cmd"] == ProcessorStateRequest.PROMPT_RESPONSE: step_id = req.get("step_id") - self.graph_state.handle_prompt_response(step_id, req.get("response")) + succeeded = self.graph_state.handle_prompt_response(step_id, req.get("response")) + output = {"ok": succeeded} else: output = {} entry = {"res": req["cmd"], "data": output} diff --git a/graphbook/prompts.py b/graphbook/prompts.py index 0133781..89cff4a 100644 --- a/graphbook/prompts.py +++ b/graphbook/prompts.py @@ -2,6 +2,9 @@ from .note import Note from .utils import transform_json_log +def none(): + return {"type": None} + def prompt(note: Note, msg: str, show_images: bool=False, default: Any=""): return { "note": transform_json_log(note), @@ -10,7 +13,7 @@ def prompt(note: Note, msg: str, show_images: bool=False, default: Any=""): "def": default } -def bool_prompt(note: Note, msg: str="", style: str="yes/no", default: bool=False, show_images: bool=False): +def bool_prompt(note: Note, msg: str="Continue?", style: str="yes/no", default: bool=False, show_images: bool=False): default = "Yes" if default else "No" p = prompt(note, msg, default=default, show_images=show_images) p["type"] = "bool" diff --git a/graphbook/state.py b/graphbook/state.py index 5e471cf..4be7d56 100644 --- a/graphbook/state.py +++ b/graphbook/state.py @@ -8,6 +8,7 @@ from graphbook.viewer import ViewManagerInterface from graphbook.plugins import setup_plugins from graphbook.logger import setup_logging_nodes +from graphbook.utils import transform_json_log import multiprocessing as mp import importlib, importlib.util, inspect import graphbook.exports as exports @@ -344,6 +345,7 @@ def handle_outputs(self, step_id: str, outputs: Outputs): self._step_states[step_id].add(StepState.EXECUTED) self._step_states[step_id].add(StepState.EXECUTED_THIS_RUN) self.view_manager.handle_queue_size(step_id, self._queues[step_id].dict_sizes()) + self.view_manager.handle_outputs(step_id, transform_json_log(outputs)) def clear_outputs(self, node_id: str | None = None): if node_id is None: @@ -401,11 +403,15 @@ def get_output_note(self, step_id: str, pin_id: str, index: int) -> dict: entry.update(data=note.items) return entry - def handle_prompt_response(self, step_id: str, response: dict): + def handle_prompt_response(self, step_id: str, response: dict) -> bool: step = self._steps.get(step_id) if not isinstance(step, PromptStep): - return - step.handle_prompt_response(response) + return False + try: + step.handle_prompt_response(response) + return True + except: + return False def get_step(self, step_id: str): return self._steps.get(step_id) diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py index 7bdf183..cf65155 100644 --- a/graphbook/steps/base.py +++ b/graphbook/steps/base.py @@ -10,7 +10,7 @@ import graphbook.prompts as prompts import graphbook.dataloading as dataloader import warnings - +import traceback warnings.simplefilter("default", DeprecationWarning) @@ -510,16 +510,22 @@ def __init__(self): self._is_awaiting_response = False self._awaiting_note = None - def handle_prompt_response(self, response: dict): + def handle_prompt_response(self, response: Any): note = self._awaiting_note - assert note is not None, "PromptStep is not awaiting a response." - self.on_prompt_response(note, response) - self._out_queue.append(note) + try: + assert note is not None, "PromptStep is not awaiting a response." + self.on_prompt_response(note, response) + self._out_queue.append(note) + except Exception as e: + self.log(f"{type(e).__name__}: {str(e)}", "error") + traceback.print_exc() + self._is_awaiting_response = False self._awaiting_note = None + prompt(prompts.none()) def get_prompt(self, note: Note) -> dict: - return prompts.bool_prompt(note, "Continue?", "yes/no") + return prompts.bool_prompt(note) def on_prompt_response(self, note: Note, response: Any): raise NotImplementedError( @@ -527,7 +533,6 @@ def on_prompt_response(self, note: Note, response: Any): ) def __call__(self): - # Handle Prompt if not self._is_awaiting_response and len(self._in_queue) > 0: note = self._in_queue.pop(0) prompt(self.get_prompt(note)) diff --git a/graphbook/web.py b/graphbook/web.py index 77c9d20..648cf02 100644 --- a/graphbook/web.py +++ b/graphbook/web.py @@ -211,8 +211,8 @@ async def prompt_response(request: web.Request) -> web.Response: step_id = request.match_info.get("id") data = await request.json() response = data.get("response") - poll_conn_for(state_conn, ProcessorStateRequest.PROMPT_RESPONSE, {"step_id": step_id, "response": response}) - return web.json_response({"success": True}) + res = poll_conn_for(state_conn, ProcessorStateRequest.PROMPT_RESPONSE, {"step_id": step_id, "response": response}) + return web.json_response(res) @routes.get("/nodes") async def get_nodes(request: web.Request) -> web.Response: @@ -524,7 +524,10 @@ def signal_handler(*_): close_event.set() if img_mem: - img_mem.close() + try: + img_mem.close() + except FileNotFoundError: + pass raise KeyboardInterrupt() diff --git a/web/src/components/Nodes/Node.tsx b/web/src/components/Nodes/Node.tsx index 4796c5e..065e457 100644 --- a/web/src/components/Nodes/Node.tsx +++ b/web/src/components/Nodes/Node.tsx @@ -39,8 +39,6 @@ export function WorkflowStep({ id, data, selected }) { const API = useAPI(); const filename = useFilename(); - console.log("WorkflowStep"); - useAPINodeMessage('stats', id, filename, (msg) => { setRecordCount(msg.queue_size || {}); }); diff --git a/web/src/components/Nodes/widgets/Prompts.tsx b/web/src/components/Nodes/widgets/Prompts.tsx index f89e213..32dad7f 100644 --- a/web/src/components/Nodes/widgets/Prompts.tsx +++ b/web/src/components/Nodes/widgets/Prompts.tsx @@ -50,23 +50,24 @@ function Widget({ type, options, value, onChange }) { export function Prompt({ nodeId }: { nodeId: string }) { const API = useAPI(); - const prompt = usePrompt(nodeId); + const [prompt, setSubmitted] = usePrompt(nodeId); const [value, setValue] = useState(null); + const [loading, setLoading] = useState(false); const onChange = useCallback((value) => { setValue(value); }, []); - const onSubmit = useCallback(() => { + const onSubmit = useCallback(async () => { if (API) { - API.respondToPrompt(nodeId, value); + setLoading(true); + const res = await API.respondToPrompt(nodeId, value); + console.log(res); + setLoading(false); + setSubmitted(); } }, [value, API, nodeId]); - useEffect(() => { - console.log('Prompted:', prompt); - }, [prompt]); - - if (!prompt) { + if (!prompt || prompt.type === null) { return null; } @@ -76,7 +77,7 @@ export function Prompt({ nodeId }: { nodeId: string }) { {prompt.msg} - +
); } diff --git a/web/src/hooks/API.ts b/web/src/hooks/API.ts index 44e4171..2581793 100644 --- a/web/src/hooks/API.ts +++ b/web/src/hooks/API.ts @@ -77,8 +77,6 @@ export function useAPINodeMessage(event_type: string, node_id: string, filename: callback(msg[node_id]); } }, [node_id, callback, filename])); - - console.log("useAPINodeMessage"); } diff --git a/web/src/hooks/Prompts.ts b/web/src/hooks/Prompts.ts index 5a86f33..23a508b 100644 --- a/web/src/hooks/Prompts.ts +++ b/web/src/hooks/Prompts.ts @@ -25,7 +25,7 @@ type Prompt = { "type": string }; -export function usePrompt(nodeId: string): Prompt | null { +export function usePrompt(nodeId: string): [Prompt | null, Function] { const filename = useFilename(); const [_, setPrompt] = useState(null); @@ -47,19 +47,26 @@ export function usePrompt(nodeId: string): Prompt | null { globalPrompts[nodeId] = data; }, [nodeId]); + const setSubmitted = useCallback(() => { + if (!globalPrompts[nodeId]) { + return; + } + globalPrompts[nodeId] = { ...globalPrompts[nodeId], "type": null }; + }, []); + useAPINodeMessage("prompt", nodeId, filename, onPrompt); const promptData = globalPrompts[nodeId]; if (!promptData) { - return null; + return [null, setSubmitted]; } - return { + return [{ note: promptData.note, msg: promptData.msg, showImages: promptData.show_images, def: promptData.def, options: promptData.options, type: promptData.type - }; + }, setSubmitted]; } From 34b260d81b589b29dcdb40c86d7b26c009975810 Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Thu, 19 Sep 2024 17:45:54 -0700 Subject: [PATCH 05/14] yes/no prompt default value fix --- graphbook/logger.py | 2 +- web/src/components/Nodes/widgets/Prompts.tsx | 3 +-- web/src/components/Nodes/widgets/Widgets.tsx | 10 +++++++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/graphbook/logger.py b/graphbook/logger.py index 925e3fb..0d80417 100644 --- a/graphbook/logger.py +++ b/graphbook/logger.py @@ -58,5 +58,5 @@ def prompt(prompt: dict, caller_id: int | None = None): raise ValueError( f"Can't find node id in {caller}. Only initialized steps can log." ) - node_id, node_name = node + node_id, _ = node view_manager.handle_prompt(node_id, prompt) diff --git a/web/src/components/Nodes/widgets/Prompts.tsx b/web/src/components/Nodes/widgets/Prompts.tsx index 32dad7f..e3befdd 100644 --- a/web/src/components/Nodes/widgets/Prompts.tsx +++ b/web/src/components/Nodes/widgets/Prompts.tsx @@ -60,8 +60,7 @@ export function Prompt({ nodeId }: { nodeId: string }) { const onSubmit = useCallback(async () => { if (API) { setLoading(true); - const res = await API.respondToPrompt(nodeId, value); - console.log(res); + await API.respondToPrompt(nodeId, value); setLoading(false); setSubmitted(); } diff --git a/web/src/components/Nodes/widgets/Widgets.tsx b/web/src/components/Nodes/widgets/Widgets.tsx index 7baaf31..9fd09af 100644 --- a/web/src/components/Nodes/widgets/Widgets.tsx +++ b/web/src/components/Nodes/widgets/Widgets.tsx @@ -63,7 +63,15 @@ export function StringWidget({ name, def, onChange }) { export function BooleanWidget({ name, def, onChange, style }) { const input = useMemo(() => { if (style === "yes/no") { - return onChange(e.target.value)} value={def} optionType="button" /> + const M = { + "Yes": true, + "No": false, + }; + const M_ = { + true: "Yes", + false: "No", + }; + return onChange(M[e.target.value])} value={M_[def]} optionType="button" /> } return }, [style, def]); From 7d60f1e687cac43d932755504535f2f54a8c09e7 Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Thu, 19 Sep 2024 18:04:19 -0700 Subject: [PATCH 06/14] cleaning up --- graphbook/processing/web_processor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/graphbook/processing/web_processor.py b/graphbook/processing/web_processor.py index d7f129c..e6c6d32 100644 --- a/graphbook/processing/web_processor.py +++ b/graphbook/processing/web_processor.py @@ -45,7 +45,6 @@ def __init__( self.cmd_queue = cmd_queue self.close_event = close_event self.pause_event = pause_event - self.prompted_pause_event = False self.view_manager = ViewManagerInterface(view_manager_queue) self.img_mem = img_mem self.graph_state = GraphState(custom_nodes_path, view_manager_queue) @@ -56,13 +55,11 @@ def __init__( self.steps = {} self.dataloader = Dataloader(self.num_workers) setup_global_dl(self.dataloader) - self.unhandled_prompts = {} self.state_client = ProcessorStateClient( server_request_conn, close_event, self.graph_state, self.dataloader, - self.unhandled_prompts, ) self.is_running = False self.filename = None @@ -307,14 +304,12 @@ def __init__( close_event: mp.Event, graph_state: GraphState, dataloader: Dataloader, - unhandled_prompts: dict, ): self.server_request_conn = server_request_conn self.close_event = close_event self.curr_task = None self.graph_state = graph_state self.dataloader = dataloader - self.unhandled_prompts = unhandled_prompts self.running_state = {} def _loop(self): From 3362944e252928d78f06cf7b92ffb9dc1e5da3a4 Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Thu, 19 Sep 2024 18:04:25 -0700 Subject: [PATCH 07/14] cleaning up --- graphbook/prompts.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/graphbook/prompts.py b/graphbook/prompts.py index 89cff4a..42f7825 100644 --- a/graphbook/prompts.py +++ b/graphbook/prompts.py @@ -2,22 +2,31 @@ from .note import Note from .utils import transform_json_log + def none(): return {"type": None} -def prompt(note: Note, msg: str, show_images: bool=False, default: Any=""): + +def prompt(note: Note, msg: str, show_images: bool = False, default: Any = ""): return { "note": transform_json_log(note), "msg": msg, "show_images": show_images, - "def": default + "def": default, } -def bool_prompt(note: Note, msg: str="Continue?", style: str="yes/no", default: bool=False, show_images: bool=False): + +def bool_prompt( + note: Note, + *, + msg: str = "Continue?", + style: str = "yes/no", + default: bool = False, + show_images: bool = False, + pause: bool = True, +): default = "Yes" if default else "No" p = prompt(note, msg, default=default, show_images=show_images) p["type"] = "bool" - p["options"] = { - "style": style - } + p["options"] = {"style": style} return p From f4defe8d3e851ee2e920f6635d5304a23504f79f Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Thu, 19 Sep 2024 18:04:33 -0700 Subject: [PATCH 08/14] cleaning up --- graphbook/prompts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphbook/prompts.py b/graphbook/prompts.py index 42f7825..db18e2f 100644 --- a/graphbook/prompts.py +++ b/graphbook/prompts.py @@ -23,7 +23,6 @@ def bool_prompt( style: str = "yes/no", default: bool = False, show_images: bool = False, - pause: bool = True, ): default = "Yes" if default else "No" p = prompt(note, msg, default=default, show_images=show_images) From 510aa6bc41da0f7d1c4cb7c283d0a0f4e6db9e1e Mon Sep 17 00:00:00 2001 From: Richard Franklin Date: Thu, 19 Sep 2024 18:08:31 -0700 Subject: [PATCH 09/14] set --- graphbook/prompts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphbook/prompts.py b/graphbook/prompts.py index db18e2f..fd2880d 100644 --- a/graphbook/prompts.py +++ b/graphbook/prompts.py @@ -7,7 +7,7 @@ def none(): return {"type": None} -def prompt(note: Note, msg: str, show_images: bool = False, default: Any = ""): +def prompt(note: Note, *, msg: str = "", show_images: bool = False, default: Any = ""): return { "note": transform_json_log(note), "msg": msg, @@ -25,7 +25,7 @@ def bool_prompt( show_images: bool = False, ): default = "Yes" if default else "No" - p = prompt(note, msg, default=default, show_images=show_images) + p = prompt(note, msg=msg, default=default, show_images=show_images) p["type"] = "bool" p["options"] = {"style": style} return p From 3d5892e6780a8475f9487955c47e87498315bff2 Mon Sep 17 00:00:00 2001 From: "Richard S. Franklin" Date: Tue, 24 Sep 2024 18:09:30 -0700 Subject: [PATCH 10/14] New dropdown widget (#96) --- graphbook/prompts.py | 131 +++++++++++++++- graphbook/steps/base.py | 5 + web/src/components/Flow.tsx | 4 +- web/src/components/Nodes/Node.tsx | 2 +- web/src/components/Nodes/node.css | 45 ++++++ web/src/components/Nodes/widgets/Prompts.tsx | 55 +++---- web/src/components/Nodes/widgets/Widgets.tsx | 149 +++++++++++++------ web/src/graph.ts | 8 +- web/src/hooks/Prompts.ts | 25 ++-- web/src/utils.ts | 8 + 10 files changed, 332 insertions(+), 100 deletions(-) diff --git a/graphbook/prompts.py b/graphbook/prompts.py index fd2880d..04b4b24 100644 --- a/graphbook/prompts.py +++ b/graphbook/prompts.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List from .note import Note from .utils import transform_json_log @@ -24,8 +24,135 @@ def bool_prompt( default: bool = False, show_images: bool = False, ): - default = "Yes" if default else "No" + """ + Prompt the user with a yes/no for binary questions. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + style (str): The style of the bool prompt. Can be "yes/no" or "switch". + default (bool): The default bool value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ p = prompt(note, msg=msg, default=default, show_images=show_images) p["type"] = "bool" p["options"] = {"style": style} return p + +def selection_prompt( + note: Note, + choices: List[str], + *, + msg: str = "Select an option:", + default: List[str] | str = None, + show_images: bool = False, + multiple_allowed: bool = False +): + """ + Prompt the user to select an option from a list of choices. + + Args: + note (Note): The current note that triggered the prompt. + choices (List[str]): A list of strings representing the options the user can select. + msg (str): An informative message or inquiry to display to the user. + default (List[str] | str): The default value. If multiple_allowed is True, this should be a list of strings. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + multiple_allowed (bool): Whether the user can select multiple options from the list of given choices. + """ + assert len(choices) > 0, "Choices must not be empty in selection prompt." + if default is None: + if multiple_allowed: + default = [choices[0]] + else: + default = choices[0] + + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "selection" + p["options"] = { + "choices": choices, + "multiple_allowed": multiple_allowed, + } + return p + +def text_prompt( + note: Note, + *, + msg: str = "Enter text:", + default: str = "", + show_images: bool = False, +): + """ + Prompt the user to enter text. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + default (str): The default text value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "string" + return p + +def number_prompt( + note: Note, + *, + msg: str = "Enter a number:", + default: float = 0.0, + show_images: bool = False, +): + """ + Prompt the user to enter a number. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + default (float): The default number value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "number" + return p + +def dict_prompt( + note: Note, + *, + msg: str = "Enter a dictionary:", + default: dict = {}, + show_images: bool = False, +): + """ + Prompt the user to enter a dictionary. + + Args: + note (Note): The current note that triggered the prompt. + msg (str): An informative message or inquiry to display to the user. + default (dict): The default dictionary value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = "dict" + return p + +def list_prompt( + note: Note, + type: str = "string", + *, + msg: str = "Enter a list:", + default: list = [], + show_images: bool = False, +): + """ + Prompt the user to enter a list. + + Args: + note (Note): The current note that triggered the prompt. + type (str): The type of the list elements. Can be "string", "number", "dict", or "bool". + msg (str): An informative message or inquiry to display to the user. + default (list): The default list value. + show_images (bool): Whether to present the images (instead of the Note object) to the user. + """ + assert type in ["string", "number", "dict", "bool"], "Invalid type in list prompt." + p = prompt(note, msg=msg, default=default, show_images=show_images) + p["type"] = f"list[{type}]" + return p diff --git a/graphbook/steps/base.py b/graphbook/steps/base.py index cf65155..4c1d992 100644 --- a/graphbook/steps/base.py +++ b/graphbook/steps/base.py @@ -523,6 +523,11 @@ def handle_prompt_response(self, response: Any): self._is_awaiting_response = False self._awaiting_note = None prompt(prompts.none()) + + def on_clear(self): + self._is_awaiting_response = False + self._awaiting_note = None + prompt(prompts.none()) def get_prompt(self, note: Note) -> dict: return prompts.bool_prompt(note) diff --git a/web/src/components/Flow.tsx b/web/src/components/Flow.tsx index 1f6e174..8e463ac 100644 --- a/web/src/components/Flow.tsx +++ b/web/src/components/Flow.tsx @@ -118,7 +118,7 @@ export default function Flow({ filename }) { return onNodesChange(newChanges); } onNodesChange(changes); - }, []); + }, [runState]); const onEdgesChangeCallback = useCallback((changes) => { setIsAddNodeActive(false); @@ -135,7 +135,7 @@ export default function Flow({ filename }) { return onEdgesChange(newChanges); } onEdgesChange(changes); - }, []); + }, [runState]); const onConnect = useCallback((params) => { const targetNode = nodes.find(n => n.id === params.target); diff --git a/web/src/components/Nodes/Node.tsx b/web/src/components/Nodes/Node.tsx index 065e457..743f2c3 100644 --- a/web/src/components/Nodes/Node.tsx +++ b/web/src/components/Nodes/Node.tsx @@ -158,7 +158,7 @@ export function WorkflowStep({ id, data, selected }) { if (isWidgetType(parameter.type)) { return (
- +
); } diff --git a/web/src/components/Nodes/node.css b/web/src/components/Nodes/node.css index 420a6d9..6416aaa 100644 --- a/web/src/components/Nodes/node.css +++ b/web/src/components/Nodes/node.css @@ -64,6 +64,51 @@ textarea.code { outline: none; } +.workflow-node .widgets .ant-select-selection-item { + font-size: .6em; + line-height: 1.2; + padding-inline-end: 0px; + padding: 0; +} + +.workflow-node .widgets .ant-select-arrow { + font-size: .6em; +} + +.workflow-node .widgets .ant-select-selector { + padding: 0 4px; + line-height: 1; +} + +.workflow-node .widgets .ant-select-selector::after { + line-height: unset; +} + +.workflow-node .widgets .ant-select-arrow { + display: none; +} + +.workflow-node .widgets .ant-select { + height: unset; +} + +.workflow-node .widgets input.ant-select-selection-search-input { + height: unset; +} + +.workflow-node .widgets .ant-select-selection-item-remove svg { + width: 8px; + height: 8px; +} + +.workflow-node .widgets .ant-select-selection-item-content { + line-height: 1; +} + +.workflow-node .widgets .ant-select-selection-item { + height: 10px; +} + .workflow-node .collapsed { display: flex; flex-direction: row; diff --git a/web/src/components/Nodes/widgets/Prompts.tsx b/web/src/components/Nodes/widgets/Prompts.tsx index e3befdd..4849dfe 100644 --- a/web/src/components/Nodes/widgets/Prompts.tsx +++ b/web/src/components/Nodes/widgets/Prompts.tsx @@ -1,38 +1,16 @@ -import React, { useCallback, useMemo, useState, useEffect } from 'react'; +import React, { useCallback, useMemo, useState } from 'react'; import { Typography, Flex, Button } from 'antd'; import { usePluginWidgets } from '../../../hooks/Plugins'; import { NotePreview } from './NotePreview'; -import { NumberWidget, StringWidget, BooleanWidget, FunctionWidget, DictWidget, ListWidget } from './Widgets'; +import { ListWidget, getWidgetLookup } from './Widgets'; import { useAPI } from '../../../hooks/API'; import { usePrompt } from '../../../hooks/Prompts'; +import type { Prompt } from '../../../hooks/Prompts'; +import { parseDictWidgetValue } from '../../../utils'; const { Text } = Typography; -export type PromptProps = { - stepId: string, - note: any, - msg: string, - type: string, - def: any, - show_images?: boolean, - options?: any, -}; -const getWidgetLookup = (pluginWidgets) => { - const lookup = { - number: NumberWidget, - string: StringWidget, - boolean: BooleanWidget, - bool: BooleanWidget, - function: FunctionWidget, - dict: DictWidget, - }; - pluginWidgets.forEach((widget) => { - lookup[widget.type] = widget.children; - }); - return lookup; -}; - -function Widget({ type, options, value, onChange }) { +function WidgetPrompt({ type, options, value, onChange }) { const pluginWidgets = usePluginWidgets(); const widgets = useMemo(() => { return getWidgetLookup(pluginWidgets); @@ -43,16 +21,21 @@ function Widget({ type, options, value, onChange }) { } if (widgets[type]) { - return widgets[type]({ name: "Answer", def: value, onChange, style: options.style }); + return widgets[type]({ name: "Answer", def: value, onChange, ...options }); } } - export function Prompt({ nodeId }: { nodeId: string }) { const API = useAPI(); - const [prompt, setSubmitted] = usePrompt(nodeId); - const [value, setValue] = useState(null); + const [value, setValue] = useState(null); const [loading, setLoading] = useState(false); + + const onPromptChange = useCallback((prompt) => { + setValue(prompt?.def); + }, []); + + const [prompt, setSubmitted] = usePrompt(nodeId, onPromptChange); + const onChange = useCallback((value) => { setValue(value); }, []); @@ -60,7 +43,11 @@ export function Prompt({ nodeId }: { nodeId: string }) { const onSubmit = useCallback(async () => { if (API) { setLoading(true); - await API.respondToPrompt(nodeId, value); + let answer = value; + if (prompt?.type === 'dict') { + answer = parseDictWidgetValue(answer); + } + await API.respondToPrompt(nodeId, answer); setLoading(false); setSubmitted(); } @@ -73,9 +60,9 @@ export function Prompt({ nodeId }: { nodeId: string }) { return ( Prompted: - + {prompt.msg} - + ); diff --git a/web/src/components/Nodes/widgets/Widgets.tsx b/web/src/components/Nodes/widgets/Widgets.tsx index 9fd09af..751f701 100644 --- a/web/src/components/Nodes/widgets/Widgets.tsx +++ b/web/src/components/Nodes/widgets/Widgets.tsx @@ -1,6 +1,6 @@ -import { Switch, Typography, theme, Flex, Button, Radio } from 'antd'; +import { Switch, Typography, theme, Flex, Button, Radio, Select as ASelect } from 'antd'; import { PlusOutlined, MinusOutlined } from '@ant-design/icons'; -import React, { useCallback, useState, useMemo } from 'react'; +import React, { useCallback, useState, useMemo, useEffect } from 'react'; import CodeMirror from '@uiw/react-codemirror'; import { python } from '@codemirror/lang-python'; import { basicDark } from '@uiw/codemirror-theme-basic'; @@ -11,7 +11,7 @@ import { usePluginWidgets } from '../../../hooks/Plugins'; const { Text } = Typography; const { useToken } = theme; -const getWidgetLookup = (pluginWidgets) => { +export const getWidgetLookup = (pluginWidgets) => { const lookup = { number: NumberWidget, string: StringWidget, @@ -19,14 +19,16 @@ const getWidgetLookup = (pluginWidgets) => { bool: BooleanWidget, function: FunctionWidget, dict: DictWidget, + selection: SelectionWidget, }; + pluginWidgets.forEach((widget) => { lookup[widget.type] = widget.children; }); return lookup; }; -export function Widget({ id, type, name, value }) { +export function Widget({ id, type, name, value, ...props }) { const { setNodes } = useReactFlow(); const pluginWidgets = usePluginWidgets(); const widgets = useMemo(() => { @@ -44,7 +46,7 @@ export function Widget({ id, type, name, value }) { } if (widgets[type]) { - return widgets[type]({ name, def: value, onChange }); + return widgets[type]({ name, def: value, onChange, ...props }); } } @@ -71,7 +73,7 @@ export function BooleanWidget({ name, def, onChange, style }) { true: "Yes", false: "No", }; - return onChange(M[e.target.value])} value={M_[def]} optionType="button" /> + return onChange(M[e.target.value])} value={M_[def]} optionType="button" /> } return }, [style, def]); @@ -162,10 +164,10 @@ export function ListWidget({ name, def, onChange, type }) { - ) + ); } -export function DictWidget({ name, def, onChange, type }) { +export function DictWidget({ name, def, onChange }) { const { token } = useToken(); const value = useMemo(() => { if (!def) { @@ -224,10 +226,10 @@ export function DictWidget({ name, def, onChange, type }) { const selectedInputs = useMemo(() => { return value.map((item, i) => { if (item[0] === 'string') { - return onValueChange(i, value)} value={item[2]} /> + return onValueChange(i, value)} value={item[2]} /> } if (item[0] === 'float' || item[0] === 'int' || item[0] === 'number') { - return onValueChange(i, value)} value={item[2]} /> + return onValueChange(i, value)} value={item[2]} /> } if (item[0] === 'boolean') { return onValueChange(i, value)} value={item[2]} /> @@ -244,66 +246,113 @@ export function DictWidget({ name, def, onChange, type }) { }, []); return ( - + {name} {value && value.map((item, i) => { return ( - - + + onKeyChange(i, value)} value={item[1]} /> - - - - { - selectedInputs[i] - } -