From 17c5906b0847b45aacd9fdf7c7e2f4e9e10723a3 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 5 Sep 2023 16:46:20 +0200 Subject: [PATCH 1/5] add vae image processor --- .github/workflows/test_openvino.yml | 1 + optimum/intel/openvino/modeling_diffusion.py | 22 +++++-- setup.py | 3 +- tests/openvino/test_stable_diffusion.py | 65 +++++++++++++------- 4 files changed, 63 insertions(+), 28 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index cb58f412a6..f0b40fa5d1 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -32,6 +32,7 @@ jobs: python -m pip install --upgrade pip # install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install git+https://github.com/huggingface/optimum.git pip install .[openvino,nncf,tests,diffusers] - name: Test with Pytest run: | diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 73ec66d473..1085c9e81c 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -30,7 +30,7 @@ StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME +from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available from huggingface_hub import snapshot_download from openvino._offline_transformations import compress_model_transformation from openvino.runtime import Core @@ -42,6 +42,7 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin +from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, @@ -106,6 +107,8 @@ def __init__( else: self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.scheduler = scheduler @@ -687,12 +690,21 @@ class OVStableDiffusionXLPipelineBase(OVStableDiffusionPipelineBase): auto_model_class = StableDiffusionXLPipeline export_feature = "stable-diffusion-xl" - def __init__(self, *args, **kwargs): + def __init__(self, *args, add_watermarker: Optional[bool] = None, **kwargs): super().__init__(*args, **kwargs) - # additional invisible-watermark dependency for SD XL - from optimum.pipelines.diffusers.watermark import StableDiffusionXLWatermarker - self.watermark = StableDiffusionXLWatermarker() + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + if not is_invisible_watermark_available(): + raise ImportError( + "`add_watermarker` requires invisible-watermark to be installed, which can be installed with `pip install invisible-watermark`." + ) + from optimum.pipelines.diffusers.watermark import StableDiffusionXLWatermarker + + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None class OVStableDiffusionXLPipeline(OVStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): diff --git a/setup.py b/setup.py index c35640226d..8e5130c672 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "torchaudio", "rjieba", "timm", + "invisible-watermark>=0.2.0", ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"] @@ -44,7 +45,7 @@ "openvino": ["openvino>=2023.0.0", "onnx", "onnxruntime"], "nncf": ["nncf>=2.5.0", "openvino-dev>=2023.0.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], - "diffusers": ["diffusers", "invisible-watermark>=0.2.0"], + "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, } diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index 35e85eeaa7..6b92bcc37c 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -60,12 +60,26 @@ def _generate_inputs(batch_size=1): return inputs -def _create_image(height=128, width=128): - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ) - return image.resize((width, height)) +def _create_image(height=128, width=128, batch_size=1, channel=3, input_type="pil"): + if input_type == "pil": + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((width, height)) + elif input_type == "np": + image = np.random.rand(height, width, channel) + elif input_type == "pt": + image = torch.rand((channel, height, width)) + + return [image] * batch_size + + +def to_np(image): + if isinstance(image[0], PIL.Image.Image): + return np.stack([np.array(i) for i in image], axis=0) + elif isinstance(image, torch.Tensor): + return image.cpu().numpy().transpose(0, 2, 3, 1) + return image class OVStableDiffusionPipelineBaseTest(unittest.TestCase): @@ -120,8 +134,10 @@ class OVStableDiffusionImg2ImgPipelineTest(OVStableDiffusionPipelineBaseTest): def test_compare_diffusers_pipeline(self, model_arch: str): model_id = MODEL_NAMES[model_arch] pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True) - inputs = self.generate_inputs() + height, width, batch_size = 128, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) inputs["prompt"] = "A painting of a squirrel eating a burger" + inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED)) np.random.seed(0) output = pipeline(**inputs).images[0, -3:, -3:, -1] # https://github.com/huggingface/diffusers/blob/v0.17.1/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py#L71 @@ -139,9 +155,9 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) - def generate_inputs(self, height=128, width=128, batch_size=1): + def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): inputs = _generate_inputs(batch_size) - inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED)) + inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) inputs["strength"] = 0.75 return inputs @@ -262,6 +278,17 @@ def test_compare_diffusers_pipeline(self, model_arch: str): generator=np.random.RandomState(0), ) inputs = self.generate_inputs(height=height, width=width) + + inputs["image"] = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((width, height)) + + inputs["mask_image"] = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ).resize((width, height)) + outputs = pipeline(**inputs, latents=latents).images self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) @@ -285,16 +312,8 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): def generate_inputs(self, height=128, width=128, batch_size=1): inputs = super(OVStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width, batch_size) - inputs["image"] = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ).resize((width, height)) - - inputs["mask_image"] = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ).resize((width, height)) - + inputs["image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0] + inputs["mask_image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0] return inputs @@ -396,7 +415,9 @@ def test_inference(self): pipeline.save_pretrained(tmp_dir) pipeline = self.MODEL_CLASS.from_pretrained(tmp_dir) - inputs = self.generate_inputs() + batch_size, height, width = 1, 128, 128 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED)) np.random.seed(0) output = pipeline(**inputs).images[0, -3:, -3:, -1] expected_slice = np.array([0.5675, 0.5108, 0.4758, 0.5280, 0.5080, 0.5473, 0.4789, 0.4286, 0.4861]) @@ -413,8 +434,8 @@ def test_num_images_per_prompt_static_model(self, model_arch: str): outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) - def generate_inputs(self, height=128, width=128, batch_size=1): + def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): inputs = _generate_inputs(batch_size) - inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED)) + inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) inputs["strength"] = 0.75 return inputs From b00ec488ef8d755ca84852ae582a2157c728310a Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 6 Sep 2023 15:09:49 +0200 Subject: [PATCH 2/5] fix tests --- tests/openvino/test_stable_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index 6b92bcc37c..f5095c4a93 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -18,6 +18,7 @@ from typing import Dict import numpy as np +import PIL import torch from diffusers import ( StableDiffusionPipeline, @@ -420,7 +421,7 @@ def test_inference(self): inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED)) np.random.seed(0) output = pipeline(**inputs).images[0, -3:, -3:, -1] - expected_slice = np.array([0.5675, 0.5108, 0.4758, 0.5280, 0.5080, 0.5473, 0.4789, 0.4286, 0.4861]) + expected_slice = np.array([0.5683, 0.5121, 0.4767, 0.5253, 0.5072, 0.5462, 0.4766, 0.4279, 0.4855]) self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-3)) @parameterized.expand(SUPPORTED_ARCHITECTURES) From bbec65acc371744ea2b397a7a8e1f5490f643e24 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 6 Sep 2023 15:15:20 +0200 Subject: [PATCH 3/5] add test --- tests/openvino/test_stable_diffusion.py | 37 +++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index f5095c4a93..3c0c90475a 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -86,6 +86,7 @@ def to_np(image): class OVStableDiffusionPipelineBaseTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ("stable-diffusion",) MODEL_CLASS = OVStableDiffusionPipeline + TASK = "text-to-image" @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_num_images_per_prompt(self, model_arch: str): @@ -119,6 +120,37 @@ def callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: self.assertTrue(callback_fn.has_been_called) self.assertEqual(callback_fn.number_of_steps, inputs["num_inference_steps"]) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + height, width, batch_size = 128, 64, 1 + pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + + if self.TASK == "image-to-image": + input_types = ["np", "pil", "pt"] + elif self.TASK == "text-to-image": + input_types = ["np"] + else: + input_types = ["pil"] + + for input_type in input_types: + if self.TASK == "image-to-image": + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) + else: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + def generate_inputs(self, height=128, width=128, batch_size=1): inputs = _generate_inputs(batch_size) inputs["height"] = height @@ -130,6 +162,7 @@ class OVStableDiffusionImg2ImgPipelineTest(OVStableDiffusionPipelineBaseTest): SUPPORTED_ARCHITECTURES = ("stable-diffusion",) MODEL_CLASS = OVStableDiffusionImg2ImgPipeline ORT_MODEL_CLASS = ORTStableDiffusionImg2ImgPipeline + TASK = "image-to-image" @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_diffusers_pipeline(self, model_arch: str): @@ -166,6 +199,7 @@ def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): class OVStableDiffusionPipelineTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ("stable-diffusion",) MODEL_CLASS = OVStableDiffusionPipeline + TASK = "text-to-image" @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_diffusers(self, model_arch: str): @@ -264,6 +298,7 @@ class OVStableDiffusionInpaintPipelineTest(OVStableDiffusionPipelineBaseTest): SUPPORTED_ARCHITECTURES = ("stable-diffusion",) MODEL_CLASS = OVStableDiffusionInpaintPipeline ORT_MODEL_CLASS = ORTStableDiffusionInpaintPipeline + TASK = "inpaint" @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_diffusers_pipeline(self, model_arch: str): @@ -323,6 +358,7 @@ class OVtableDiffusionXLPipelineTest(unittest.TestCase): MODEL_CLASS = OVStableDiffusionXLPipeline ORT_MODEL_CLASS = ORTStableDiffusionXLPipeline PT_MODEL_CLASS = StableDiffusionXLPipeline + TASK = "text-to-image" @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_diffusers(self, model_arch: str): @@ -407,6 +443,7 @@ class OVStableDiffusionXLImg2ImgPipelineTest(unittest.TestCase): MODEL_CLASS = OVStableDiffusionXLImg2ImgPipeline ORT_MODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline PT_MODEL_CLASS = StableDiffusionXLImg2ImgPipeline + TASK = "image-to-image" def test_inference(self): model_id = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" From 1a7e27d8e5d4381f65a8f2e001aabc67246edbf7 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 7 Sep 2023 17:11:37 +0200 Subject: [PATCH 4/5] add optimum min version --- .github/workflows/test_openvino.yml | 1 - setup.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index f0b40fa5d1..cb58f412a6 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -32,7 +32,6 @@ jobs: python -m pip install --upgrade pip # install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install git+https://github.com/huggingface/optimum.git pip install .[openvino,nncf,tests,diffusers] - name: Test with Pytest run: | diff --git a/setup.py b/setup.py index 8e5130c672..769431c31c 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ - "optimum>=1.10.0", + "optimum>=1.13.0", "transformers>=4.20.0", "datasets>=1.4.0", "sentencepiece", From 0b0b7d37d42e334c1ab487a8f5cd05bb3bf937ab Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 8 Sep 2023 11:41:06 +0200 Subject: [PATCH 5/5] fix --- optimum/intel/openvino/modeling_base.py | 8 -------- tests/openvino/test_stable_diffusion.py | 1 - 2 files changed, 9 deletions(-) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 14ac76137f..59fc89649a 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -65,7 +65,6 @@ class PreTrainedModel(OptimizedModel): """, ) class OVBaseModel(PreTrainedModel): - _AUTOMODELS_TO_TASKS = {cls_name: task for task, cls_name in TasksManager._TASKS_TO_AUTOMODELS.items()} auto_model_class = None export_feature = None @@ -391,13 +390,6 @@ def _ensure_supported_device(self, device: str = None): def forward(self, *args, **kwargs): raise NotImplementedError - @classmethod - def _auto_model_to_task(cls, auto_model_class): - """ - Get the task corresponding to a class (for example AutoModelForXXX in transformers). - """ - return cls._AUTOMODELS_TO_TASKS[auto_model_class.__name__] - def can_generate(self) -> bool: """ Returns whether this model can generate sequences with `.generate()`. diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index 3c0c90475a..e04e2d6fd3 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -121,7 +121,6 @@ def callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: self.assertEqual(callback_fn.number_of_steps, inputs["num_inference_steps"]) @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers def test_shape(self, model_arch: str): height, width, batch_size = 128, 64, 1 pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True)