Skip to content

Commit

Permalink
openvino
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 11, 2024
1 parent fb7a99e commit 3a21aa5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 122 deletions.
126 changes: 28 additions & 98 deletions optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@

import torch
from hydra.utils import get_class
from openvino.runtime import properties
from optimum.intel.openvino import OVConfig as OVQuantizationConfig # naming conflict
from optimum.intel.openvino import OVQuantizer

from ...generators.dataset_generator import DatasetGenerator
from ...import_utils import is_accelerate_available, is_torch_distributed_available
from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from ..transformers_utils import fast_weights_init
from .config import OVConfig
from .utils import TASKS_TO_MODEL_TYPES_TO_OVPIPELINE, TASKS_TO_OVMODEL
from .config import OVConfig as OVBackendConfig
from .utils import TASKS_OVPIPELINE, TASKS_TO_OVMODEL

if is_accelerate_available():
from accelerate import Accelerator
Expand All @@ -24,53 +19,26 @@
import torch.distributed


class OVBackend(Backend[OVConfig]):
class OVBackend(Backend[OVBackendConfig]):
NAME: str = "openvino"

def __init__(self, config: OVConfig) -> None:
def __init__(self, config: OVBackendConfig) -> None:
super().__init__(config)

if self.config.task in TASKS_TO_OVMODEL:
self.ovmodel_class = get_class(TASKS_TO_OVMODEL[self.config.task])
self.logger.info(f"\t+ Using OVModel class {self.ovmodel_class.__name__}")
elif self.config.task in TASKS_TO_MODEL_TYPES_TO_OVPIPELINE:
if self.config.model_type in TASKS_TO_MODEL_TYPES_TO_OVPIPELINE[self.config.task]:
self.ovmodel_class = get_class(
TASKS_TO_MODEL_TYPES_TO_OVPIPELINE[self.config.task][self.config.model_type]
)
self.logger.info(f"\t+ Using OVPipeline class {self.ovmodel_class.__name__}")
else:
raise NotImplementedError(
f"OVBackend does not support model {self.config.model_type} for task {self.config.task}"
)
elif self.config.task in TASKS_OVPIPELINE:
self.ovmodel_class = get_class(TASKS_OVPIPELINE[self.config.task])
self.logger.info(f"\t+ Using OVDiffusionPipeline class {self.ovmodel_class.__name__}")
else:
raise NotImplementedError(f"OVBackend does not support task {self.config.task}")

if self.config.inter_op_num_threads is not None:
self.logger.info(f"\t+ Setting inter_op_num_threads to {self.config.inter_op_num_threads}")
self.config.openvino_config[properties.inference_num_threads()] = self.config.inter_op_num_threads

def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()

if self.config.quantization:
if self.config.no_weights:
self.logger.info("\t+ Creating no weights AutoModel")
self.create_no_weights_model()
self.logger.info("\t+ Loading no weights AutoModel")
self._load_automodel_with_no_weights()
else:
self.logger.info("\t+ Loading pretrained AutoModel")
self._load_automodel_from_pretrained()
self.logger.info("\t+ Applying post-training quantization")
self.quantize_automodel()
original_model, self.config.model = self.config.model, self.quantized_model
original_export, self.config.export = self.config.export, False
self.logger.info("\t+ Loading quantized OVModel")
self._load_ovmodel_from_pretrained()
self.config.model, self.config.export = original_model, original_export
elif self.config.no_weights:
if self.config.no_weights:
self.logger.info("\t+ Creating no weights OVModel")
self.create_no_weights_model()
self.logger.info("\t+ Loading no weights OVModel")
Expand All @@ -85,9 +53,6 @@ def load(self) -> None:
for key, value in self.model_shapes.items()
if key in inspect.getfullargspec(self.pretrained_model.reshape).args
}
if ("sequence_length" in static_shapes) and ("height" in static_shapes) and ("width" in static_shapes):
# for vision models, sequence_length is the number of channels
static_shapes["sequence_length"] = self.model_shapes.get("num_channels")

self.logger.info(f"\t+ Reshaping model with static shapes: {static_shapes}")
self.pretrained_model.reshape(**static_shapes)
Expand All @@ -102,26 +67,9 @@ def load(self) -> None:

self.tmpdir.cleanup()

def _load_automodel_from_pretrained(self) -> None:
self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs)

def _load_automodel_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model

with fast_weights_init():
self._load_automodel_from_pretrained()

self.logger.info("\t+ Tying model weights")
self.pretrained_model.tie_weights()

self.config.model = original_model

def _load_ovmodel_from_pretrained(self) -> None:
self.pretrained_model = self.ovmodel_class.from_pretrained(
self.config.model,
export=self.config.export,
ov_config=self.config.openvino_config,
device=self.config.device,
**self.config.model_kwargs,
**self.ovmodel_kwargs,
)
Expand All @@ -135,61 +83,36 @@ def _load_ovmodel_with_no_weights(self) -> None:
self.config.export = original_export
self.config.model = original_model

