Skip to content

Commit

Permalink
Use GooeyEnum for ControlNetModels
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Sep 13, 2024
1 parent bee12c4 commit c8f296d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 58 deletions.
13 changes: 6 additions & 7 deletions daras_ai_v2/img_model_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
InpaintingModels,
Img2ImgModels,
ControlNetModels,
controlnet_model_explanations,
Schedulers,
)

Expand Down Expand Up @@ -130,9 +129,7 @@ def controlnet_settings(
if not models:
return

if extra_explanations is None:
extra_explanations = {}
explanations = controlnet_model_explanations | extra_explanations
extra_explanations = extra_explanations or {}

state_values = gui.session_state.get("controlnet_conditioning_scale", [])
new_values = []
Expand All @@ -157,7 +154,9 @@ def controlnet_settings(
pass
new_values.append(
controlnet_weight_setting(
selected_controlnet_model=model, explanations=explanations, key=key
selected_controlnet_model=model,
extra_explanations=extra_explanations,
key=key,
),
)
gui.session_state["controlnet_conditioning_scale"] = new_values
Expand All @@ -166,13 +165,13 @@ def controlnet_settings(
def controlnet_weight_setting(
*,
selected_controlnet_model: str,
explanations: dict[ControlNetModels, str],
extra_explanations: dict[ControlNetModels, str],
key: str = "controlnet_conditioning_scale",
):
model = ControlNetModels[selected_controlnet_model]
return gui.slider(
label=f"""
{explanations[model]}.
{extra_explanations.get(model, model.explanation)}.
""",
key=key,
min_value=CONTROLNET_CONDITIONING_SCALE_RANGE[0],
Expand Down
105 changes: 63 additions & 42 deletions daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
resize_img_fit,
get_downscale_factor,
)
from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.exceptions import (
raise_for_status,
UserError,
Expand Down Expand Up @@ -119,47 +120,68 @@ def _deprecated(cls):
}


class ControlNetModels(Enum):
sd_controlnet_canny = "Canny"
sd_controlnet_depth = "Depth"
sd_controlnet_hed = "HED Boundary"
sd_controlnet_mlsd = "M-LSD Straight Line"
sd_controlnet_normal = "Normal Map"
sd_controlnet_openpose = "Human Pose"
sd_controlnet_scribble = "Scribble"
sd_controlnet_seg = "Image Segmentation"
sd_controlnet_tile = "Tiling"
sd_controlnet_brightness = "Brightness"
control_v1p_sd15_qrcode_monster_v2 = "QR Monster V2"


controlnet_model_explanations = {
ControlNetModels.sd_controlnet_canny: "Canny edge detection",
ControlNetModels.sd_controlnet_depth: "Depth estimation",
ControlNetModels.sd_controlnet_hed: "HED edge detection",
ControlNetModels.sd_controlnet_mlsd: "M-LSD straight line detection",
ControlNetModels.sd_controlnet_normal: "Normal map estimation",
ControlNetModels.sd_controlnet_openpose: "Human pose estimation",
ControlNetModels.sd_controlnet_scribble: "Scribble",
ControlNetModels.sd_controlnet_seg: "Image segmentation",
ControlNetModels.sd_controlnet_tile: "Tiling: to preserve small details",
ControlNetModels.sd_controlnet_brightness: "Brightness: to increase contrast naturally",
ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose",
}
class ControlNetModel(typing.NamedTuple):
label: str
model_id: str
explanation: str

controlnet_model_ids = {
ControlNetModels.sd_controlnet_canny: "lllyasviel/sd-controlnet-canny",
ControlNetModels.sd_controlnet_depth: "lllyasviel/sd-controlnet-depth",
ControlNetModels.sd_controlnet_hed: "lllyasviel/sd-controlnet-hed",
ControlNetModels.sd_controlnet_mlsd: "lllyasviel/sd-controlnet-mlsd",
ControlNetModels.sd_controlnet_normal: "lllyasviel/sd-controlnet-normal",
ControlNetModels.sd_controlnet_openpose: "lllyasviel/sd-controlnet-openpose",
ControlNetModels.sd_controlnet_scribble: "lllyasviel/sd-controlnet-scribble",
ControlNetModels.sd_controlnet_seg: "lllyasviel/sd-controlnet-seg",
ControlNetModels.sd_controlnet_tile: "lllyasviel/control_v11f1e_sd15_tile",
ControlNetModels.sd_controlnet_brightness: "ioclab/control_v1p_sd15_brightness",
ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "monster-labs/control_v1p_sd15_qrcode_monster/v2",
}

class ControlNetModels(ControlNetModel, GooeyEnum):
sd_controlnet_canny = ControlNetModel(
label="Canny",
explanation="Canny edge detection",
model_id="lllyasviel/sd-controlnet-canny",
)
sd_controlnet_depth = ControlNetModel(
label="Depth",
explanation="Depth estimation",
model_id="lllyasviel/sd-controlnet-depth",
)
sd_controlnet_hed = ControlNetModel(
label="HED Boundary",
explanation="HED edge detection",
model_id="lllyasviel/sd-controlnet-hed",
)
sd_controlnet_mlsd = ControlNetModel(
label="M-LSD Straight Line",
explanation="M-LSD straight line detection",
model_id="lllyasviel/sd-controlnet-mlsd",
)
sd_controlnet_normal = ControlNetModel(
label="Normal Map",
explanation="Normal map estimation",
model_id="lllyasviel/sd-controlnet-normal",
)
sd_controlnet_openpose = ControlNetModel(
label="Human Pose",
explanation="Human pose estimation",
model_id="lllyasviel/sd-controlnet-openpose",
)
sd_controlnet_scribble = ControlNetModel(
label="Scribble",
explanation="Scribble",
model_id="lllyasviel/sd-controlnet-scribble",
)
sd_controlnet_seg = ControlNetModel(
label="Image Segmentation",
explanation="Image segmentation",
model_id="lllyasviel/sd-controlnet-seg",
)
sd_controlnet_tile = ControlNetModel(
label="Tiling",
explanation="Tiling: to preserve small details",
model_id="lllyasviel/control_v11f1e_sd15_tile",
)
sd_controlnet_brightness = ControlNetModel(
label="Brightness",
explanation="Brightness: to increase contrast naturally",
model_id="ioclab/control_v1p_sd15_brightness",
)
control_v1p_sd15_qrcode_monster_v2 = ControlNetModel(
label="QR Monster V2",
explanation="QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose",
model_id="monster-labs/control_v1p_sd15_qrcode_monster/v2",
)


class Schedulers(models.TextChoices):
Expand Down Expand Up @@ -463,8 +485,7 @@ def controlnet(
),
"disable_safety_checker": True,
"controlnet_model_id": [
controlnet_model_ids[ControlNetModels[model]]
for model in selected_controlnet_model
ControlNetModels[model].model_id for model in selected_controlnet_model
],
},
inputs={
Expand Down
4 changes: 1 addition & 3 deletions recipes/Img2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ class RequestModel(BasePage.RequestModel):

selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None
selected_controlnet_model: (
list[typing.Literal[tuple(e.name for e in ControlNetModels)]]
| typing.Literal[tuple(e.name for e in ControlNetModels)]
| None
list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None
)
negative_prompt: str | None

Expand Down
8 changes: 2 additions & 6 deletions recipes/QRCodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,14 @@ class RequestModel(BasePage.RequestModel):
text_prompt: str
negative_prompt: str | None
image_prompt: str | None
image_prompt_controlnet_models: (
list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None
)
image_prompt_controlnet_models: list[ControlNetModels.api_enum] | None
image_prompt_strength: float | None
image_prompt_scale: float | None
image_prompt_pos_x: float | None
image_prompt_pos_y: float | None

selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None
selected_controlnet_model: (
list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None
)
selected_controlnet_model: list[ControlNetModels.api_enum] | None

output_width: int | None
output_height: int | None
Expand Down

0 comments on commit c8f296d

Please sign in to comment.