Skip to content

Commit

Permalink
merge main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng committed Oct 18, 2023
2 parents 5c4d92b + 65b20f5 commit 2f9916b
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 103 deletions.
1 change: 1 addition & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
else:
_import_structure["neural_compressor"] = [
"INCConfig",
"INCModel",
"INCModelForCausalLM",
"INCModelForMaskedLM",
"INCModelForMultipleChoice",
Expand Down
96 changes: 31 additions & 65 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
GenerationMixin,
PretrainedConfig,
XLNetLMHeadModel,
)
Expand All @@ -39,11 +40,8 @@
from transformers.utils import is_ipex_available
from transformers.utils.generic import ContextManagers

from ...exporters import TasksManager
from ...modeling_base import OptimizedModel
from ..generation.modeling import jit_trace
from ..utils.import_utils import _torch_version, is_torch_version
from ..utils.modeling_utils import patch_decoder_attention_mask
from .configuration import INCConfig
from .utils import WEIGHTS_NAME

Expand All @@ -65,6 +63,7 @@

class INCModel(OptimizedModel):
auto_model_class = AutoModel
export_feature = "feature-extraction"
base_model_prefix = "inc_model"

def __init__(
Expand All @@ -76,12 +75,13 @@ def __init__(
inc_config: Dict = None,
**kwargs,
):
super().__init__(model=model, config=config)

super().__init__(model=model, config=config, **kwargs)
self.inc_config = inc_config
self._q_config = q_config
self.model_save_dir = model_save_dir
self.is_quantized = q_config is not None
self._device = getattr(self.model, "device", None) or torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
)

if getattr(self.config, "backend", None) == "ipex":
if not is_ipex_available():
Expand Down Expand Up @@ -109,9 +109,10 @@ def _from_pretrained(
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = WEIGHTS_NAME,
file_name: str = WEIGHTS_NAME,
local_files_only: bool = False,
subfolder: str = "",
trust_remote_code: bool = False,
**kwargs,
):
model_name_or_path = kwargs.pop("model_name_or_path", None)
Expand Down Expand Up @@ -143,8 +144,10 @@ def _from_pretrained(
if not is_torch_version("==", inc_config.torch_version):
msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found."
logger.warning(f"{msg}")
except Exception:
logger.info("Couldn't verify torch version.")
except EnvironmentError:
msg = (
f"Please check if torch quantization the model was obtained with is compatible with {_torch_version}."
)

if getattr(config, "backend", None) == "ipex" or getattr(config, "torchscript", False):
# NOTE: Will improve to use load function when Intel Neural Compressor next 2.1 release.
Expand Down Expand Up @@ -195,63 +198,26 @@ def forward(self, *args, **kwargs):

def eval(self):
self.model.eval()
return self

@classmethod
def _from_transformers(
cls,
model_id: str,
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
use_cache: bool = True,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
**kwargs,
):
if is_torch_version("<", "2.0.0"):
raise ImportError("`torch>=2.0.0` is needed to trace your model")

task = cls.export_feature
kwargs.get("file_name", None)

model_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"torch_dtype": torch_dtype,
}

if config.torch_dtype == "int8" or config.torch_dtype == torch.int8:
raise ValueError("quantized model cannot be exported")

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)

if task == "text-generation":
model = patch_decoder_attention_mask(model)

traced_model = jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
config.torchscript = True

return cls._from_pretrained(
model_id=save_dir_path,
config=config,
use_cache=use_cache,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
local_files_only=local_files_only,
**kwargs,
)
@property
def device(self) -> torch.device:
return self._device

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
return self

def can_generate(self):
return isinstance(self.model, GenerationMixin)

def generate(self, *args, **kwargs):
if not self.can_generate():
raise TypeError(
f"The current model class {self.model.__class__} is not compatible with `.generate()`, as it doesn't have a language model head."
)
return self.model.generate(*args, **kwargs)


class INCModelForQuestionAnswering(INCModel):
Expand Down
14 changes: 12 additions & 2 deletions optimum/intel/neural_compressor/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union
from typing import Dict, Optional, Union

from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings
Expand All @@ -39,15 +39,25 @@ class INCModelForCausalLM(INCModel, BaseModelForCausalLM):
auto_model_class = AutoModelForCausalLM
export_feature = "text-generation"
forward = BaseModelForCausalLM.forward
generate = BaseModelForCausalLM.generate
can_generate = BaseModelForCausalLM.can_generate

def __init__(
self,
model,
config: PretrainedConfig = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
q_config: Dict = None,
inc_config: Dict = None,
use_cache: bool = True,
**kwargs,
):
super(INCModelForCausalLM, self).__init__(
model=model, config=config, model_save_dir=model_save_dir, use_cache=use_cache, **kwargs
model=model,
config=config,
model_save_dir=model_save_dir,
q_config=q_config,
inc_config=inc_config,
use_cache=use_cache,
**kwargs,
)
13 changes: 12 additions & 1 deletion optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,18 @@
is_ipex_version,
is_neural_compressor_version,
)
from .configuration import INCConfig, WeightOnlyQuantConfig
from .configuration import INCConfig
from .modeling_base import ( # noqa
INCModel,
INCModelForMaskedLM,
INCModelForMultipleChoice,
INCModelForQuestionAnswering,
INCModelForSeq2SeqLM,
INCModelForSequenceClassification,
INCModelForTokenClassification,
INCModelForVision2Seq,
INCModelForXLNetLM,
)
from .utils import INCDataLoader, _cfgs_to_fx_cfgs


Expand Down
1 change: 1 addition & 0 deletions optimum/intel/neural_compressor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"question-answering": "INCModelForQuestionAnswering",
"multiple-choice": "INCModelForMultipleChoice",
"stable-diffusion": "INCStableDiffusionPipeline",
"feature-extraction": "INCModel",
}


Expand Down
5 changes: 4 additions & 1 deletion optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ def is_torch_version(operation: str, version: str):
"""
if not _torch_available:
return False
return compare_versions(parse(_torch_version), operation, version)

import torch

return compare_versions(parse(parse(torch.__version__).base_version), operation, version)


def is_ipex_version(operation: str, version: str):
Expand Down
Loading

0 comments on commit 2f9916b

Please sign in to comment.