From e35eb60bff15bb9961e74a633c1a1517a722ac9c Mon Sep 17 00:00:00 2001 From: Acly Date: Sun, 19 Nov 2023 12:42:38 +0100 Subject: [PATCH] Improve model detection #61 #57 #50 * Be more strict about LCM lora names to avoid similarly named loras * Match complete substring (still allows custom folders) * Log which models are found and which are missing * Log available models and search paths if a required model isn't found * Prioritize models that are in a "krita" subfolder --- ai_diffusion/client.py | 84 ++++++++++++++++++++++++++------------- ai_diffusion/resources.py | 2 +- doc/comfy-requirements.md | 8 ++++ 3 files changed, 65 insertions(+), 29 deletions(-) diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index f770a5666..4a006e15f 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -157,7 +157,7 @@ async def connect(url=default_url): # Retrieve CLIPVision models cv = nodes["CLIPVisionLoader"]["input"]["required"]["clip_name"][0] - client.clip_vision_model = _find_clip_vision_model(cv, "SD1.5") + client.clip_vision_model = _find_clip_vision_model(cv) # Retrieve IP-Adapter model ip = nodes["IPAdapterModelLoader"]["input"]["required"]["ipadapter_file"][0] @@ -425,37 +425,62 @@ def filter_supported_styles(styles: Styles, client: Optional[Client] = None): return list(styles) -def _find_control_model(model_list: Sequence[str], mode: ControlMode): - def match_filename(path: str, name: str): - path_sep = "\\" if is_windows else "/" - return path.startswith(name) or path.split(path_sep)[-1].startswith(name) +def _find_model( + kind: ResourceKind, + sdver: SDVersion, + model_list: Sequence[str], + search_paths: Sequence[str], + info: ControlMode | None = None, +): + sanitize = lambda m: m.replace("\\", "/").lower() + matches = (m for m in model_list if any(p in sanitize(m) for p in search_paths)) + # if there are multiple matches, prefer the one with "krita" in the path + prio = sorted(matches, key=lambda m: 0 if "krita" in m else 1) + model_name = next(iter(prio), None) + + model_id = kind.value if not info else f"{kind.value} {info.name}" + is_optional = kind is ResourceKind.controlnet and not ( + sdver is SDVersion.sd15 and info in [ControlMode.inpaint, ControlMode.blur] + ) + if model_name is None and not is_optional: + log.warning(f"Missing {model_id} for {sdver.value}") + log.info(f"Available {model_id}s: {', '.join(sanitize(m) for m in model_list)}") + log.info(f"No model matches {model_id} search paths: {', '.join(search_paths)}") + elif model_name is None: + log.info( + f"Optional {model_id} for {sdver.value} not found (search path:" + f" {', '.join(search_paths)})" + ) + else: + log.info(f"Found {model_id} for {sdver.value}: {model_name}") + return model_name - def find(name: Union[str, list, None]): + +def _find_control_model(model_list: Sequence[str], mode: ControlMode): + def find(sdver: SDVersion): + name = mode.filenames(sdver) if name is None: return None names = [name] if isinstance(name, str) else name - matches_name = lambda model: any(match_filename(model, name) for name in names) - model = next((model for model in model_list if matches_name(model)), None) - return model + return _find_model(ResourceKind.controlnet, sdver, model_list, names, mode) - return {version: find(mode.filenames(version)) for version in [SDVersion.sd15, SDVersion.sdxl]} + return {version: find(version) for version in [SDVersion.sd15, SDVersion.sdxl]} -def _find_clip_vision_model(model_list: Sequence[str], sdver: str): - assert sdver == "SD1.5", "Using SD1.5 clip vision model also for SDXL IP-adapter" - model_name = "model." - match = lambda x: sdver in x and model_name in x - model = next((m for m in model_list if match(m)), None) +def _find_clip_vision_model(model_list: Sequence[str]): + search_paths = ["sd1.5/pytorch_model.bin", "sd1.5/model.safetensors"] + model = _find_model(ResourceKind.clip_vision, SDVersion.all, model_list, search_paths) if model is None: - full_name = f"{sdver}/model.safetensors" - raise MissingResource(ResourceKind.clip_vision, [full_name]) + raise MissingResource(ResourceKind.clip_vision, ["SD1.5/model.safetensors"]) return model def _find_ip_adapter(model_list: Sequence[str], sdver: SDVersion): - model_name = "ip-adapter_sd15" if sdver is SDVersion.sd15 else "ip-adapter_sdxl_vit-h" - model = next((m for m in model_list if model_name in m), None) - return model + search_paths = { + SDVersion.sd15: ["ip-adapter_sd15"], + SDVersion.sdxl: ["ip-adapter_sdxl_vit-h"], + }[sdver] + return _find_model(ResourceKind.ip_adapter, sdver, model_list, search_paths) def _find_upscaler(model_list: Sequence[str], model_name: str): @@ -466,14 +491,17 @@ def _find_upscaler(model_list: Sequence[str], model_name: str): def _find_lcm(model_list: Sequence[str], sdver: SDVersion): - verstr = ["sd15", "sdv1-5", "sd1.5"] if sdver is SDVersion.sd15 else ["sdxl"] - base = ["lcm_lora", "lcm-lora", "pytorch_lora_weights"] - - def match_path(path: str): - path = path.lower() - return any(x in path for x in verstr) and any(x in path for x in base) and "lcm" in path - - return next((m for m in model_list if match_path(m)), None) + search_paths = { + SDVersion.sd15: [ + "lcm-lora-sdv1-5.safetensors", + "lcm/sd1.5/pytorch_lora_weights.safetensors", + ], + SDVersion.sdxl: [ + "lcm-lora-sdxl.safetensors", + "lcm/sdxl/pytorch_lora_weights.safetensors", + ], + }[sdver] + return _find_model(ResourceKind.lcm_lora, sdver, model_list, search_paths) def _ensure_supported_style(client: Client): diff --git a/ai_diffusion/resources.py b/ai_diffusion/resources.py index b9706000f..5422f3d7b 100644 --- a/ai_diffusion/resources.py +++ b/ai_diffusion/resources.py @@ -401,7 +401,7 @@ def filenames(self, sd_version: SDVersion): }, ControlMode.pose: { SDVersion.sd15: ["control_v11p_sd15_openpose", "control_lora_rank128_v11p_sd15_openpose"], - SDVersion.sdxl: ["control-lora-openposeXL2-rank", "thibaud_xl_openpose"], + SDVersion.sdxl: ["control-lora-openposexl2-rank", "thibaud_xl_openpose"], }, ControlMode.segmentation: { SDVersion.sd15: ["control_v11p_sd15_seg", "control_lora_rank128_v11p_sd15_seg"], diff --git a/doc/comfy-requirements.md b/doc/comfy-requirements.md index 2f3c9f6bf..6ee980039 100644 --- a/doc/comfy-requirements.md +++ b/doc/comfy-requirements.md @@ -35,3 +35,11 @@ The following checkpoints are used by the default styles: * [JuggernautXL](https://civitai.com/api/download/models/198530) At least one checkpoint is required, but it doesn't have to be one of the above. + +## Troubleshooting +If you're getting errors about missing resources, or workload not being installed, it's probably because one of the models wasn't found. +You can find the `client.log` file in the `.logs` folder where you installed the plugin. Check the log for warnings. Here you will also +find which models were found in your installation, and the patterns the plugin looks for. + +Model paths must contain one of the search patterns entirely to match. The model path is allowed to be longer though: you may place models +in arbitrary subfolders and they will still be found. If there are multiple matches, any files placed inside a `krita` subfolder are prioritized. \ No newline at end of file