Skip to content

Commit

Permalink
add safety_checker to save_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 7, 2024
1 parent 9d729ae commit 29e5174
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
self.tokenizer_3.save_pretrained(save_directory / "tokenizer_3")
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
if getattr(self, "safety_checker", None) is not None:
self.safety_checker.save_pretrained(save_directory / "safety_checker")

self._save_openvino_config(save_directory)

Expand Down Expand Up @@ -422,12 +424,15 @@ def _from_pretrained(
module = getattr(pipelines, module_name)
else:
module = importlib.import_module(module_name)
logger.warn(module)
class_obj = getattr(module, module_class)
load_method = getattr(class_obj, "from_pretrained")
# Check if the module is in a subdirectory
if (model_save_path / name).is_dir():
logger.warn(name)
submodels[name] = load_method(model_save_path / name)
else:
logger.warn(name)
submodels[name] = load_method(model_save_path)

models = {
Expand All @@ -449,8 +454,10 @@ def _from_pretrained(
if (quantization_config is None or quantization_config.dataset is None) and not compile_only:
for name, path in models.items():
if name in kwargs:
logger.warn(name)
models[name] = kwargs.pop(name)
else:
logger.warn(name)
models[name] = cls.load_model(path, quantization_config) if path.is_file() else None
elif compile_only:
ov_config = kwargs.get("ov_config", {})
Expand Down

0 comments on commit 29e5174

Please sign in to comment.