Skip to content

Commit

Permalink
Merge pull request #98 from graphbookai/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
rsamf authored Sep 26, 2024
2 parents c271e71 + af46121 commit 4f3d6cd
Show file tree
Hide file tree
Showing 27 changed files with 951 additions and 105 deletions.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,34 @@

<h1 align="center">Graphbook</h1>

<p align="center">
<a href="https://github.com/graphbookai/graphbook/blob/main/LICENSE">
<img alt="GitHub License" src="https://img.shields.io/github/license/graphbookai/graphbook">
</a>
<a href="https://github.com/graphbookai/graphbook/actions/workflows/pypi.yml">
<img alt="GitHub Actions Workflow Status" src="https://img.shields.io/github/actions/workflow/status/graphbookai/graphbook/pypi.yml">
</a>
<a href="https://hub.docker.com/r/rsamf/graphbook">
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/rsamf/graphbook">
</a>
<a href="https://www.pepy.tech/projects/graphbook">
<img alt="PyPI Downloads" src="https://static.pepy.tech/badge/graphbook">
</a>
<a href="https://pypi.org/project/graphbook/">
<img alt="PyPI - Version" src="https://img.shields.io/pypi/v/graphbook">
</a>
</p>
<div align="center">
<a href="https://discord.gg/XukMUDmjnt">
<img alt="Join Discord" src="https://img.shields.io/badge/Join%20our%20Discord-5865F2?style=for-the-badge&logo=discord&logoColor=white">
</a>
</div>
<p align="center">
<a href="https://discord.gg/XukMUDmjnt">
<img alt="Discord" src="https://img.shields.io/discord/1199855707567177860">
</a>
</p>

<p align="center">
The ML workflow framework
<br>
Expand All @@ -23,7 +51,7 @@
</p>

## Overview
Graphbook is a framework for building efficient, visual DAG-structured ML workflows composed of nodes written in Python. Graphbook provides common ML processing features such as multiprocessing IO and automatic batching, and it features a web-based UI to assemble, monitor, and execute data processing workflows. It can be used to prepare training data for custom ML models, experiment with custom trained or off-the-shelf models, and to build ML-based ETL applications. Custom nodes can be built in Python, and Graphbook will behave like a framework and call lifecycle methods on those nodes.
Graphbook is a framework for building efficient, visual DAG-structured ML workflows composed of nodes written in Python. Graphbook provides common ML processing features such as multiprocessing IO and automatic batching for PyTorch tensors, and it features a web-based UI to assemble, monitor, and execute data processing workflows. It can be used to prepare training data for custom ML models, experiment with custom trained or off-the-shelf models, and to build ML-based ETL applications. Custom nodes can be built in Python, and Graphbook will behave like a framework and call lifecycle methods on those nodes.

<p align="center">
<a href="https://graphbook.ai">
Expand Down
4 changes: 2 additions & 2 deletions graphbook/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .note import Note
from .decorators import step, param, source, output, batch, resource, event
from .decorators import step, param, source, output, batch, resource, event, prompt

