diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 67a01011a..0230394d2 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -106,6 +106,8 @@ "OVLatentConsistencyModelPipeline", "OVLatentConsistencyModelImg2ImgPipeline", "OVFluxPipeline", + "OVFluxImg2ImgPipeline", + "OVFluxInpaintPipeline", "OVPipelineForImage2Image", "OVPipelineForText2Image", "OVPipelineForInpainting", @@ -126,6 +128,8 @@ "OVLatentConsistencyModelPipeline", "OVLatentConsistencyModelImg2ImgPipeline", "OVFluxPipeline", + "OVFluxImg2ImgPipeline", + "OVFluxInpaintPipeline", "OVPipelineForImage2Image", "OVPipelineForText2Image", "OVPipelineForInpainting", diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 589a0938e..373773332 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -82,6 +82,8 @@ if is_diffusers_available(): from .modeling_diffusion import ( OVDiffusionPipeline, + OVFluxImg2ImgPipeline, + OVFluxInpaintPipeline, OVFluxPipeline, OVLatentConsistencyModelImg2ImgPipeline, OVLatentConsistencyModelPipeline, diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 3ce1cc73f..9c53994b8 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -86,13 +86,20 @@ if is_diffusers_version(">=", "0.29.0"): from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline else: - StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline = StableDiffusionPipeline, StableDiffusionImg2ImgPipeline + StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline = object, object if is_diffusers_version(">=", "0.30.0"): from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline else: - StableDiffusion3InpaintPipeline = StableDiffusionInpaintPipeline - FluxPipeline = StableDiffusionPipeline + StableDiffusion3InpaintPipeline = object + FluxPipeline = object + + +if is_diffusers_version(">=", "0.31.0"): + from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline +else: + FluxImg2ImgPipeline = object + FluxInpaintPipeline = object DIFFUSION_MODEL_TRANSFORMER_SUBFOLDER = "transformer" @@ -887,9 +894,6 @@ def compile(self): def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs): return cls.load_config(config_name_or_path, **kwargs) - def _save_config(self, save_directory): - self.save_config(save_directory) - @property def components(self) -> Dict[str, Any]: components = { @@ -1447,6 +1451,18 @@ class OVFluxPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxPip auto_model_class = FluxPipeline +class OVFluxImg2ImgPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxImg2ImgPipeline): + main_input_name = "prompt" + export_feature = "image-to-image" + auto_model_class = FluxImg2ImgPipeline + + +class OVFluxInpaintPipeline(OVDiffusionPipeline, OVTextualInversionLoaderMixin, FluxInpaintPipeline): + main_input_name = "prompt" + export_feature = "inpainting" + auto_model_class = FluxInpaintPipeline + + SUPPORTED_OV_PIPELINES = [ OVStableDiffusionPipeline, OVStableDiffusionImg2ImgPipeline, @@ -1510,6 +1526,10 @@ def _get_ov_class(pipeline_class_name: str, throw_error_if_not_exist: bool = Tru OV_INPAINT_PIPELINES_MAPPING["stable-diffusion-3"] = OVStableDiffusion3InpaintPipeline OV_TEXT2IMAGE_PIPELINES_MAPPING["flux"] = OVFluxPipeline +if is_diffusers_version(">=", "0.31.0"): + SUPPORTED_OV_PIPELINES.extend([OVFluxImg2ImgPipeline, OVFluxInpaintPipeline]) + OV_INPAINT_PIPELINES_MAPPING["flux"] = OVFluxInpaintPipeline + OV_IMAGE2IMAGE_PIPELINES_MAPPING["flux"] = OVFluxImg2ImgPipeline SUPPORTED_OV_PIPELINES_MAPPINGS = [ OV_TEXT2IMAGE_PIPELINES_MAPPING, diff --git a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py index 38aea6c1f..a6a10651c 100644 --- a/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py +++ b/optimum/intel/utils/dummy_openvino_and_diffusers_objects.py @@ -189,3 +189,25 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["openvino", "diffusers"]) + + +class OVFluxImg2ImgPipeline(metaclass=DummyObject): + _backends = ["openvino", "diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "diffusers"]) + + +class OVFluxInpaintPipeline(metaclass=DummyObject): + _backends = ["openvino", "diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "diffusers"]) diff --git a/tests/openvino/test_diffusion.py b/tests/openvino/test_diffusion.py index 2baeba9a4..5f151cb83 100644 --- a/tests/openvino/test_diffusion.py +++ b/tests/openvino/test_diffusion.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import unittest from pathlib import Path @@ -134,8 +135,8 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): height, width, batch_size = 128, 128, 1 inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None) - diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None) + ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type @@ -330,6 +331,15 @@ def test_load_and_save_pipeline_with_safety_checker(self): ]: subdir_path = Path(tmpdirname) / subdir self.assertTrue(subdir_path.is_dir()) + # check that config contains original model classes + pipeline_config = Path(tmpdirname) / "model_index.json" + self.assertTrue(pipeline_config.exists()) + with pipeline_config.open("r") as f: + config = json.load(f) + for key in ["unet", "vae", "text_encoder"]: + model_lib, model_class = config[key] + self.assertTrue(model_lib in ["diffusers", "transformers"]) + self.assertFalse(model_class.startswith("OV")) loaded_pipeline = self.OVMODEL_CLASS.from_pretrained(tmpdirname) self.assertTrue(loaded_pipeline.safety_checker is not None) self.assertIsInstance(loaded_pipeline.safety_checker, StableDiffusionSafetyChecker) @@ -398,19 +408,24 @@ class OVPipelineForImage2ImageTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl", "latent-consistency"] if is_transformers_version(">=", "4.40.0"): SUPPORTED_ARCHITECTURES.append("stable-diffusion-3") + SUPPORTED_ARCHITECTURES.append("flux") AUTOMODEL_CLASS = AutoPipelineForImage2Image OVMODEL_CLASS = OVPipelineForImage2Image TASK = "image-to-image" - def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"): + def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil", model_type=None): inputs = _generate_prompts(batch_size=batch_size) inputs["image"] = _generate_images( height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type ) + if "flux" == model_type: + inputs["height"] = height + inputs["width"] = width + inputs["strength"] = 0.75 return inputs @@ -439,7 +454,9 @@ def test_num_images_per_prompt(self, model_arch: str): for height in [64, 128]: for width in [64, 128]: for num_images_per_prompt in [1, 3]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs = self.generate_inputs( + height=height, width=width, batch_size=batch_size, model_type=model_arch + ) outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images self.assertEqual(outputs.shape, (batch_size * num_images_per_prompt, height, width, 3)) @@ -447,7 +464,7 @@ def test_num_images_per_prompt(self, model_arch: str): @require_diffusers def test_callback(self, model_arch: str): height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch) class Callback: def __init__(self): @@ -478,7 +495,9 @@ def test_shape(self, model_arch: str): height, width, batch_size = 128, 64, 1 for input_type in ["pil", "np", "pt"]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) + inputs = self.generate_inputs( + height=height, width=width, batch_size=batch_size, input_type=input_type, model_type=model_arch + ) for output_type in ["pil", "np", "pt", "latent"]: inputs["output_type"] = output_type @@ -490,29 +509,35 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: - out_channels = ( - pipeline.unet.config.out_channels - if pipeline.unet is not None - else pipeline.transformer.config.out_channels - ) - self.assertEqual( - outputs.shape, - ( - batch_size, - out_channels, - height // pipeline.vae_scale_factor, - width // pipeline.vae_scale_factor, - ), - ) + if model_arch != "flux": + out_channels = ( + pipeline.unet.config.out_channels + if pipeline.unet is not None + else pipeline.transformer.config.out_channels + ) + self.assertEqual( + outputs.shape, + ( + batch_size, + out_channels, + height // pipeline.vae_scale_factor, + width // pipeline.vae_scale_factor, + ), + ) + else: + packed_height = height // pipeline.vae_scale_factor + packed_width = width // pipeline.vae_scale_factor + channels = pipeline.transformer.config.in_channels + self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers def test_compare_to_diffusers_pipeline(self, model_arch: str): height, width, batch_size = 128, 128, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch) - diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None) - ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], text_encoder_3=None) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ov_pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) for output_type in ["latent", "np", "pt"]: print(output_type) @@ -529,7 +554,7 @@ def test_image_reproducibility(self, model_arch: str): pipeline = self.OVMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) height, width, batch_size = 64, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch) for generator_framework in ["np", "pt"]: ov_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) @@ -551,7 +576,7 @@ def test_safety_checker(self, model_arch: str): self.assertIsInstance(ov_pipeline.safety_checker, StableDiffusionSafetyChecker) height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, model_type=model_arch) ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)) diffusers_output = pipeline(**inputs, generator=get_generator("pt", SEED)) @@ -586,9 +611,13 @@ def test_height_width_properties(self, model_arch: str): self.assertFalse(ov_pipeline.is_dynamic) expected_batch = batch_size * num_images_per_prompt - if ov_pipeline.unet is None or "timestep_cond" not in { - inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs - }: + if ( + ov_pipeline.unet is not None + and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} + ) or ( + ov_pipeline.transformer is not None + and "txt_ids" not in {inputs.get_any_name() for inputs in ov_pipeline.transformer.model.inputs} + ): expected_batch *= 2 self.assertEqual(ov_pipeline.batch_size, expected_batch) self.assertEqual(ov_pipeline.height, height) @@ -604,7 +633,7 @@ def test_textual_inversion(self): model_id = "runwayml/stable-diffusion-v1-5" ti_id = "sd-concepts-library/cat-toy" - inputs = self.generate_inputs() + inputs = self.generate_inputs(model_type="stable-diffusion") inputs["prompt"] = "A backpack" diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None) @@ -624,6 +653,7 @@ class OVPipelineForInpaintingTest(unittest.TestCase): if is_transformers_version(">=", "4.40.0"): SUPPORTED_ARCHITECTURES.append("stable-diffusion-3") + SUPPORTED_ARCHITECTURES.append("flux") AUTOMODEL_CLASS = AutoPipelineForInpainting OVMODEL_CLASS = OVPipelineForInpainting @@ -721,20 +751,26 @@ def test_shape(self, model_arch: str): elif output_type == "pt": self.assertEqual(outputs.shape, (batch_size, 3, height, width)) else: - out_channels = ( - pipeline.unet.config.out_channels - if pipeline.unet is not None - else pipeline.transformer.config.out_channels - ) - self.assertEqual( - outputs.shape, - ( - batch_size, - out_channels, - height // pipeline.vae_scale_factor, - width // pipeline.vae_scale_factor, - ), - ) + if model_arch != "flux": + out_channels = ( + pipeline.unet.config.out_channels + if pipeline.unet is not None + else pipeline.transformer.config.out_channels + ) + self.assertEqual( + outputs.shape, + ( + batch_size, + out_channels, + height // pipeline.vae_scale_factor, + width // pipeline.vae_scale_factor, + ), + ) + else: + packed_height = height // pipeline.vae_scale_factor + packed_width = width // pipeline.vae_scale_factor + channels = pipeline.transformer.config.in_channels + self.assertEqual(outputs.shape, (batch_size, packed_height * packed_width, channels)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @require_diffusers @@ -816,9 +852,13 @@ def test_height_width_properties(self, model_arch: str): self.assertFalse(ov_pipeline.is_dynamic) expected_batch = batch_size * num_images_per_prompt - if ov_pipeline.unet is None or "timestep_cond" not in { - inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs - }: + if ( + ov_pipeline.unet is not None + and "timestep_cond" not in {inputs.get_any_name() for inputs in ov_pipeline.unet.model.inputs} + ) or ( + ov_pipeline.transformer is not None + and "txt_ids" not in {inputs.get_any_name() for inputs in ov_pipeline.transformer.model.inputs} + ): expected_batch *= 2 self.assertEqual( ov_pipeline.batch_size,