-
Notifications
You must be signed in to change notification settings - Fork 13
/
StableDiffusionWebUI.py
422 lines (357 loc) · 15.1 KB
/
StableDiffusionWebUI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
# Stablender Diffusion 0.0.1
# author: @shellworld
# license: MIT
# imports
import uuid
import bpy
import os
import sys
import requests
import json
from base64 import b64decode, b64encode
import io
from PIL import Image
import cv2
import numpy as np
from datetime import datetime
from enum import Enum
# constants
# function indices
class FunctionIndex(Enum):
TEXT_TO_IMAGE = 4
IMAGE_TO_IMAGE = 17
IMAGE_TO_IMAGE_WITH_MASK = 16
# set vars
# comment out in operators for accepting input variables
# replace with gradio url
url = 'https://#####.gradio.app'
prompt = "landscape painting, oil on cavas, high detail 4k"
cfg_scale = 7.5
width = 512
height = 512
steps = 50
num_batches = 1
batch_size = 1
strength = .71
# set headers for requests to StableDiffusion Web Ui
headers = {
"authority": f"{url}",
"method": "POST",
"path": "/api/predict/",
"scheme": "https",
"accept": "*/*",
"accept-encoding": "gzip, deflate, br",
"accept-language": "en-US,en;q=0.9",
"dnt": "1",
"origin": f"{url}",
"referer": f"{url}",
"sec-ch-ua": "`\"Chromium`\";v=\"104\", `\" Not A;Brand`\";v=\"99\", `\"Google Chrome`\";v=\"104\"\"",
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": "`\"Windows`\"",
"sec-fetch-dest": "empty",
"sec-fetch-mode": "cors",
"sec-fetch-site": "same-origin"
}
# constant static variables
MAX_STEPS = 400
MIN_STEPS = 10
MAX_PROMPT_LENGTH = 1000
CFG_SCALE_FLOOR = 0.1
CFG_SCALE_CEILING = 10.0
SIZE_MIN = 512
SIZE_MAX = 768
HIGH_BATCH_NUMBER = 16
HIGH_BATCH_NUMBER_ENABLED = False
MAX_LENGTH_FILE_NAME=24
# predict
def predict(prompt:str, steps:int, cfg_scale:float, width:int, height:int, num_batches=1, batch_size=1, image_string="", mask_string="", function_index=FunctionIndex.TEXT_TO_IMAGE,strength=.50):
# get prediction
# set function index based on input
# this will grab the corresponding enum value
fn_index = function_index.value
# define validation functions
def checkPrompt(prompt):
#prompt is not null or empty
if prompt is None or prompt == "":
raise (Exception("prompt is required"))
# prompt is a valid string
if type(prompt) != str:
raise (Exception("prompt must be a string"))
# prompt is not too long
if len(prompt) > MAX_PROMPT_LENGTH:
raise (
Exception(f"prompt must be less than {str(MAX_PROMPT_LENGTH)} characters"))
def checkSteps(steps):
if steps is None:
raise (Exception("steps is required"))
if type(steps) != int:
raise (Exception("steps must be an integer"))
if steps < MIN_STEPS:
raise (Exception(f"steps must be greater than {str(MIN_STEPS)}"))
if steps > MAX_STEPS:
#steps is not null or empty
if steps is None or steps == "":
raise (Exception("steps are required"))
# steps is a valid integer
if type(steps) != int:
raise (Exception("steps must be an integer"))
# steps are within the range of 10 to 400
if steps < MIN_STEPS or steps > MAX_STEPS:
raise (
Exception(f"steps must be between {str(MIN_STEPS)} and {str(MAX_STEPS)}"))
def checkCfgScale(cfg_scale):
#cfg_scale is not null or empty
if cfg_scale is None or cfg_scale == "":
raise (Exception("cfg_scale is required"))
# cfg_scale is a valid float
if type(cfg_scale) != float:
raise (Exception("cfg_scale must be a float"))
# cfg_scale is within the range of -CFG_SCALE_FLOOR to CFG_SCALE_CEILING
if cfg_scale < CFG_SCALE_FLOOR or cfg_scale > CFG_SCALE_CEILING:
raise (Exception(
f"cfg_scale must be between {str(CFG_SCALE_FLOOR)} and {str(CFG_SCALE_CEILING)}"))
def checkWidth(width):
#width is not null or empty
if width is None or width == "":
raise (Exception("width is required"))
# width is a valid integer
if type(width) != int:
raise (Exception("width must be an integer"))
# width is greater than SIZE_MIN and less than SIZE_MAX
if width < SIZE_MIN or width > SIZE_MAX:
raise (
Exception(f"width must be between {str(SIZE_MIN)}px and {str(SIZE_MAX)}px"))
def checkHeight(height):
#height is not null or empty
if height is None or height == "":
raise (Exception("height is required"))
# height is a valid integer
if type(height) != int:
raise (Exception("height must be an integer"))
# height is greater than SIZE_MIN and less than SIZE_MAX
if height < SIZE_MIN or height > SIZE_MAX:
raise (
Exception(f"height must be between {str(SIZE_MIN)}px and {str(SIZE_MAX)}px"))
def checkNumBatches(num_batches):
#num_batches is not null or empty
if num_batches is None or num_batches == "":
raise (Exception("num_batches is required"))
# num_batches is a valid integer greater than 0
if type(num_batches) != int or num_batches <= 0:
raise (Exception("num_batches must be an integer greater than zero"))
# num_batches is greater than 16
if num_batches > HIGH_BATCH_NUMBER and not HIGH_BATCH_NUMBER_ENABLED:
raise (Exception(
f"num_batches must be less than {str(HIGH_BATCH_NUMBER)} or set HIGH_BATCH_NUMBER_ENABLED to True"))
# Developer Note:
# It might make sense to disable the batch size option
# in your interface.
def checkBatchSize(batch_size):
if batch_size is None or batch_size == "":
raise (Exception("batch_size is required"))
# batch_size is a valid integer greater than 0
if type(batch_size) != int or batch_size <= 0:
raise (Exception("batch_size must be an integer greater than zero"))
# batch_size is greater than 1
if batch_size > 1:
raise (Exception(f"batch_size must be set to one for now."))
def checkImageString(image_string):
if image_string is not None or image_string != "":
# image_string is a valid string
if type(image_string) != str:
print(type(image_string))
raise (Exception(f"image_string must be a string, but it is currently {type(image_string)}"))
# image string is a valid base64 png datastream
if not image_string.startswith("data:image/png;base64,"):
print(f"DEBUG:{image_string}")
raise (Exception("image_string must be a valid base64 png datastream"))
def checkMaskString(mask_string):
if mask_string is not None or mask_string != "":
# mask_string is a valid string
if type(mask_string) != str:
print(type(image_string))
raise (Exception("mask_string must be a string"))
# mask string is a valid base64 png datastream
if not mask_string.startswith("data:image/png;base64,"):
print(f"DEBUG:{image_string}")
raise (Exception("mask_string must be a valid base64 png datastream"))
def checkStrength(strength):
if strength is not None or strength != "":
# strength is a valid float
if type(strength) != float:
raise (Exception("strength must be a float"))
# strength is within the range of 0 to 1
if strength < 0 or strength > 1:
raise (Exception(f"strength must be between 0 and 1"))
def validate(prompt, steps, cfg_scale, width, height, num_batches, batch_size, image_string, mask_string, strength):
checkPrompt(prompt)
checkSteps(steps)
checkCfgScale(cfg_scale)
checkWidth(width)
checkHeight(height)
checkNumBatches(num_batches)
checkBatchSize(batch_size)
if function_index == FunctionIndex.IMAGE_TO_IMAGE or function_index == FunctionIndex.IMAGE_TO_IMAGE_WITH_MASK:
checkImageString(image_string)
checkStrength(strength)
if function_index== FunctionIndex.IMAGE_TO_IMAGE_WITH_MASK:
checkMaskString(mask_string)
def validateResponse(response):
if response is None or response == "":
raise (Exception(f"Something went wrong: {response}"))
if response.status_code != 200:
raise (Exception(f"Something went wrong: {response}"))
# validate the input
validate(prompt, steps, cfg_scale, width, height, num_batches,
batch_size, image_string, mask_string, strength)
# depending on which mode we are in, we need to set the data differently
# txt2img
if function_index == FunctionIndex.TEXT_TO_IMAGE:
data = [prompt, steps, "k_lms",
["Normalize Prompt Weights (ensure sum of weights add up to 1.0)",
"Save individual images",
"Save grid",
"Sort samples by prompt",
"Write sample info files"],
"RealESRGAN_x4plus", 0, num_batches, batch_size,
cfg_scale, "", width, height, None, 0, ""
]
# img2img
elif function_index == FunctionIndex.IMAGE_TO_IMAGE:
data = [prompt, "Crop", image_string, "Keep masked area",
3, steps, "k_lms",
["Normalize Prompt Weights (ensure sum of weights add up to 1.0)",
"Save individual images",
"Save grid",
"Sort samples by prompt",
"Write sample info files"],
"RealESRGAN_x4plus",
num_batches, batch_size, cfg_scale, strength,
None, width, height, "Just resize", None
]
# img2img with mask
elif function_index == FunctionIndex.IMAGE_TO_IMAGE_WITH_MASK:
data = [prompt, "Mask", {"image": image_string, "mask": mask_string},
"Keep masked area", 3, steps, "k_lms",
["Normalize Prompt Weights (ensure sum of weights add up to 1.0)",
"Save individual images",
"Save grid",
"Sort samples by prompt",
"Write sample info files"],
"RealESRGAN_x4plus",
num_batches, batch_size, cfg_scale, strength,
None, width, height, "Just resize", None
]
response = requests.post(url + '/api/predict/', headers=headers, json={
"fn_index": fn_index,
"data":data
})
# validate the response
validateResponse(response)
return response.json()
# utility classes
# function that converts png datastream to Image
def stringToRGB(base64_string:str):
header, encoded = base64_string.split(",", 1)
imgdata = b64decode(encoded)
img = Image.open(io.BytesIO(imgdata))
return img
# function that converts Image to png datastream
def rgbToString(image:Image.Image):
img = image.pixels
img = Image.fromarray(img.reshape(image.size[1], image.size[0], 4))
img_bytes = io.BytesIO()
img.save(img_bytes, format='PNG')
img_bytes.seek(0)
img_string = b64encode(img_bytes.read()).decode('ascii')
return img_string
def parseResults(response):
#return array of b64 strings from response[data][0]
# get the data property's zeroth index
data = response["data"][0]
#if the length of data is zero, raise an exception
if len(data) == 0:
raise (Exception("No data returned"))
elif len(data) == 1:
# if the length of data is one, then we have a single image
# return the image
return [data[0]]
else:
images = []
for i in range(len(data)):
# if data>1, the first image will be a grid, ignore this image
if i == 0:
continue
# return multiple images
else:
images.append(data[i])
return images
def convertPNGDatastreamsToBPYImages(base64_strings:str):
images = []
#convert to CV2 images
for base64_string in base64_strings:
images.append(stringToRGB(base64_string))
#now convert the cv2 images to bpy images
for i in range(len(images)):
images[i] = bpy.data.images.new(f"image_{i}", images[i].size[0], images[i].size[1], alpha=True)
images[i].pixels = images[i].pixels[:]
images[i].filepath_raw = f"image_{i}.png"
images[i].file_format = 'PNG'
return images
def convertBPYImageToBase64PNGDataStream(referenceToBPYImage):
selectedImage = bpy.data.images[referenceToBPYImage]
#convert the bpy image to a cv2 image
img = cv2.imread(f"{bpy.path.abspath(selectedImage.filepath_raw)}")
#convert the cv2 image to a b64 string
_, im_arr = cv2.imencode('.jpg', img) # im_arr: image in Numpy one-dim array format.
im_bytes = im_arr.tobytes()
im_b64 = b64encode(im_bytes).decode('ascii')
im_b64 = "data:image/png;base64," + im_b64
return im_b64
def pil_to_image(pil_image, name='NewImage'):
'''
PIL image pixels is 2D array of byte tuple (when mode is 'RGB', 'RGBA') or byte (when mode is 'L')
bpy image pixels is flat array of normalized values in RGBA order
'''
# setup PIL image reading
# swap red and blue channels
pil_image = pil_image.transpose(Image.FLIP_TOP_BOTTOM)
# convert to bpy image
width = pil_image.width
height = pil_image.height
pil_pixels = pil_image.load()
byte_to_normalized = 1.0 / 255.0
num_pixels = width * height
# setup bpy image
channels = 4
bpy_image = bpy.data.images.new(name, width=width, height=height)
# bpy image has a flat RGBA array (similar to JS Canvas)
bpy_image.pixels = (np.asarray(pil_image.convert('RGBA'),dtype=np.float32) * byte_to_normalized).ravel()
return bpy_image
#save bpy images
def saveImage(image_data, filename):
img = stringToRGB(image_data)
img.save(filename)
return img
# generate a safe filename
def generateSafeNameFromPromptAndIndex(prompt, index):
prompt = prompt.replace(" ", "_")
prompt = prompt.replace(",", "-")
prompt = prompt[0:MAX_LENGTH_FILE_NAME]
return prompt + "_" + str(index)
## MAIN FUNCTIONS
# function that converts base64 png datastream to Image
def requestImg(prompt, steps, cfg_scale, width, height, fn, bpy_image=None, mask_image=None, batch_num=1, batch_size=1, strength=50):
#request the image
response = predict(prompt, steps, cfg_scale, width, height,
num_batches, batch_size,bpy_image, mask_image, fn)
#results will be a list of base64 strings
results = parseResults(response)
img = saveImage(results[0], f"{generateSafeNameFromPromptAndIndex(prompt, 0)}.png")
#add image to bpy.data.images
blender_image = pil_to_image(img, f"{generateSafeNameFromPromptAndIndex(prompt, 0)}")
return f"{generateSafeNameFromPromptAndIndex(prompt, 0)}"
#Text to image example
requestImg(prompt, steps, cfg_scale, width, height, FunctionIndex.TEXT_TO_IMAGE)
#Image to image example (replace V with the name of the blender texture)
#requestImg(prompt, steps, cfg_scale, width, height, FunctionIndex.IMAGE_TO_IMAGE, bpy_image=convertBPYImageToBase64PNGDataStream("V"), strength=strength)