Skip to content

Commit

Permalink
add saving safety_checker (#990)
Browse files Browse the repository at this point in the history
* add saving safety_checker during conversion

* add safety_checker to save_pretrained

* add test

* Update modeling_diffusion.py
  • Loading branch information
eaidova authored Nov 12, 2024
1 parent b3cbc95 commit 790244d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,9 @@ def export_from_model(
tokenizer_3 = getattr(model, "tokenizer_3", None)
if tokenizer_3 is not None:
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))
safety_checker = getattr(model, "safety_checker", None)
if safety_checker is not None:
safety_checker.save_pretrained(output.joinpath("safety_checker"))

model.save_config(output)

Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
if getattr(self, "safety_checker", None) is not None:
self.safety_checker.save_pretrained(save_directory / "safety_checker")

self._save_openvino_config(save_directory)

Expand Down
27 changes: 27 additions & 0 deletions tests/openvino/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from pathlib import Path

import numpy as np
import pytest
Expand All @@ -35,6 +36,7 @@
OVPipelineForInpainting,
OVPipelineForText2Image,
)
from optimum.intel.openvino.utils import TemporaryDirectory
from optimum.intel.utils.import_utils import is_transformers_version
from optimum.utils.testing_utils import require_diffusers

Expand Down Expand Up @@ -309,6 +311,31 @@ def test_safety_checker(self, model_arch: str):

np.testing.assert_allclose(ov_images, diffusers_images, atol=1e-4, rtol=1e-2)

@require_diffusers
def test_load_and_save_pipeline_with_safety_checker(self):
model_id = "katuni4ka/tiny-random-stable-diffusion-with-safety-checker"
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id)
self.assertTrue(ov_pipeline.safety_checker is not None)
self.assertIsInstance(ov_pipeline.safety_checker, StableDiffusionSafetyChecker)
with TemporaryDirectory() as tmpdirname:
ov_pipeline.save_pretrained(tmpdirname)
for subdir in [
"text_encoder",
"tokenizer",
"unet",
"vae_encoder",
"vae_decoder",
"scheduler",
"feature_extractor",
]:
subdir_path = Path(tmpdirname) / subdir
self.assertTrue(subdir_path.is_dir())
loaded_pipeline = self.OVMODEL_CLASS.from_pretrained(tmpdirname)
self.assertTrue(loaded_pipeline.safety_checker is not None)
self.assertIsInstance(loaded_pipeline.safety_checker, StableDiffusionSafetyChecker)
del loaded_pipeline
del ov_pipeline

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_height_width_properties(self, model_arch: str):
batch_size, height, width, num_images_per_prompt = 2, 128, 64, 4
Expand Down

0 comments on commit 790244d

Please sign in to comment.