diff --git a/src/python/enfugue/api/controller/invocation.py b/src/python/enfugue/api/controller/invocation.py index e516bb76..b9434357 100644 --- a/src/python/enfugue/api/controller/invocation.py +++ b/src/python/enfugue/api/controller/invocation.py @@ -177,9 +177,14 @@ def invoke_engine(self, request: Request, response: Response) -> Dict[str, Any]: request.parsed.pop("inpainter_size", None) request.parsed.pop("inpainter_vae", None) - vae = request.parsed.pop("vae", None) - if vae is not None: - plan_kwargs["vae"] = self.check_find_model("vae", vae) + for vae_key in ["vae", "inpainter_vae", "refiner_vae"]: + maybe_vae = request.parsed.pop(vae_key, None) + if maybe_vae is not None: + plan_kwargs[vae_key] = self.check_find_model("vae", maybe_vae) + + motion_module = request.parsed.pop("motion_module", None) + if motion_module is not None: + plan_kwargs["motion_module"] = self.check_find_model("motion", motion_module) lora = request.parsed.pop("lora", []) plan_kwargs["lora"] = self.check_find_adaptations("lora", True, lora) if lora else None diff --git a/src/python/enfugue/diffusion/manager.py b/src/python/enfugue/diffusion/manager.py index f0ffbd1e..fe3634db 100644 --- a/src/python/enfugue/diffusion/manager.py +++ b/src/python/enfugue/diffusion/manager.py @@ -267,21 +267,24 @@ def check_download_model(self, local_dir: str, remote_url: str) -> str: Downloads a model directly to the model folder if enabled. """ if not remote_url.startswith("http"): - return remote_url - output_file = get_file_name_from_url(remote_url) + output_file = os.path.basename(remote_url) + else: + output_file = get_file_name_from_url(remote_url) + output_path = os.path.join(local_dir, output_file) for directory in [local_dir, self.engine_root]: found_path = find_file_in_directory( - self.engine_root, + directory, os.path.splitext(output_file)[0], extensions = [".ckpt", ".bin", ".pt", ".pth", ".safetensors"] ) - if found_path: return found_path - if self.offline: + if not remote_url.startswith("http"): + raise ValueError(f"Resource '{remote_url}' is not a URL and cannot be found on-disk.") + elif self.offline: raise ValueError(f"File {output_file} does not exist in {local_dir} and offline mode is enabled, refusing to download from {remote_url}") file_label = "{0} from {1}".format(