Skip to content

Commit

Permalink
add support for latent consistency models
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 26, 2023
1 parent cadfab7 commit ca84adb
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 15 deletions.
4 changes: 4 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"OVStableDiffusionInpaintPipeline",
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVLatentConsistencyModelPipeline",
]
else:
_import_structure["openvino"].extend(
Expand All @@ -71,6 +72,7 @@
"OVStableDiffusionInpaintPipeline",
"OVStableDiffusionXLPipeline",
"OVStableDiffusionXLImg2ImgPipeline",
"OVLatentConsistencyModelPipeline",
]
)

Expand Down Expand Up @@ -162,6 +164,7 @@
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVLatentConsistencyModelPipeline,
OVStableDiffusionXLPipeline,
)
else:
Expand All @@ -170,6 +173,7 @@
OVStableDiffusionInpaintPipeline,
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVLatentConsistencyModelPipeline,
OVStableDiffusionXLPipeline,
)

Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVStableDiffusionXLPipeline,
OVLatentConsistencyModelPipeline,
)
82 changes: 67 additions & 15 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from optimum.pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
from optimum.pipelines.diffusers.pipeline_latent_consistency import LatentConsistencyPipelineMixin

from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor
from optimum.utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
Expand Down Expand Up @@ -270,20 +272,8 @@ def _from_pretrained(
if model_save_dir is None:
model_save_dir = new_model_save_dir

return cls(
vae_decoder=components["vae_decoder"],
text_encoder=components["text_encoder"],
unet=unet,
config=config,
tokenizer=kwargs.pop("tokenizer", None),
scheduler=kwargs.pop("scheduler"),
feature_extractor=kwargs.pop("feature_extractor", None),
vae_encoder=components["vae_encoder"],
text_encoder_2=components["text_encoder_2"],
tokenizer_2=kwargs.pop("tokenizer_2", None),
model_save_dir=model_save_dir,
**kwargs,
)
return cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)


