diff --git a/src/python/enfugue/diffusion/manager.py b/src/python/enfugue/diffusion/manager.py index a0c2a1ea..8023241d 100644 --- a/src/python/enfugue/diffusion/manager.py +++ b/src/python/enfugue/diffusion/manager.py @@ -673,7 +673,6 @@ def refiner_vae( """ Sets a new refiner vae. """ - pretrained_path = self.get_vae_path(new_vae) existing_vae = getattr(self, "_refiner_vae", None) if ( @@ -686,8 +685,9 @@ def refiner_vae( self._refiner_vae = None self.unload_refiner("VAE resetting to default") else: + vae_path = self.check_download_model(self.engine_vae_dir, new_vae) self._refiner_vae_name = new_vae - self._refiner_vae = self.get_vae(pretrained_path) + self._refiner_vae = self.get_vae(vae_path) if self.refiner_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES: self.unload_refiner("VAE changing") elif hasattr(self, "_refiner_pipeline"): @@ -695,7 +695,7 @@ def refiner_vae( self._refiner_pipeline.vae = self._refiner_vae # type: ignore[assignment] if self.refiner_is_sdxl: self._refiner_pipeline.register_to_config( # type: ignore[attr-defined] - force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"] or (new_vae.endswith("sdxl_vae.safetensors") and "16" not in new_vae) + force_full_precision_vae = "xl" in new_vae and "16" not in new_vae ) @property @@ -724,7 +724,6 @@ def inpainter_vae( """ Sets a new inpainter vae. """ - pretrained_path = self.get_vae_path(new_vae) existing_vae = getattr(self, "_inpainter_vae", None) if ( @@ -737,8 +736,9 @@ def inpainter_vae( self._inpainter_vae = None self.unload_inpainter("VAE resetting to default") else: + vae_path = self.check_download_model(self.engine_vae_dir, new_vae) self._inpainter_vae_name = new_vae - self._inpainter_vae = self.get_vae(pretrained_path) + self._inpainter_vae = self.get_vae(vae_path) if self.inpainter_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES: self.unload_inpainter("VAE changing") elif hasattr(self, "_inpainter_pipeline"): @@ -746,7 +746,7 @@ def inpainter_vae( self._inpainter_pipeline.vae = self._inpainter_vae # type: ignore[assignment] if self.inpainter_is_sdxl: self._inpainter_pipeline.register_to_config( # type: ignore[attr-defined] - force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"] or (new_vae.endswith("sdxl_vae.safetensors") and "16" not in new_vae) + force_full_precision_vae = "xl" in new_vae and "16" not in new_vae ) @property @@ -775,7 +775,6 @@ def animator_vae( """ Sets a new animator vae. """ - pretrained_path = self.get_vae_path(new_vae) existing_vae = getattr(self, "_animator_vae", None) if ( @@ -788,8 +787,9 @@ def animator_vae( self._animator_vae = None self.unload_animator("VAE resetting to default") else: + vae_path = self.check_download_model(self.engine_vae_dir, new_vae) self._animator_vae_name = new_vae - self._animator_vae = self.get_vae(pretrained_path) + self._animator_vae = self.get_vae(vae_path) if self.animator_tensorrt_is_ready and "vae" in self.TENSORRT_STAGES: self.unload_animator("VAE changing") elif hasattr(self, "_animator_pipeline"): @@ -797,7 +797,7 @@ def animator_vae( self._animator_pipeline.vae = self._animator_vae # type: ignore [assignment] if self.animator_is_sdxl: self._animator_pipeline.register_to_config( # type: ignore[attr-defined] - force_full_precision_vae = new_vae in ["xl", "stabilityai/sdxl-vae"] + force_full_precision_vae = "xl" in new_vae and "16" not in new_vae ) @property