Skip to content

Commit

Permalink
fix #131 for real this time
Browse files Browse the repository at this point in the history
  • Loading branch information
painebenjamin committed Jan 23, 2024
1 parent 88120e7 commit 1d94cb6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
11 changes: 8 additions & 3 deletions src/python/enfugue/api/controller/invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,14 @@ def invoke_engine(self, request: Request, response: Response) -> Dict[str, Any]:
request.parsed.pop("inpainter_size", None)
request.parsed.pop("inpainter_vae", None)

vae = request.parsed.pop("vae", None)
if vae is not None:
plan_kwargs["vae"] = self.check_find_model("vae", vae)
for vae_key in ["vae", "inpainter_vae", "refiner_vae"]:
maybe_vae = request.parsed.pop(vae_key, None)
if maybe_vae is not None:
plan_kwargs[vae_key] = self.check_find_model("vae", maybe_vae)

motion_module = request.parsed.pop("motion_module", None)
if motion_module is not None:
plan_kwargs["motion_module"] = self.check_find_model("motion", motion_module)

lora = request.parsed.pop("lora", [])
plan_kwargs["lora"] = self.check_find_adaptations("lora", True, lora) if lora else None
Expand Down
13 changes: 8 additions & 5 deletions src/python/enfugue/diffusion/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,21 +267,24 @@ def check_download_model(self, local_dir: str, remote_url: str) -> str:
Downloads a model directly to the model folder if enabled.
"""
if not remote_url.startswith("http"):
return remote_url
output_file = get_file_name_from_url(remote_url)
output_file = os.path.basename(remote_url)
else:
output_file = get_file_name_from_url(remote_url)

output_path = os.path.join(local_dir, output_file)

for directory in [local_dir, self.engine_root]:
found_path = find_file_in_directory(
self.engine_root,
directory,
os.path.splitext(output_file)[0],
extensions = [".ckpt", ".bin", ".pt", ".pth", ".safetensors"]
)

if found_path:
return found_path

if self.offline:
if not remote_url.startswith("http"):
raise ValueError(f"Resource '{remote_url}' is not a URL and cannot be found on-disk.")
elif self.offline:
raise ValueError(f"File {output_file} does not exist in {local_dir} and offline mode is enabled, refusing to download from {remote_url}")

file_label = "{0} from {1}".format(
Expand Down

0 comments on commit 1d94cb6

Please sign in to comment.