From 7aca2b0549cfbd9fe96bdcf97fc346067029ed83 Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 17 Oct 2023 23:55:11 +0200 Subject: [PATCH] Exclude control images from img2img input --- ai_diffusion/document.py | 19 ++++++++++++------- ai_diffusion/ui/model.py | 12 ++++++++++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/ai_diffusion/document.py b/ai_diffusion/document.py index 83c7e03654..0832e6317d 100644 --- a/ai_diffusion/document.py +++ b/ai_diffusion/document.py @@ -67,19 +67,24 @@ def create_mask_from_selection(self, grow: float, feather: float, padding: float data = selection.pixelData(*bounds) return Mask(bounds, data) - def get_image(self, bounds: Bounds | None = None, exclude_layer: krita.Node | None = None): - restore_layer = False - if exclude_layer and exclude_layer.visible(): - exclude_layer.setVisible(False) + def get_image( + self, bounds: Bounds | None = None, exclude_layers: list[krita.Node] | None = None + ): + excluded: list[krita.Node] = [] + if exclude_layers: + for layer in filter(lambda l: l.visible(), exclude_layers): + layer.setVisible(False) + excluded.append(layer) + if len(excluded) > 0: # This is quite slow and blocks the UI. Maybe async spinning on tryBarrierLock works? self._doc.refreshProjection() - restore_layer = True bounds = bounds or Bounds(0, 0, self._doc.width(), self._doc.height()) img = QImage(self._doc.pixelData(*bounds), *bounds.extent, QImage.Format_ARGB32) - if exclude_layer and restore_layer: - exclude_layer.setVisible(True) + for layer in excluded: + layer.setVisible(True) + if len(excluded) > 0: self._doc.refreshProjection() return Image(img) diff --git a/ai_diffusion/ui/model.py b/ai_diffusion/ui/model.py index bbbc95dfd4..b9ea76faf0 100644 --- a/ai_diffusion/ui/model.py +++ b/ai_diffusion/ui/model.py @@ -3,7 +3,7 @@ from collections import deque from datetime import datetime from enum import Enum, Flag -from typing import Deque, List, Sequence, NamedTuple, Optional, Callable +from typing import Deque, Optional, cast from PyQt5.QtCore import Qt, QObject, pyqtSignal from .. import eventloop, Document, workflow, NetworkError, settings, util @@ -170,7 +170,7 @@ def generate(self): ) image_bounds = workflow.compute_bounds(extent, mask.bounds if mask else None, self.strength) if mask is not None or self.strength < 1.0: - image = self._doc.get_image(image_bounds, exclude_layer=self._layer) + image = self._get_current_image(image_bounds) control = [self._get_control_image(c, image_bounds) for c in self.control] conditioning = Conditioning(self.prompt, control) @@ -217,6 +217,14 @@ async def _generate( self.jobs.add(job_id, conditioning.prompt, bounds) self.changed.emit() + def _get_current_image(self, bounds: Bounds): + exclude = [ # exclude control inputs + cast(krita.Node, c.image) for c in self.control if c.mode is not ControlMode.image + ] + if self._layer: # exclude preview layer + exclude.append(self._layer) + return self._doc.get_image(bounds, exclude_layers=exclude) + def _get_control_image(self, control: Control, bounds: Optional[Bounds]): if control.mode is ControlMode.image: bounds = None # ignore mask bounds, use layer bounds