Skip to content

Commit

Permalink
Send text from custom workflows to Krita #1285
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 25, 2024
1 parent 180b004 commit b9a1d8a
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 17 deletions.
23 changes: 17 additions & 6 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,30 @@ class ClientEvent(Enum):
queued = 6
upload = 7
published = 8
output = 9


class TextOutput(NamedTuple):
key: str
name: str
text: str
mime: str


class SharedWorkflow(NamedTuple):
publisher: str
workflow: dict


ClientOutput = dict | SharedWorkflow | TextOutput


class ClientMessage(NamedTuple):
event: ClientEvent
job_id: str = ""
progress: float = 0
images: ImageCollection | None = None
result: dict | SharedWorkflow | None = None
result: ClientOutput | None = None
error: str | None = None


Expand Down Expand Up @@ -69,11 +85,6 @@ def parse(data: dict):
return DeviceInfo("cpu", "unknown", 0)


class SharedWorkflow(NamedTuple):
publisher: str
workflow: dict


class CheckpointInfo(NamedTuple):
filename: str
arch: Arch
Expand Down
37 changes: 32 additions & 5 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .api import WorkflowInput
from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels
from .client import SharedWorkflow, TranslationPackage, ClientFeatures
from .client import SharedWorkflow, TranslationPackage, ClientFeatures, TextOutput
from .client import filter_supported_styles, loras_to_upload
from .files import FileFormat
from .image import Image, ImageCollection
Expand Down Expand Up @@ -308,10 +308,13 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
log.error(f"Received message {msg} but there is no active job")

if msg["type"] == "executed":
job = self._get_active_job(msg["data"]["prompt_id"])
pose_json = _extract_pose_json(msg)
if job and pose_json:
result = pose_json
if job := self._get_active_job(msg["data"]["prompt_id"]):
text_output = _extract_text_output(job.local_id, msg)
if text_output is not None:
await self._messages.put(text_output)
pose_json = _extract_pose_json(msg)
if pose_json is not None:
result = pose_json

if msg["type"] == "execution_error":
job = self._get_active_job(msg["data"]["prompt_id"])
Expand Down Expand Up @@ -733,3 +736,27 @@ def _extract_pose_json(msg: dict):
except Exception as e:
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
return None


def _extract_text_output(job_id: str, msg: dict):
try:
output = msg["data"]["output"]
if output is not None and "text" in output:
key = msg["data"].get("node")
payload = output["text"]
name, text, mime = (None, None, "text/plain")
if isinstance(payload, list) and len(payload) >= 1:
payload = payload[0]
if isinstance(payload, dict):
text = payload.get("text")
name = payload.get("name")
mime = payload.get("content-type", mime)
elif isinstance(payload, str):
text = payload
name = f"Node {key}"
if text is not None and name is not None:
result = TextOutput(key, name, text, mime)
return ClientMessage(ClientEvent.output, job_id, result=result)
except Exception as e:
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
return None
17 changes: 17 additions & 0 deletions ai_diffusion/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from PyQt5.QtCore import pyqtSignal

from .api import WorkflowInput
from .client import TextOutput, ClientOutput
from .comfy_workflow import ComfyWorkflow, ComfyNode
from .connection import Connection, ConnectionState
from .image import Bounds, Image
Expand Down Expand Up @@ -323,6 +324,7 @@ class CustomWorkspace(QObject, ObservableProperties):
mode = Property(CustomGenerationMode.regular, setter="_set_mode")
is_live = Property(False, setter="toggle_live")
has_result = Property(False)
outputs = Property({})

workflow_id_changed = pyqtSignal(str)
graph_changed = pyqtSignal()
Expand All @@ -331,6 +333,7 @@ class CustomWorkspace(QObject, ObservableProperties):
is_live_changed = pyqtSignal(bool)
result_available = pyqtSignal(Image)
has_result_changed = pyqtSignal(bool)
outputs_changed = pyqtSignal(dict)
modified = pyqtSignal(QObject, str)

