Skip to content

Commit

Permalink
Exclude control images from img2img input
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 17, 2023
1 parent 7ec6ee9 commit 7aca2b0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
19 changes: 12 additions & 7 deletions ai_diffusion/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,24 @@ def create_mask_from_selection(self, grow: float, feather: float, padding: float
data = selection.pixelData(*bounds)
return Mask(bounds, data)

def get_image(self, bounds: Bounds | None = None, exclude_layer: krita.Node | None = None):
restore_layer = False
if exclude_layer and exclude_layer.visible():
exclude_layer.setVisible(False)
def get_image(
self, bounds: Bounds | None = None, exclude_layers: list[krita.Node] | None = None
):
excluded: list[krita.Node] = []
if exclude_layers:
for layer in filter(lambda l: l.visible(), exclude_layers):
layer.setVisible(False)
excluded.append(layer)
if len(excluded) > 0:
# This is quite slow and blocks the UI. Maybe async spinning on tryBarrierLock works?
self._doc.refreshProjection()
restore_layer = True

bounds = bounds or Bounds(0, 0, self._doc.width(), self._doc.height())
img = QImage(self._doc.pixelData(*bounds), *bounds.extent, QImage.Format_ARGB32)

if exclude_layer and restore_layer:
exclude_layer.setVisible(True)
for layer in excluded:
layer.setVisible(True)
if len(excluded) > 0:
self._doc.refreshProjection()
return Image(img)

Expand Down
12 changes: 10 additions & 2 deletions ai_diffusion/ui/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import deque
from datetime import datetime
from enum import Enum, Flag
from typing import Deque, List, Sequence, NamedTuple, Optional, Callable
from typing import Deque, Optional, cast
from PyQt5.QtCore import Qt, QObject, pyqtSignal

from .. import eventloop, Document, workflow, NetworkError, settings, util
Expand Down Expand Up @@ -170,7 +170,7 @@ def generate(self):
)
image_bounds = workflow.compute_bounds(extent, mask.bounds if mask else None, self.strength)
if mask is not None or self.strength < 1.0:
image = self._doc.get_image(image_bounds, exclude_layer=self._layer)
image = self._get_current_image(image_bounds)

control = [self._get_control_image(c, image_bounds) for c in self.control]
conditioning = Conditioning(self.prompt, control)
Expand Down Expand Up @@ -217,6 +217,14 @@ async def _generate(
self.jobs.add(job_id, conditioning.prompt, bounds)
self.changed.emit()

def _get_current_image(self, bounds: Bounds):
exclude = [ # exclude control inputs
cast(krita.Node, c.image) for c in self.control if c.mode is not ControlMode.image
]
if self._layer: # exclude preview layer
exclude.append(self._layer)
return self._doc.get_image(bounds, exclude_layers=exclude)

def _get_control_image(self, control: Control, bounds: Optional[Bounds]):
if control.mode is ControlMode.image:
bounds = None # ignore mask bounds, use layer bounds
Expand Down

0 comments on commit 7aca2b0

Please sign in to comment.