__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "Note"]
__all__ = ["step", "param", "source", "output", "batch", "resource", "event", "prompt", "Note"]
2 changes: 2 additions & 0 deletions graphbook/custom_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from graphbook.steps import (
Step,
BatchStep,
PromptStep,
SourceStep,
GeneratorSourceStep,
AsyncStep,
Expand All @@ -25,6 +26,7 @@
BUILT_IN_STEPS = [
Step,
BatchStep,
PromptStep,
SourceStep,
GeneratorSourceStep,
AsyncStep,
Expand Down
72 changes: 70 additions & 2 deletions graphbook/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ def param(
"required": required,
"description": description,
}
self.parameter_type_casts[name] = cast_as
if cast_as is None:
# Default casts
if type == "function":
self.parameter_type_casts[name] = transform_function_string
if type == "int":
self.parameter_type_casts[name] = int
else:
self.parameter_type_casts[name] = cast_as

@abc.abstractmethod
def build():
Expand Down Expand Up @@ -85,6 +84,11 @@ def batch(
if dump_fn is not None:
self.event("dump_fn", dump_fn)

def prompt(self, get_prompt=None):
self.BaseClass = steps.PromptStep
if get_prompt is not None:
self.event("get_prompt", get_prompt)

def build(self):
def __init__(cls, **kwargs):
if self.BaseClass == steps.BatchStep:
Expand Down Expand Up @@ -212,6 +216,8 @@ def decorator(func):
factory.event("on_note", func)
elif factory.BaseClass == steps.BatchStep:
factory.event("on_item_batch", func)
elif factory.BaseClass == steps.PromptStep:
factory.event("on_prompt_response", func)
else:
factory.event("load", func)

Expand Down Expand Up @@ -456,3 +462,65 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def prompt(get_prompt: callable = None):
"""
Marks a function as a step that is capable of prompting the user.
This is useful for interactive workflows where data labeling, model evaluation, or any other human input is required.
Events ``get_prompt(ctx, note: Note)`` and ``on_prompt_response(ctx, note: Note, response: Any)`` are required to be implemented.
The decorator accepts the ``get_prompt`` function that returns a prompt to display to the user.
If nothing is passed as an argument, a ``bool_prompt`` will be used by default.
If the function returns **None** on any given note, no prompt will be displayed for that note allowing for conditional prompts based on the note's content.
Available prompts are located in the ``graphbook.prompts`` module.
The function that this decorator decorates is ``on_prompt_response`` and will be called when a response to a prompt is obtained from a user.
Once the prompt is handled, the execution lifecycle of the Step will proceed, normally.
Args:
get_prompt (callable): A function that returns a prompt. Default is ``bool_prompt``.
Examples:
.. highlight:: python
.. code-block:: python
def dog_or_cat(ctx, note: Note):
return selection_prompt(note, choices=["dog", "cat"], show_images=True)
@step("Prompts/Label")
@prompt(dog_or_cat)
def label_images(ctx, note: Note, response: str):
note["label"] = response
def corrective_prompt(ctx, note: Note):
if note["prediction_confidence"] < 0.65:
return bool_prompt(
note,
msg=f"Model prediction ({note['pred']}) was uncertain. Is its prediction correct?",
show_images=True,
)
else:
return None
@step("Prompts/CorrectModelLabel")
@prompt(corrective_prompt)
def correct_model_labels(ctx, note: Note, response: bool):
if response:
ctx.log("Model is correct!")
note["label"] = note["pred"]
else:
ctx.log("Model is incorrect!")
if note["pred"] == "dog":
note["label"] = "cat"
else:
note["label"] = "dog"
"""
def decorator(func):
def set_prompt(factory: StepClassFactory):
factory.prompt(get_prompt)

return DecoratorFunction(func, set_prompt)

return decorator
16 changes: 16 additions & 0 deletions graphbook/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import multiprocessing as mp
from typing import Dict, Tuple, Any
import inspect
from graphbook.note import Note
from graphbook.viewer import ViewManagerInterface
from graphbook.utils import transform_json_log

Expand Down Expand Up @@ -44,3 +45,18 @@ def log(msg: Any, type: LogType = "info", caller_id: int | None = None):
else:
raise ValueError(f"Unknown log type {type}")
view_manager.handle_log(node_id, msg, type)

def prompt(prompt: dict, caller_id: int | None = None):
if caller_id is None:
prev_frame = inspect.currentframe().f_back
caller = prev_frame.f_locals.get("self")
if caller is not None:
caller_id = id(caller)

node = logging_nodes.get(caller_id, None)
if node is None:
raise ValueError(
f"Can't find node id in {caller}. Only initialized steps can log."
)
node_id, _ = node
view_manager.handle_prompt(node_id, prompt)
54 changes: 40 additions & 14 deletions graphbook/processing/web_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from graphbook.steps import Step, SourceStep, GeneratorSourceStep, AsyncStep, StepOutput
from graphbook.steps import (
Step,
SourceStep,
GeneratorSourceStep,
AsyncStep,
BatchStep,
StepOutput,
)
from graphbook.dataloading import Dataloader, setup_global_dl
from graphbook.utils import MP_WORKER_TIMEOUT, ProcessorStateRequest, transform_json_log
from graphbook.state import GraphState, StepState, NodeInstantiationError
Expand Down Expand Up @@ -49,7 +56,10 @@ def __init__(
self.dataloader = Dataloader(self.num_workers)
setup_global_dl(self.dataloader)
self.state_client = ProcessorStateClient(
server_request_conn, close_event, self.graph_state, self.dataloader
server_request_conn,
close_event,
self.graph_state,
self.dataloader,
)
self.is_running = False
self.filename = None
Expand Down Expand Up @@ -120,13 +130,11 @@ def exec_step(
id(step),
)
return None

self.handle_images(outputs)
self.graph_state.handle_outputs(
step.id, outputs if not self.copy_outputs else copy.deepcopy(outputs)
)
self.view_manager.handle_outputs(step.id, transform_json_log(outputs))
self.view_manager.handle_time(step.id, time.time() - start_time)
self.view_manager.handle_time(step.id, time.time() - start_time)
return outputs

def handle_steps(self, steps: List[Step]) -> bool:
Expand Down Expand Up @@ -182,13 +190,25 @@ def step_until_received_output(self, steps: List[Step], step_id: str):
step_executed = self.graph_state.get_state(
step_id, StepState.EXECUTED_THIS_RUN
)

def try_execute_step_event(self, step: Step, event: str):
try:
if hasattr(step, event):
getattr(step, event)()
return True
except Exception as e:
log(f"{type(e).__name__}: {str(e)}", "error", id(step))
traceback.print_exc()
return False

def run(self, step_id: str = None):
steps: List[Step] = self.graph_state.get_processing_steps(step_id)
self.setup_dataloader(steps)
for step in steps:
self.view_manager.handle_start(step.id)
step.on_start()
succeeded = self.try_execute_step_event(step, "on_start")
if not succeeded:
return
self.setup_dataloader(steps)
self.pause_event.clear()
dag_is_active = True
try:
Expand All @@ -201,24 +221,26 @@ def run(self, step_id: str = None):
dag_is_active = self.handle_steps(steps)
finally:
self.view_manager.handle_end()
for step in steps:
step.on_end()
self.dataloader.stop()
for step in steps:
self.try_execute_step_event(step, "on_end")

def step(self, step_id: str = None):
steps: List[Step] = self.graph_state.get_processing_steps(step_id)
self.setup_dataloader(steps)
for step in steps:
self.view_manager.handle_start(step.id)
step.on_start()
succeeded = self.try_execute_step_event(step, "on_start")
if not succeeded:
return
self.setup_dataloader(steps)
self.pause_event.clear()
try:
self.step_until_received_output(steps, step_id)
finally:
self.view_manager.handle_end()
for step in steps:
step.on_end()
self.dataloader.stop()
for step in steps:
self.try_execute_step_event(step, "on_end")

def set_is_running(self, is_running: bool = True, filename: str | None = None):
self.is_running = is_running
Expand All @@ -232,7 +254,7 @@ def cleanup(self):
self.dataloader.shutdown()

def setup_dataloader(self, steps: List[Step]):
dataloader_consumers = [step for step in steps if isinstance(step, AsyncStep)]
dataloader_consumers = [step for step in steps if isinstance(step, BatchStep)]
consumer_ids = [id(c) for c in dataloader_consumers]
consumer_load_fn = [
c.load_fn if hasattr(c, "load_fn") else None for c in dataloader_consumers
Expand Down Expand Up @@ -323,6 +345,10 @@ def _loop(self):
output = self.dataloader.get_all_sizes()
elif req["cmd"] == ProcessorStateRequest.GET_RUNNING_STATE:
output = self.running_state
elif req["cmd"] == ProcessorStateRequest.PROMPT_RESPONSE:
step_id = req.get("step_id")
succeeded = self.graph_state.handle_prompt_response(step_id, req.get("response"))
output = {"ok": succeeded}
else:
output = {}
entry = {"res": req["cmd"], "data": output}
Expand Down
Loading

0 comments on commit 4f3d6cd

Please sign in to comment.