_live_poll_rate = 0.1
Expand All @@ -345,6 +348,7 @@ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, job
self._last_input: WorkflowInput | None = None
self._last_result: Image | None = None
self._last_job: JobParams | None = None
self._new_outputs: list[str] = []

jobs.job_finished.connect(self._handle_job_finished)
workflows.dataChanged.connect(self._update_workflow)
Expand Down Expand Up @@ -463,7 +467,20 @@ def collect_parameters(self, layers: "LayerManager", bounds: Bounds):

return params

def show_output(self, output: ClientOutput | None):
if isinstance(output, TextOutput):
self._new_outputs.append(output.key)
self.outputs[output.key] = output
self.outputs_changed.emit(self.outputs)

def _handle_job_finished(self, job: Job):
to_remove = [k for k in self.outputs.keys() if k not in self._new_outputs]
for key in to_remove:
del self.outputs[key]
if len(to_remove) > 0:
self.outputs_changed.emit(self.outputs)
self._new_outputs.clear()

if job.kind is JobKind.live_preview:
if len(job.results) > 0:
self._last_result = job.results[0]
Expand Down
6 changes: 4 additions & 2 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .settings import ApplyBehavior, settings
from .network import NetworkError
from .image import Extent, Image, Mask, Bounds, DummyImage
from .client import ClientMessage, ClientEvent, SharedWorkflow
from .client import ClientMessage, ClientEvent, ClientOutput
from .client import filter_supported_styles, resolve_arch
from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode
from .document import Document, KritaDocument
Expand Down Expand Up @@ -461,6 +461,8 @@ def handle_message(self, message: ClientMessage):
self.jobs.notify_started(job)
self.progress_kind = ProgressKind.upload
self.progress = message.progress
elif message.event is ClientEvent.output:
self.custom.show_output(message.result)
elif message.event is ClientEvent.finished:
if message.images:
self.jobs.set_results(job, message.images)
Expand Down Expand Up @@ -604,7 +606,7 @@ def apply_generated_result(self, job_id: str, index: int):
self.jobs.selection = None
self.jobs.notify_used(job_id, index)

def add_control_layer(self, job: Job, result: dict | SharedWorkflow | None):
def add_control_layer(self, job: Job, result: ClientOutput | None):
assert job.kind is JobKind.control_layer and job.control
if job.control.mode is ControlMode.pose and isinstance(result, (dict, list)):
pose = Pose.from_open_pose_json(result)
Expand Down
103 changes: 101 additions & 2 deletions ai_diffusion/ui/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from PyQt5.QtWidgets import QComboBox, QFileDialog, QFrame, QGridLayout, QHBoxLayout, QMenu
from PyQt5.QtWidgets import QLabel, QLineEdit, QListWidgetItem, QMessageBox, QSpinBox, QAction
from PyQt5.QtWidgets import QToolButton, QVBoxLayout, QWidget, QSlider, QDoubleSpinBox
from PyQt5.QtWidgets import QScrollArea, QTextEdit, QSizePolicy

from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource
from ..custom_workflow import CustomGenerationMode
from ..client import TextOutput
from ..jobs import JobKind
from ..model import Model, ApplyBehavior
from ..properties import Binding, Bind, bind, bind_combo
Expand All @@ -22,6 +24,7 @@
from .live import LivePreviewArea
from .switch import SwitchWidget
from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget
from .settings_widgets import ExpanderButton
from . import theme


Expand Down Expand Up @@ -384,6 +387,93 @@ def value(self, values: dict[str, Any]):
widget.value = value


class WorkflowOutputsWidget(QWidget):
def __init__(self, parent: QWidget):
super().__init__(parent)
self._value: dict[str, TextOutput] = {}

self._scroll_area = QScrollArea(self)
self._scroll_area.setFrameShape(QFrame.Shape.NoFrame)
self._scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)

