From ca84adb74a2b41953d62de93192abb459d85a51a Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 26 Oct 2023 10:51:03 +0200 Subject: [PATCH] add support for latent consistency models --- optimum/intel/__init__.py | 4 + optimum/intel/openvino/__init__.py | 1 + optimum/intel/openvino/modeling_diffusion.py | 82 +++++++++++++++---- .../dummy_openvino_and_diffusers_objects.py | 13 +++ tests/openvino/test_stable_diffusion.py | 71 ++++++++++++++++ tests/openvino/utils_tests.py | 1 + 6 files changed, 157 insertions(+), 15 deletions(-) diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index dcd2827eab..cd93a68c66 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -62,6 +62,7 @@ "OVStableDiffusionInpaintPipeline", "OVStableDiffusionXLPipeline", "OVStableDiffusionXLImg2ImgPipeline", + "OVLatentConsistencyModelPipeline", ] else: _import_structure["openvino"].extend( @@ -71,6 +72,7 @@ "OVStableDiffusionInpaintPipeline", "OVStableDiffusionXLPipeline", "OVStableDiffusionXLImg2ImgPipeline", + "OVLatentConsistencyModelPipeline", ] ) @@ -162,6 +164,7 @@ OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, OVStableDiffusionXLImg2ImgPipeline, + OVLatentConsistencyModelPipeline, OVStableDiffusionXLPipeline, ) else: @@ -170,6 +173,7 @@ OVStableDiffusionInpaintPipeline, OVStableDiffusionPipeline, OVStableDiffusionXLImg2ImgPipeline, + OVLatentConsistencyModelPipeline, OVStableDiffusionXLPipeline, ) diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 6b023accda..a7bcbfbce0 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -56,4 +56,5 @@ OVStableDiffusionPipeline, OVStableDiffusionXLImg2ImgPipeline, OVStableDiffusionXLPipeline, + OVLatentConsistencyModelPipeline, ) diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 1ca0b93643..bc63268bdb 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -42,6 +42,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin +from optimum.pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin + from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, @@ -270,20 +272,8 @@ def _from_pretrained( if model_save_dir is None: model_save_dir = new_model_save_dir - return cls( - vae_decoder=components["vae_decoder"], - text_encoder=components["text_encoder"], - unet=unet, - config=config, - tokenizer=kwargs.pop("tokenizer", None), - scheduler=kwargs.pop("scheduler"), - feature_extractor=kwargs.pop("feature_extractor", None), - vae_encoder=components["vae_encoder"], - text_encoder_2=components["text_encoder_2"], - tokenizer_2=kwargs.pop("tokenizer_2", None), - model_save_dir=model_save_dir, - **kwargs, - ) + return cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs) + @classmethod def _from_transformers( @@ -377,8 +367,10 @@ def _reshape_unet( if batch_size == -1 or num_images_per_prompt == -1: batch_size = -1 else: + batch_size *= num_images_per_prompt # The factor of 2 comes from the guidance scale > 1 - batch_size = 2 * batch_size * num_images_per_prompt + if "timestep_cond" not in {inputs.get_any_name() for inputs in model.inputs}: + batch_size *= 2 height = height // self.vae_scale_factor if height > 0 else height width = width // self.vae_scale_factor if width > 0 else width @@ -402,6 +394,8 @@ def _reshape_unet( shapes[inputs] = [batch_size, self.text_encoder_2.config["projection_dim"]] elif inputs.get_any_name() == "time_ids": shapes[inputs] = [batch_size, inputs.get_partial_shape()[1]] + elif inputs.get_any_name() == "timestep_cond": + shapes[inputs] = [batch_size, self.unet.config["time_cond_proj_dim"]] else: shapes[inputs][0] = batch_size shapes[inputs][1] = tokenizer_max_length @@ -587,6 +581,7 @@ def __call__( encoder_hidden_states: np.ndarray, text_embeds: Optional[np.ndarray] = None, time_ids: Optional[np.ndarray] = None, + timestep_cond: Optional[np.ndarray] = None, ): self._compile() @@ -600,6 +595,8 @@ def __call__( inputs["text_embeds"] = text_embeds if time_ids is not None: inputs["time_ids"] = time_ids + if timestep_cond is not None: + inputs["timestep_cond"] = timestep_cond outputs = self.request(inputs, shared_memory=True) return list(outputs.values()) @@ -932,6 +929,61 @@ def __call__( ) +class OVLatentConsistencyModelPipeline(OVStableDiffusionPipelineBase, LatentConsistencyPipelineMixin): + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 4, + original_inference_steps: int = None, + guidance_scale: float = 8.5, + num_images_per_prompt: int = 1, + **kwargs, + ): + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor + _height = self.height + _width = self.width + expected_batch_size = self._batch_size + + if _height != -1 and height != _height: + logger.warning( + f"`height` was set to {height} but the static model will output images of height {_height}." + "To fix the height, please reshape your model accordingly using the `.reshape()` method." + ) + height = _height + + if _width != -1 and width != _width: + logger.warning( + f"`width` was set to {width} but the static model will output images of width {_width}." + "To fix the width, please reshape your model accordingly using the `.reshape()` method." + ) + width = _width + + if expected_batch_size != -1: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = kwargs.get("prompt_embeds").shape[0] + + _raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale=0.0) + + return LatentConsistencyPipelineMixin.__call__( + self, + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + original_inference_steps=original_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + **kwargs, + ) + + def _raise_invalid_batch_size( expected_batch_size: int, batch_size: int, num_images_per_prompt: int, guidance_scale: float ): diff --git a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py index 72c2b8de10..0b108cc5f7 100644 --- a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py +++ b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py @@ -68,3 +68,16 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["openvino", "diffusers"]) + + + +class OVLatentConsistencyModelPipeline(metaclass=DummyObject): + _backends = ["openvino", "diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "diffusers"]) + diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index 0e2ea91e4c..5e25a7eb48 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -25,6 +25,7 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, ) +from packaging.version import Version, parse from diffusers.utils import load_image from diffusers.utils.testing_utils import floats_tensor from openvino.runtime.ie_api import CompiledModel @@ -37,7 +38,9 @@ OVStableDiffusionPipeline, OVStableDiffusionXLImg2ImgPipeline, OVStableDiffusionXLPipeline, + OVLatentConsistencyModelPipeline, ) +from optimum.utils.import_utils import _diffusers_version from optimum.intel.openvino.modeling_diffusion import ( OVModelTextEncoder, OVModelUnet, @@ -475,3 +478,71 @@ def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) inputs["strength"] = 0.75 return inputs + + +class OVLatentConsistencyModelPipelineTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ("latent-consistency",) + MODEL_CLASS = OVLatentConsistencyModelPipeline + TASK = "text-to-image" + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version") + def test_compare_to_diffusers(self, model_arch: str): + ov_pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + self.assertIsInstance(ov_pipeline.text_encoder, OVModelTextEncoder) + self.assertIsInstance(ov_pipeline.vae_encoder, OVModelVaeEncoder) + self.assertIsInstance(ov_pipeline.vae_decoder, OVModelVaeDecoder) + self.assertIsInstance(ov_pipeline.unet, OVModelUnet) + self.assertIsInstance(ov_pipeline.config, Dict) + + from diffusers import LatentConsistencyModelPipeline + + pipeline = LatentConsistencyModelPipeline.from_pretrained(MODEL_NAMES[model_arch]) + batch_size, num_images_per_prompt, height, width = 2, 3, 64, 128 + latents = ov_pipeline.prepare_latents( + batch_size * num_images_per_prompt, + ov_pipeline.unet.config["in_channels"], + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), + ) + + kwargs = { + "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, + "num_inference_steps": 1, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_scale": 8.5, + } + + for output_type in ["latent", "np"]: + ov_outputs = ov_pipeline(latents=latents, output_type=output_type, **kwargs).images + self.assertIsInstance(ov_outputs, np.ndarray) + with torch.no_grad(): + outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images + + # Compare model outputs + self.assertTrue(np.allclose(ov_outputs, outputs, atol=1e-4)) + # Compare model devices + self.assertEqual(pipeline.device.type, ov_pipeline.device) + + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version") + def test_num_images_per_prompt_static_model(self, model_arch: str): + model_id = MODEL_NAMES[model_arch] + pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False) + batch_size, num_images, height, width = 3, 4, 128, 64 + pipeline.half() + pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) + self.assertFalse(pipeline.is_dynamic) + pipeline.compile() + + for _height in [height, height + 16]: + inputs = _generate_inputs(batch_size) + outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 2fa77052eb..958e6e23b3 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -72,6 +72,7 @@ "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner", + "latent-consistency": "echarlaix/tiny-random-latent-consistency", "sew": "hf-internal-testing/tiny-random-SEWModel", "sew_d": "hf-internal-testing/tiny-random-SEWDModel", "swin": "hf-internal-testing/tiny-random-SwinModel",