diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index 7f62fb878..733465cd1 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -6,7 +6,6 @@ InpaintingModels, Img2ImgModels, ControlNetModels, - controlnet_model_explanations, Schedulers, ) @@ -130,9 +129,7 @@ def controlnet_settings( if not models: return - if extra_explanations is None: - extra_explanations = {} - explanations = controlnet_model_explanations | extra_explanations + extra_explanations = extra_explanations or {} state_values = gui.session_state.get("controlnet_conditioning_scale", []) new_values = [] @@ -157,7 +154,9 @@ def controlnet_settings( pass new_values.append( controlnet_weight_setting( - selected_controlnet_model=model, explanations=explanations, key=key + selected_controlnet_model=model, + extra_explanations=extra_explanations, + key=key, ), ) gui.session_state["controlnet_conditioning_scale"] = new_values @@ -166,13 +165,13 @@ def controlnet_settings( def controlnet_weight_setting( *, selected_controlnet_model: str, - explanations: dict[ControlNetModels, str], + extra_explanations: dict[ControlNetModels, str], key: str = "controlnet_conditioning_scale", ): model = ControlNetModels[selected_controlnet_model] return gui.slider( label=f""" - {explanations[model]}. + {extra_explanations.get(model, model.explanation)}. """, key=key, min_value=CONTROLNET_CONDITIONING_SCALE_RANGE[0], diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index f59b044c3..99d3e6608 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -13,6 +13,7 @@ resize_img_fit, get_downscale_factor, ) +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import ( raise_for_status, UserError, @@ -119,47 +120,68 @@ def _deprecated(cls): } -class ControlNetModels(Enum): - sd_controlnet_canny = "Canny" - sd_controlnet_depth = "Depth" - sd_controlnet_hed = "HED Boundary" - sd_controlnet_mlsd = "M-LSD Straight Line" - sd_controlnet_normal = "Normal Map" - sd_controlnet_openpose = "Human Pose" - sd_controlnet_scribble = "Scribble" - sd_controlnet_seg = "Image Segmentation" - sd_controlnet_tile = "Tiling" - sd_controlnet_brightness = "Brightness" - control_v1p_sd15_qrcode_monster_v2 = "QR Monster V2" - - -controlnet_model_explanations = { - ControlNetModels.sd_controlnet_canny: "Canny edge detection", - ControlNetModels.sd_controlnet_depth: "Depth estimation", - ControlNetModels.sd_controlnet_hed: "HED edge detection", - ControlNetModels.sd_controlnet_mlsd: "M-LSD straight line detection", - ControlNetModels.sd_controlnet_normal: "Normal map estimation", - ControlNetModels.sd_controlnet_openpose: "Human pose estimation", - ControlNetModels.sd_controlnet_scribble: "Scribble", - ControlNetModels.sd_controlnet_seg: "Image segmentation", - ControlNetModels.sd_controlnet_tile: "Tiling: to preserve small details", - ControlNetModels.sd_controlnet_brightness: "Brightness: to increase contrast naturally", - ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose", -} +class ControlNetModel(typing.NamedTuple): + label: str + model_id: str + explanation: str -controlnet_model_ids = { - ControlNetModels.sd_controlnet_canny: "lllyasviel/sd-controlnet-canny", - ControlNetModels.sd_controlnet_depth: "lllyasviel/sd-controlnet-depth", - ControlNetModels.sd_controlnet_hed: "lllyasviel/sd-controlnet-hed", - ControlNetModels.sd_controlnet_mlsd: "lllyasviel/sd-controlnet-mlsd", - ControlNetModels.sd_controlnet_normal: "lllyasviel/sd-controlnet-normal", - ControlNetModels.sd_controlnet_openpose: "lllyasviel/sd-controlnet-openpose", - ControlNetModels.sd_controlnet_scribble: "lllyasviel/sd-controlnet-scribble", - ControlNetModels.sd_controlnet_seg: "lllyasviel/sd-controlnet-seg", - ControlNetModels.sd_controlnet_tile: "lllyasviel/control_v11f1e_sd15_tile", - ControlNetModels.sd_controlnet_brightness: "ioclab/control_v1p_sd15_brightness", - ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "monster-labs/control_v1p_sd15_qrcode_monster/v2", -} + +class ControlNetModels(ControlNetModel, GooeyEnum): + sd_controlnet_canny = ControlNetModel( + label="Canny", + explanation="Canny edge detection", + model_id="lllyasviel/sd-controlnet-canny", + ) + sd_controlnet_depth = ControlNetModel( + label="Depth", + explanation="Depth estimation", + model_id="lllyasviel/sd-controlnet-depth", + ) + sd_controlnet_hed = ControlNetModel( + label="HED Boundary", + explanation="HED edge detection", + model_id="lllyasviel/sd-controlnet-hed", + ) + sd_controlnet_mlsd = ControlNetModel( + label="M-LSD Straight Line", + explanation="M-LSD straight line detection", + model_id="lllyasviel/sd-controlnet-mlsd", + ) + sd_controlnet_normal = ControlNetModel( + label="Normal Map", + explanation="Normal map estimation", + model_id="lllyasviel/sd-controlnet-normal", + ) + sd_controlnet_openpose = ControlNetModel( + label="Human Pose", + explanation="Human pose estimation", + model_id="lllyasviel/sd-controlnet-openpose", + ) + sd_controlnet_scribble = ControlNetModel( + label="Scribble", + explanation="Scribble", + model_id="lllyasviel/sd-controlnet-scribble", + ) + sd_controlnet_seg = ControlNetModel( + label="Image Segmentation", + explanation="Image segmentation", + model_id="lllyasviel/sd-controlnet-seg", + ) + sd_controlnet_tile = ControlNetModel( + label="Tiling", + explanation="Tiling: to preserve small details", + model_id="lllyasviel/control_v11f1e_sd15_tile", + ) + sd_controlnet_brightness = ControlNetModel( + label="Brightness", + explanation="Brightness: to increase contrast naturally", + model_id="ioclab/control_v1p_sd15_brightness", + ) + control_v1p_sd15_qrcode_monster_v2 = ControlNetModel( + label="QR Monster V2", + explanation="QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose", + model_id="monster-labs/control_v1p_sd15_qrcode_monster/v2", + ) class Schedulers(models.TextChoices): @@ -463,8 +485,7 @@ def controlnet( ), "disable_safety_checker": True, "controlnet_model_id": [ - controlnet_model_ids[ControlNetModels[model]] - for model in selected_controlnet_model + ControlNetModels[model].model_id for model in selected_controlnet_model ], }, inputs={ diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 7cb9283a7..12d2ede32 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -48,9 +48,7 @@ class RequestModel(BasePage.RequestModel): selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None selected_controlnet_model: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)]] - | typing.Literal[tuple(e.name for e in ControlNetModels)] - | None + list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None ) negative_prompt: str | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index 8a8d4921f..b57905885 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -90,18 +90,14 @@ class RequestModel(BasePage.RequestModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None - ) + image_prompt_controlnet_models: list[ControlNetModels.api_enum] | None image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None - ) + selected_controlnet_model: list[ControlNetModels.api_enum] | None output_width: int | None output_height: int | None