self.expander = ExpanderButton(_("Text Output"), self)
self.expander.setStyleSheet("QToolButton { border: none; }")
self.expander.setChecked(True)
self.expander.toggled.connect(self._scroll_area.setVisible)

layout = QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
layout.addWidget(self.expander)
layout.addWidget(self._scroll_area)

@property
def value(self):
return self._value

@value.setter
def value(self, value: dict[str, TextOutput]):
self._value = value
self._update()

def _update(self):
if len(self._value) == 0:
self.expander.hide()
self._scroll_area.hide()
return
elif not self.expander.isVisible():
self.expander.show()
self._scroll_area.show()

widget = QWidget(self._scroll_area)
layout = QGridLayout()
layout.setContentsMargins(0, 0, 0, 0)
layout.setColumnMinimumWidth(1, 8)
layout.setColumnStretch(2, 1)
widget.setLayout(layout)

line = 0
text_areas: list[QTextEdit] = []
for output in self._value.values():
label = QLabel(output.name, widget)
if (not output.mime or output.mime == "text/plain") and len(output.text) < 40:
value = QLabel(output.text, widget)
value.setWordWrap(True)
value.setMinimumWidth(40)
layout.addWidget(label, line, 0)
layout.addWidget(value, line, 2)
line += 1
else:
value = QTextEdit(widget)
value.setFrameShape(QFrame.Shape.StyledPanel)
value.setStyleSheet(
"QTextEdit { background: transparent; border-left: 1px solid %s; padding-left: 2px; }"
% theme.line
)
value.setReadOnly(True)
match output.mime:
case "" | "text/plain":
value.setPlainText(output.text)
case "text/html":
value.setHtml(output.text)
case "text/markdown":
value.setMarkdown(output.text)
layout.addWidget(label, line, 0, 1, 3)
layout.addWidget(value, line + 1, 0, 1, 3)
text_areas.append(value)
line += 2

layout.setRowStretch(line, 1)
widget.setFixedWidth(self._scroll_area.width() - 8)
self._scroll_area.setWidget(widget)
if self.expander.isChecked():
widget.show()

for w in text_areas:
size = ensure(w.document()).size().toSize()
w.setFixedHeight(max(size.height() + 2, self.fontMetrics().height() + 6))
widget.adjustSize()


def popup_on_error(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
Expand Down Expand Up @@ -487,6 +577,9 @@ def __init__(self):
self._progress_bar = ProgressBar(self)
self._error_text = create_error_label(self)

self._outputs = WorkflowOutputsWidget(self)
self._outputs.expander.toggled.connect(self._update_layout)

self._history = HistoryWidget(self)
self._history.item_activated.connect(self.apply_result)

Expand Down Expand Up @@ -525,12 +618,17 @@ def __init__(self):
self._layout.addLayout(actions_layout)
self._layout.addWidget(self._progress_bar)
self._layout.addWidget(self._error_text)
self._layout.addWidget(self._history)
self._layout.addWidget(self._live_preview)
self._layout.addWidget(self._outputs, stretch=1)
self._layout.addWidget(self._history, stretch=3)
self._layout.addWidget(self._live_preview, stretch=5)
self.setLayout(self._layout)

self._update_ui()

def _update_layout(self):
stretch = 1 if self._outputs.expander.isChecked() else 0
self._layout.setStretchFactor(self._outputs, stretch)

@property
def model(self):
return self._model
Expand All @@ -543,6 +641,7 @@ def model(self, model: Model):
self._model_bindings = [
bind(model, "workspace", self._workspace_select, "value", Bind.one_way),
bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way),
bind(model.custom, "outputs", self._outputs, "value", Bind.one_way),
model.workspace_changed.connect(self._cancel_name),
model.custom.graph_changed.connect(self._update_current_workflow),
model.error_changed.connect(self._error_text.setText),
Expand Down
Loading

0 comments on commit b9a1d8a

Please sign in to comment.