@classmethod
def _from_transformers(
Expand Down Expand Up @@ -377,8 +367,10 @@ def _reshape_unet(
if batch_size == -1 or num_images_per_prompt == -1:
batch_size = -1
else:
batch_size *= num_images_per_prompt
# The factor of 2 comes from the guidance scale > 1
batch_size = 2 * batch_size * num_images_per_prompt
if "timestep_cond" not in {inputs.get_any_name() for inputs in model.inputs}:
batch_size *= 2

height = height // self.vae_scale_factor if height > 0 else height
width = width // self.vae_scale_factor if width > 0 else width
Expand All @@ -402,6 +394,8 @@ def _reshape_unet(
shapes[inputs] = [batch_size, self.text_encoder_2.config["projection_dim"]]
elif inputs.get_any_name() == "time_ids":
shapes[inputs] = [batch_size, inputs.get_partial_shape()[1]]
elif inputs.get_any_name() == "timestep_cond":
shapes[inputs] = [batch_size, self.unet.config["time_cond_proj_dim"]]
else:
shapes[inputs][0] = batch_size
shapes[inputs][1] = tokenizer_max_length
Expand Down Expand Up @@ -587,6 +581,7 @@ def __call__(
encoder_hidden_states: np.ndarray,
text_embeds: Optional[np.ndarray] = None,
time_ids: Optional[np.ndarray] = None,
timestep_cond: Optional[np.ndarray] = None,
):
self._compile()

Expand All @@ -600,6 +595,8 @@ def __call__(
inputs["text_embeds"] = text_embeds
if time_ids is not None:
inputs["time_ids"] = time_ids
if timestep_cond is not None:
inputs["timestep_cond"] = timestep_cond

outputs = self.request(inputs, shared_memory=True)
return list(outputs.values())
Expand Down Expand Up @@ -932,6 +929,61 @@ def __call__(
)


class OVLatentConsistencyModelPipeline(OVStableDiffusionPipelineBase, LatentConsistencyPipelineMixin):
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 4,
original_inference_steps: int = None,
guidance_scale: float = 8.5,
num_images_per_prompt: int = 1,
**kwargs,
):
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
_height = self.height
_width = self.width
expected_batch_size = self._batch_size

if _height != -1 and height != _height:
logger.warning(
f"`height` was set to {height} but the static model will output images of height {_height}."
"To fix the height, please reshape your model accordingly using the `.reshape()` method."
)
height = _height

if _width != -1 and width != _width:
logger.warning(
f"`width` was set to {width} but the static model will output images of width {_width}."
"To fix the width, please reshape your model accordingly using the `.reshape()` method."
)
width = _width

if expected_batch_size != -1:
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = kwargs.get("prompt_embeds").shape[0]

_raise_invalid_batch_size(expected_batch_size, batch_size, num_images_per_prompt, guidance_scale=0.0)

return LatentConsistencyPipelineMixin.__call__(
self,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
original_inference_steps=original_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
**kwargs,
)


def _raise_invalid_batch_size(
expected_batch_size: int, batch_size: int, num_images_per_prompt: int, guidance_scale: float
):
Expand Down
13 changes: 13 additions & 0 deletions optimum/intel/utils/dummy_openvino_and_diffusers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,16 @@ def __init__(self, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "diffusers"])



class OVLatentConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["openvino", "diffusers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["openvino", "diffusers"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["openvino", "diffusers"])

71 changes: 71 additions & 0 deletions tests/openvino/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
)
from packaging.version import Version, parse
from diffusers.utils import load_image
from diffusers.utils.testing_utils import floats_tensor
from openvino.runtime.ie_api import CompiledModel
Expand All @@ -37,7 +38,9 @@
OVStableDiffusionPipeline,
OVStableDiffusionXLImg2ImgPipeline,
OVStableDiffusionXLPipeline,
OVLatentConsistencyModelPipeline,
)
from optimum.utils.import_utils import _diffusers_version
from optimum.intel.openvino.modeling_diffusion import (
OVModelTextEncoder,
OVModelUnet,
Expand Down Expand Up @@ -475,3 +478,71 @@ def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"):
inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type)
inputs["strength"] = 0.75
return inputs


class OVLatentConsistencyModelPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("latent-consistency",)
MODEL_CLASS = OVLatentConsistencyModelPipeline
TASK = "text-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
def test_compare_to_diffusers(self, model_arch: str):
ov_pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True)
self.assertIsInstance(ov_pipeline.text_encoder, OVModelTextEncoder)
self.assertIsInstance(ov_pipeline.vae_encoder, OVModelVaeEncoder)
self.assertIsInstance(ov_pipeline.vae_decoder, OVModelVaeDecoder)
self.assertIsInstance(ov_pipeline.unet, OVModelUnet)
self.assertIsInstance(ov_pipeline.config, Dict)

from diffusers import LatentConsistencyModelPipeline

pipeline = LatentConsistencyModelPipeline.from_pretrained(MODEL_NAMES[model_arch])
batch_size, num_images_per_prompt, height, width = 2, 3, 64, 128
latents = ov_pipeline.prepare_latents(
batch_size * num_images_per_prompt,
ov_pipeline.unet.config["in_channels"],
height,
width,
dtype=np.float32,
generator=np.random.RandomState(0),
)

kwargs = {
"prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size,
"num_inference_steps": 1,
"num_images_per_prompt": num_images_per_prompt,
"height": height,
"width": width,
"guidance_scale": 8.5,
}

for output_type in ["latent", "np"]:
ov_outputs = ov_pipeline(latents=latents, output_type=output_type, **kwargs).images
self.assertIsInstance(ov_outputs, np.ndarray)
with torch.no_grad():
outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images

# Compare model outputs
self.assertTrue(np.allclose(ov_outputs, outputs, atol=1e-4))
# Compare model devices
self.assertEqual(pipeline.device.type, ov_pipeline.device)


@parameterized.expand(SUPPORTED_ARCHITECTURES)
@unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
def test_num_images_per_prompt_static_model(self, model_arch: str):
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
batch_size, num_images, height, width = 3, 4, 128, 64
pipeline.half()
pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images)
self.assertFalse(pipeline.is_dynamic)
pipeline.compile()

for _height in [height, height + 16]:
inputs = _generate_inputs(batch_size)
outputs = pipeline(**inputs, num_images_per_prompt=num_images, height=_height, width=width).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))


1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner",
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"sew": "hf-internal-testing/tiny-random-SEWModel",
"sew_d": "hf-internal-testing/tiny-random-SEWDModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
Expand Down

0 comments on commit ca84adb

Please sign in to comment.