def quantize_automodel(self) -> None:
self.logger.info("\t+ Attempting quantization")
self.quantized_model = f"{self.tmpdir.name}/quantized_model"
self.logger.info("\t+ Processing quantization config")
quantization_config = OVQuantizationConfig(**self.config.quantization_config)
self.logger.info("\t+ Creating quantizer")
quantizer = OVQuantizer.from_pretrained(self.pretrained_model, task=self.config.task, seed=self.config.seed)

if self.config.calibration:
self.logger.info("\t+ Generating calibration dataset")
dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes}
calibration_dataset = DatasetGenerator(
task=self.config.task, dataset_shapes=dataset_shapes, model_shapes=self.model_shapes
)()
columns_to_be_removed = list(set(calibration_dataset.column_names) - set(quantizer._export_input_names))
calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed)
else:
calibration_dataset = None

self.logger.info("\t+ Quantizing model")
quantizer.quantize(
save_directory=self.quantized_model,
quantization_config=quantization_config,
calibration_dataset=calibration_dataset,
# TODO: add support for these (maybe)
remove_unused_columns=True,
data_collator=None,
weights_only=False,
file_name=None,
batch_size=1,
)

@property
def ovmodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.config.task in TEXT_GENERATION_TASKS:
if self.config.export is not None:
kwargs["export"] = self.config.export

if self.config.use_cache is not None:
kwargs["use_cache"] = self.config.use_cache

if self.config.use_merged is not None:
kwargs["use_merged"] = self.config.use_merged

if self.config.load_in_8bit is not None:
kwargs["load_in_8bit"] = self.config.load_in_8bit

if self.config.load_in_4bit is not None:
kwargs["load_in_4bit"] = self.config.load_in_4bit

return kwargs

@property
def split_between_processes(self) -> bool:
return is_torch_distributed_available() and torch.distributed.is_initialized()

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.split_between_processes:
with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
inputs = process_inputs

for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)

if "input_ids" in inputs:
self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape)))

Expand All @@ -200,6 +123,13 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:

return inputs

def prepare_inputs_after_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
for key in list(inputs.keys()):
if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names:
inputs.pop(key)

return inputs

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.forward(**inputs, **kwargs)

Expand Down
30 changes: 12 additions & 18 deletions optimum_benchmark/backends/openvino/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,22 @@ class OVConfig(BackendConfig):
version: Optional[str] = openvino_version()
_target_: str = "optimum_benchmark.backends.openvino.backend.OVBackend"

# load options
no_weights: bool = False

# export options
export: bool = True
use_cache: bool = True
use_merged: bool = False

# openvino config
openvino_config: Dict[str, Any] = field(default_factory=dict)
# ovmodel kwargs
export: Optional[bool] = None
use_cache: Optional[bool] = None
use_merged: Optional[bool] = None
load_in_8bit: Optional[bool] = None
load_in_4bit: Optional[bool] = None

# compilation options
half: bool = False
compile: bool = False
reshape: bool = False

# quantization options
quantization: bool = False
quantization_config: Dict[str, Any] = field(default_factory=dict)

# calibration options
calibration: bool = False
calibration_config: Dict[str, Any] = field(default_factory=dict)
# openvino config
ov_config: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
super().__post_init__()
Expand All @@ -42,7 +36,7 @@ def __post_init__(self):
raise ValueError(f"OVBackend only supports CPU devices, got {self.device}")

if self.intra_op_num_threads is not None:
raise NotImplementedError("OVBackend does not support intra_op_num_threads")
raise NotImplementedError("OVBackend does not support intra_op_num_threads. Please use the ov_config")

if self.quantization and not self.calibration:
raise ValueError("OpenVINO quantization requires enabling calibration.")
if self.inter_op_num_threads is not None:
raise NotImplementedError("OVBackend does not support inter_op_num_threads. Please use the ov_config")
10 changes: 4 additions & 6 deletions optimum_benchmark/backends/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
"audio-classification": "optimum.intel.openvino.OVModelForAudioClassification",
"pix2struct": "optimum.intel.openvino.OVModelForPix2Struct",
}
TASKS_TO_MODEL_TYPES_TO_OVPIPELINE = {
"text-to-image": {
"lcm": "optimum.intel.openvino.OVLatentConsistencyModelPipeline",
"stable-diffusion": "optimum.intel.openvino.OVStableDiffusionPipeline",
"stable-diffusion-xl": "optimum.intel.openvino.OVStableDiffusionXLPipeline",
},
TASKS_OVPIPELINE = {
"inpainting": "optimum.intel.openvino.OVPipelineForInpainting",
"text-to-image": "optimum.intel.openvino.OVPipelineForText2Image",
"image-to-image": "optimum.intel.openvino.OVPipelineForImage2Image",
}

0 comments on commit 3a21aa5

Please sign in to comment.