Skip to content

Commit

Permalink
Improve model detection #61 #57 #50
Browse files Browse the repository at this point in the history
* Be more strict about LCM lora names to avoid similarly named loras
* Match complete substring (still allows custom folders)
* Log which models are found and which are missing
* Log available models and search paths if a required model isn't found
* Prioritize models that are in a "krita" subfolder
  • Loading branch information
Acly committed Nov 19, 2023
1 parent 8d0102a commit e35eb60
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 29 deletions.
84 changes: 56 additions & 28 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ async def connect(url=default_url):

# Retrieve CLIPVision models
cv = nodes["CLIPVisionLoader"]["input"]["required"]["clip_name"][0]
client.clip_vision_model = _find_clip_vision_model(cv, "SD1.5")
client.clip_vision_model = _find_clip_vision_model(cv)

# Retrieve IP-Adapter model
ip = nodes["IPAdapterModelLoader"]["input"]["required"]["ipadapter_file"][0]
Expand Down Expand Up @@ -425,37 +425,62 @@ def filter_supported_styles(styles: Styles, client: Optional[Client] = None):
return list(styles)


def _find_control_model(model_list: Sequence[str], mode: ControlMode):
def match_filename(path: str, name: str):
path_sep = "\\" if is_windows else "/"
return path.startswith(name) or path.split(path_sep)[-1].startswith(name)
def _find_model(
kind: ResourceKind,
sdver: SDVersion,
model_list: Sequence[str],
search_paths: Sequence[str],
info: ControlMode | None = None,
):
sanitize = lambda m: m.replace("\\", "/").lower()
matches = (m for m in model_list if any(p in sanitize(m) for p in search_paths))
# if there are multiple matches, prefer the one with "krita" in the path
prio = sorted(matches, key=lambda m: 0 if "krita" in m else 1)
model_name = next(iter(prio), None)

model_id = kind.value if not info else f"{kind.value} {info.name}"
is_optional = kind is ResourceKind.controlnet and not (
sdver is SDVersion.sd15 and info in [ControlMode.inpaint, ControlMode.blur]
)
if model_name is None and not is_optional:
log.warning(f"Missing {model_id} for {sdver.value}")
log.info(f"Available {model_id}s: {', '.join(sanitize(m) for m in model_list)}")
log.info(f"No model matches {model_id} search paths: {', '.join(search_paths)}")
elif model_name is None:
log.info(
f"Optional {model_id} for {sdver.value} not found (search path:"
f" {', '.join(search_paths)})"
)
else:
log.info(f"Found {model_id} for {sdver.value}: {model_name}")
return model_name

def find(name: Union[str, list, None]):

def _find_control_model(model_list: Sequence[str], mode: ControlMode):
def find(sdver: SDVersion):
name = mode.filenames(sdver)
if name is None:
return None
names = [name] if isinstance(name, str) else name
matches_name = lambda model: any(match_filename(model, name) for name in names)
model = next((model for model in model_list if matches_name(model)), None)
return model
return _find_model(ResourceKind.controlnet, sdver, model_list, names, mode)

return {version: find(mode.filenames(version)) for version in [SDVersion.sd15, SDVersion.sdxl]}
return {version: find(version) for version in [SDVersion.sd15, SDVersion.sdxl]}


def _find_clip_vision_model(model_list: Sequence[str], sdver: str):
assert sdver == "SD1.5", "Using SD1.5 clip vision model also for SDXL IP-adapter"
model_name = "model."
match = lambda x: sdver in x and model_name in x
model = next((m for m in model_list if match(m)), None)
def _find_clip_vision_model(model_list: Sequence[str]):
search_paths = ["sd1.5/pytorch_model.bin", "sd1.5/model.safetensors"]
model = _find_model(ResourceKind.clip_vision, SDVersion.all, model_list, search_paths)
if model is None:
full_name = f"{sdver}/model.safetensors"
raise MissingResource(ResourceKind.clip_vision, [full_name])
raise MissingResource(ResourceKind.clip_vision, ["SD1.5/model.safetensors"])
return model


def _find_ip_adapter(model_list: Sequence[str], sdver: SDVersion):
model_name = "ip-adapter_sd15" if sdver is SDVersion.sd15 else "ip-adapter_sdxl_vit-h"
model = next((m for m in model_list if model_name in m), None)
return model
search_paths = {
SDVersion.sd15: ["ip-adapter_sd15"],
SDVersion.sdxl: ["ip-adapter_sdxl_vit-h"],
}[sdver]
return _find_model(ResourceKind.ip_adapter, sdver, model_list, search_paths)


def _find_upscaler(model_list: Sequence[str], model_name: str):
Expand All @@ -466,14 +491,17 @@ def _find_upscaler(model_list: Sequence[str], model_name: str):


def _find_lcm(model_list: Sequence[str], sdver: SDVersion):
verstr = ["sd15", "sdv1-5", "sd1.5"] if sdver is SDVersion.sd15 else ["sdxl"]
base = ["lcm_lora", "lcm-lora", "pytorch_lora_weights"]

def match_path(path: str):
path = path.lower()
return any(x in path for x in verstr) and any(x in path for x in base) and "lcm" in path

return next((m for m in model_list if match_path(m)), None)
search_paths = {
SDVersion.sd15: [
"lcm-lora-sdv1-5.safetensors",
"lcm/sd1.5/pytorch_lora_weights.safetensors",
],
SDVersion.sdxl: [
"lcm-lora-sdxl.safetensors",
"lcm/sdxl/pytorch_lora_weights.safetensors",
],
}[sdver]
return _find_model(ResourceKind.lcm_lora, sdver, model_list, search_paths)


def _ensure_supported_style(client: Client):
Expand Down
2 changes: 1 addition & 1 deletion ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def filenames(self, sd_version: SDVersion):
},
ControlMode.pose: {
SDVersion.sd15: ["control_v11p_sd15_openpose", "control_lora_rank128_v11p_sd15_openpose"],
SDVersion.sdxl: ["control-lora-openposeXL2-rank", "thibaud_xl_openpose"],
SDVersion.sdxl: ["control-lora-openposexl2-rank", "thibaud_xl_openpose"],
},
ControlMode.segmentation: {
SDVersion.sd15: ["control_v11p_sd15_seg", "control_lora_rank128_v11p_sd15_seg"],
Expand Down
8 changes: 8 additions & 0 deletions doc/comfy-requirements.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ The following checkpoints are used by the default styles:
* [JuggernautXL](https://civitai.com/api/download/models/198530)

At least one checkpoint is required, but it doesn't have to be one of the above.

## Troubleshooting
If you're getting errors about missing resources, or workload not being installed, it's probably because one of the models wasn't found.
You can find the `client.log` file in the `.logs` folder where you installed the plugin. Check the log for warnings. Here you will also
find which models were found in your installation, and the patterns the plugin looks for.

Model paths must contain one of the search patterns entirely to match. The model path is allowed to be longer though: you may place models
in arbitrary subfolders and they will still be found. If there are multiple matches, any files placed inside a `krita` subfolder are prioritized.

0 comments on commit e35eb60

Please sign in to comment.