From 5ee162820f8b50376c43c6e16333aefa1289f0de Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 18 Oct 2023 19:00:20 +0200 Subject: [PATCH 01/29] Enable openvino inference for gpt big code models --- optimum/intel/openvino/modeling_base.py | 60 ++++-- optimum/intel/openvino/modeling_decoder.py | 222 +++++++++++++++------ optimum/intel/utils/modeling_utils.py | 3 + setup.py | 2 +- tests/openvino/test_modeling.py | 2 + tests/openvino/utils_tests.py | 5 +- 6 files changed, 214 insertions(+), 80 deletions(-) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 58eb2163d0..321f50f570 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -153,6 +153,7 @@ def _from_pretrained( force_download: bool = False, cache_dir: Optional[str] = None, file_name: Optional[str] = None, + subfolder: str = "", from_onnx: bool = False, local_files_only: bool = False, load_in_8bit: bool = False, @@ -184,38 +185,59 @@ def _from_pretrained( local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). """ + + model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name - # Load the model from local directory - if os.path.isdir(model_id): - file_name = os.path.join(model_id, file_name) - model_save_dir = model_id - # Download the model from the hub - else: - model_file_names = [file_name] - # If not ONNX then OpenVINO IR + model_cache_path = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit) + return cls(model, config=config, model_save_dir=model_cache_path.parent, **kwargs) - if not from_onnx: - model_file_names.append(file_name.replace(".xml", ".bin")) - file_names = [] + @staticmethod + def _cached_file( + model_path: Union[Path, str], + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + ): + # locates a file in a local folder and repo, downloads and cache it if necessary. + model_path = Path(model_path) + if model_path.is_dir(): + model_cache_path = model_path / file_name + else: + file_name = Path(file_name) + if file_name.suffix != "onnx": + model_file_names = [file_name.with_suffix(".bin"), file_name] + else: + model_file_names = [file_name] for file_name in model_file_names: model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, + repo_id=model_path.as_posix(), + filename=file_name.as_posix(), + subfolder=subfolder, use_auth_token=use_auth_token, revision=revision, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, ) - file_names.append(model_cache_path) - model_save_dir = Path(model_cache_path).parent - file_name = file_names[0] - - model = cls.load_model(file_name, load_in_8bit=load_in_8bit) + model_cache_path = Path(model_cache_path) - return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + return model_cache_path @classmethod def _from_transformers( diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 68d737fe74..96df4a2b84 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -31,6 +31,7 @@ from ...exporters.openvino import main_export from ..utils.import_utils import is_transformers_version +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE @@ -78,8 +79,9 @@ "bloom", "codegen", "gpt2", - "gpt_neo", - "gpt_neox", + "gpt-bigcode", + "gpt-neo", + "gpt-neox", "llama", "marian", "opt", @@ -213,7 +215,7 @@ def _from_transformers( load_in_8bit: bool = False, **kwargs, ): - if config.model_type not in _SUPPORTED_ARCHITECTURES: + if config.model_type.replace("_", "-") not in _SUPPORTED_ARCHITECTURES: logger.warning( f"This architecture : {config.model_type} was not validated, only :{', '.join(_SUPPORTED_ARCHITECTURES)} architectures were " "validated, use at your own risk." @@ -315,62 +317,61 @@ def forward( **kwargs, ) -> CausalLMOutputWithPast: self.compile() + inputs = {} if self.use_cache and past_key_values is not None: input_ids = input_ids[:, -1:] - inputs = {} if past_key_values is not None: - if self._pkv_precision == Type.bf16: - # numpy does not support bf16, pretending f16, should change to bf16 - past_key_values = tuple( - Tensor(past_key_value, past_key_value.shape, Type.bf16) - for pkv_per_layer in past_key_values - for past_key_value in pkv_per_layer - ) - else: - # Flatten the past_key_values - past_key_values = tuple( - past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer - ) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + if self._pkv_precision == Type.bf16: + # numpy does not support bf16, pretending f16, should change to bf16 + past_key_values = tuple( + Tensor(past_key_value, past_key_value.shape, Type.bf16) + for pkv_per_layer in past_key_values + for past_key_value in pkv_per_layer + ) + else: + # Flatten the past_key_values + past_key_values = tuple( + past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer + ) # Add the past_key_values to the decoder inputs inputs = dict(zip(self.key_value_input_names, past_key_values)) # Create empty past_key_values for decoder_with_past first generation step elif self.use_cache: - shape_input_ids = input_ids.shape - num_attention_heads = ( - self.normalized_config.num_attention_heads if self.config.model_type == "bloom" else 1 - ) + batch_size = input_ids.shape[0] + if self.config.model_type == "bloom": + batch_size *= self.normalized_config.num_attention_heads + for input_name in self.key_value_input_names: model_inputs = self.model.input(input_name) shape = model_inputs.get_partial_shape() - shape[0] = shape_input_ids[0] * num_attention_heads + shape[0] = batch_size if shape[2].is_dynamic: shape[2] = 0 - if shape[1].is_dynamic: + else: shape[1] = 0 inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) inputs["input_ids"] = np.array(input_ids) - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names and attention_mask is not None: + if attention_mask is not None: inputs["attention_mask"] = np.array(attention_mask) # Run inference - self.request.start_async(inputs, shared_memory=True) - self.request.wait() - - logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) + results = self.request.infer(inputs, share_inputs=True, share_outputs=True) + logits = torch.from_numpy(results["logits"]).to(self.device) if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) - past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) - past_key_values = tuple( - past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) - ) + past_key_values = tuple(results[key] for key in self.key_value_output_names) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) + past_key_values = tuple( + past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) + ) else: past_key_values = None @@ -378,41 +379,115 @@ def forward( # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - past_key_values = past_key_values or kwargs.get("past", None) - - # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed - if past_key_values is not None and self.config.model_type == "bloom": - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "past_key_values": past_key_values, - "use_cache": self.use_cache, - "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), - "token_type_ids": None, + "use_cache": use_cache, + "position_ids": position_ids, + "attention_mask": attention_mask, } + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache + @staticmethod def _reorder_cache( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - - if self.config.model_type == "bloom": - return self._reorder_cache_bloom(past_key_values, beam_idx) - - # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache return tuple( tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values ) - # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache - def _reorder_cache_bloom( + def can_generate(self): + """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" + return True + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: PretrainedConfig, + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + from_onnx: bool = False, + local_files_only: bool = False, + load_in_8bit: bool = False, + **kwargs, + ): + model_path = Path(model_id) + + default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME + file_name = file_name or default_file_name + + model_cache_path = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + + model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit) + + model_type = config.model_type.replace("_", "-") + if model_type == "bloom": + init_cls = OVBloomForCausalLM + elif model_type == "mpt": + init_cls = OVMPTForCausalLM + elif model_type == "opt": + init_cls = OVOPTForCausalLM + elif model_type == "gpt-bigcode": + init_cls = OVGPTBigCodeForCausalLM + else: + init_cls = OVModelForCausalLM + + return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs) + + +class OVBloomForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + # only last token for input_ids if past is not None + if past_key_values: + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache + def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ @@ -421,7 +496,6 @@ def _reorder_cache_bloom( This is required to match `past_key_values` with the correct beam_idx at every generation step. """ standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) - reordered_past = tuple( ( np.take(layer_past[0], beam_idx, 0), @@ -456,9 +530,6 @@ def _convert_to_standard_cache( """ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) """ - if self.config.model_type != "bloom": - return past_key_value - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape num_heads = batch_size_times_num_heads // batch_size # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] @@ -471,6 +542,41 @@ def _convert_to_standard_cache( for layer_past in past_key_value ) - def can_generate(self): - """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" - return True + +class OVOPTForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + +class OVMPTForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + +class OVGPTBigCodeForCausalLM(OVModelForCausalLM): + # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 17abf1059e..6b6cd30999 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,6 +18,9 @@ from transformers.modeling_utils import PreTrainedModel +MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} + + # Modified from transformers.models.bloom.modeling_bloom._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, diff --git a/setup.py b/setup.py index 5fc16692ee..6d81b98b2a 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ ], "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime"], "nncf": ["nncf>=2.6.0"], - "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx", "torch<2.1.0"], + "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index c381341f83..895c9b42f7 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -449,6 +449,7 @@ def test_pipeline(self, model_arch): class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", + "big_code", "blenderbot", "blenderbot-small", "bloom", @@ -459,6 +460,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neox", "llama", "marian", + # "mistral", "mpt", "opt", "pegasus", diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index c2feb4d264..2fa77052eb 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -39,7 +39,7 @@ "distilbert": "hf-internal-testing/tiny-random-distilbert", "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", - # "gpt_bigcode": "bigcode/tiny_starcoder_py", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", @@ -51,8 +51,9 @@ "llama": "fxmarty/tiny-llama-fast-tokenizer", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", - "marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken + "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", + "mistral": "echarlaix/tiny-random-mistral", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", From fec765559f49e12d6e3d9185e73d1707b570b564 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 16:02:18 +0200 Subject: [PATCH 02/29] fix --- optimum/intel/openvino/modeling_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 749963553c..0a5747336a 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -357,8 +357,8 @@ def forward( inputs = {} past_len = 0 if past_key_values is not None: - past_len = past_key_values[0][1].shape[-2] if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + past_len = past_key_values[0][1].shape[-2] if self._pkv_precision == Type.bf16: # numpy does not support bf16, pretending f16, should change to bf16 past_key_values = tuple( @@ -371,6 +371,8 @@ def forward( past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer ) + else: + past_len = past_key_values[0].shape[-2] # Add the past_key_values to the decoder inputs inputs = dict(zip(self.key_value_input_names, past_key_values)) From 327533e4bf08d88bebcfcc66d3d2872236213ae8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 16:04:31 +0200 Subject: [PATCH 03/29] format --- optimum/intel/openvino/modeling_decoder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 0a5747336a..8af37de8a4 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -33,7 +33,7 @@ from ..utils.import_utils import is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel -from .utils import OV_XML_FILE_NAME, STR_TO_OV_TYPE +from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE if is_transformers_version("<", "4.25.0"): @@ -477,7 +477,6 @@ def _from_pretrained( **kwargs, ): model_path = Path(model_id) - default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name From 36d482e9744331db9672511bd345717f9f4fb988 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 19:38:49 +0200 Subject: [PATCH 04/29] fix --- tests/openvino/test_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index c7968d6e5a..bf1a007844 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -449,7 +449,7 @@ def test_pipeline(self, model_arch): class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", - "big_code", + "gpt_bigcode", "blenderbot", "blenderbot-small", "bloom", From f8ba216b153d3ca9f9dea7e1f782622dc5b8e285 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 19:41:48 +0200 Subject: [PATCH 05/29] fix --- optimum/intel/openvino/modeling_decoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 8af37de8a4..0e018f9f62 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -404,12 +404,13 @@ def forward( ) # Run inference - results = self.request.infer(inputs, share_inputs=True, share_outputs=True) - logits = torch.from_numpy(results["logits"]).to(self.device) + self.request.start_async(inputs, shared_memory=True) + self.request.wait() + logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) - past_key_values = tuple(results[key] for key in self.key_value_output_names) + past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) past_key_values = tuple( From f52960bdb42400a0dbd0305b4007e3a0a34cd857 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 21:32:40 +0200 Subject: [PATCH 06/29] fix input names --- optimum/intel/openvino/modeling_base.py | 8 +++++++- optimum/intel/utils/modeling_utils.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 321f50f570..9384477eb9 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -79,7 +79,13 @@ def __init__( height = -1 if self.export_feature == "image-classification" else None width = -1 if self.export_feature == "image-classification" else None model = self._reshape(model, -1, -1, height, width) - self.input_names = {key.get_any_name(): idx for idx, key in enumerate(model.inputs)} + + input_names = {} + for idx, key in enumerate(model.inputs): + names = tuple(key.get_names()) + input_names[next((name for name in names if "/" not in name), names[0])] = idx + self.input_names = input_names + self.model = model self.request = None if enable_compilation: diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 6b6cd30999..b56e5e4f2d 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -18,6 +18,9 @@ from transformers.modeling_utils import PreTrainedModel +# from ...utils.modeling_utils import _prepare_decoder_sliding_window_attention_mask + + MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} @@ -109,6 +112,8 @@ def patch_decoder_attention_mask(model: "PreTrainedModel"): model.transformer._prepare_attn_mask = _prepare_attn_mask elif model.config.model_type == "llama": model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + # elif model.config.model_type == "mistral": + # model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask return model From 764b23b1085091c30e02d503008073da103ef5bc Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 23 Oct 2023 22:55:09 +0200 Subject: [PATCH 07/29] Fix export optimum modifications --- .github/workflows/test_openvino.yml | 1 + optimum/exporters/openvino/__main__.py | 29 +++--- optimum/intel/openvino/modeling_decoder.py | 5 + optimum/intel/openvino/quantization.py | 8 +- optimum/intel/utils/modeling_utils.py | 102 --------------------- 5 files changed, 21 insertions(+), 124 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index cb58f412a6..f0b40fa5d1 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -32,6 +32,7 @@ jobs: python -m pip install --upgrade pip # install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install git+https://github.com/huggingface/optimum.git pip install .[openvino,nncf,tests,diffusers] - name: Test with Pytest run: | diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 782aa0bc0d..d25c2c5f3a 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -27,7 +27,6 @@ from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ...intel.utils.import_utils import is_nncf_available -from ...intel.utils.modeling_utils import patch_decoder_attention_mask from .convert import export_models @@ -222,24 +221,18 @@ def main_export( preprocessors = maybe_load_preprocessors( model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code ) - if not task.startswith("text-generation"): - onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( - model=model, - task=task, - monolith=False, - custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, - custom_architecture=custom_architecture, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - _variant="default", - ) - else: - # TODO : ModelPatcher will be added in next optimum release - model = patch_decoder_attention_mask(model) - onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config) - models_and_onnx_configs = {"model": (model, onnx_config)} + onnx_config, models_and_onnx_configs = optimum_main._get_submodels_and_onnx_configs( + model=model, + task=task, + monolith=False, + custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_architecture=custom_architecture, + fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, + _variant="default", + legacy=False, + ) if int8 is None: int8 = False diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 0e018f9f62..02fa96bd2c 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -346,6 +346,7 @@ def forward( input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: self.compile() @@ -403,6 +404,10 @@ def forward( (input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype ) + # Add the attention_mask inputs when needed + if "position_ids" in self.input_names and position_ids is not None: + inputs["position_ids"] = np.array(position_ids) + # Run inference self.request.start_async(inputs, shared_memory=True) self.request.wait() diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index bcc7c2908b..b94f61214d 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -39,7 +39,6 @@ from ...exporters.openvino import export, export_pytorch_via_onnx from ..utils.constant import _TASK_ALIASES -from ..utils.modeling_utils import patch_decoder_attention_mask from .configuration import OVConfig from .modeling_base import OVBaseModel from .modeling_decoder import OVBaseDecoderModel @@ -394,9 +393,10 @@ def _quantize_torchmodel( task = self.task model = self.model self.model.config.save_pretrained(save_directory) - model = patch_decoder_attention_mask(model) - if task == "text-generation": - onnx_config = onnx_config_class(model.config, use_past=model.config.use_cache) + if task.startswith("text-generation"): + onnx_config = onnx_config_class( + model.config, use_past=model.config.use_cache, use_past_in_inputs=model.config.use_cache + ) else: onnx_config = onnx_config_class(model.config) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index b56e5e4f2d..fa8a5d5d4d 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,108 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple - -import torch -from transformers.modeling_utils import PreTrainedModel - - -# from ...utils.modeling_utils import _prepare_decoder_sliding_window_attention_mask MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} - - -# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - device: torch.device, - past_key_values_length: int, - dtype: torch.dtype = torch.bool, -) -> torch.BoolTensor: - """ - Make causal mask used for bi-directional self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) - seq_ids = torch.arange(target_length, device=device) - - mask[:, past_key_values_length:] = ( - (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min - if torch.is_floating_point(mask) - else seq_ids[:, None] < seq_ids[None, :] - ) - - return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - - -# Modified from transformers.models..bloom.modeling_bloom._prepare_attn_mask -def _prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int -) -> torch.BoolTensor: - from transformers.models.bloom.modeling_bloom import _expand_mask - - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask -def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): - from transformers.models.llama.modeling_llama import _expand_mask - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - combined_attention_mask = _make_causal_mask( - input_shape, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - dtype=inputs_embeds.dtype, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -def patch_decoder_attention_mask(model: "PreTrainedModel"): - """ - Apply patch on decoder with past model forward to resolve first inference based on model architecture - - Args: - model (PretrainedModel): The model to patch. - - Returns: - model with applied patch - """ - if model.config.model_type in {"bloom", "mpt"}: - model.transformer._prepare_attn_mask = _prepare_attn_mask - elif model.config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - # elif model.config.model_type == "mistral": - # model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask - elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: - model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - return model From b111ca9a766b483fda00287500682e44dd34422e Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 11:57:48 +0100 Subject: [PATCH 08/29] add test --- tests/generation/test_modeling.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py index 0fd668ad8f..2b7c16140a 100644 --- a/tests/generation/test_modeling.py +++ b/tests/generation/test_modeling.py @@ -20,6 +20,7 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, pipeline, set_seed +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel.generation.modeling import TSModelForCausalLM @@ -28,6 +29,9 @@ "gptj": "hf-internal-testing/tiny-random-gptj", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "mistral": "echarlaix/tiny-random-mistral", + "llama": "fxmarty/tiny-llama-fast-tokenizer", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", } SEED = 42 @@ -48,7 +52,11 @@ class ModelingIntegrationTest(unittest.TestCase): "gpt2", "gptj", "gpt_neo", + "mistral", + "llama", + "gpt_bigcode", ) + GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 @@ -61,7 +69,12 @@ def test_compare_to_transformers(self, model_arch): trfs_model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer("This is a sample", return_tensors="pt") - outputs = model(**tokens) + + position_ids = None + if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: + input_shape = tokens["input_ids"].shape + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) + outputs = model(**tokens, position_ids=position_ids) self.assertIsInstance(outputs.logits, torch.Tensor) with torch.no_grad(): trfs_outputs = trfs_model(**tokens) @@ -71,7 +84,8 @@ def test_compare_to_transformers(self, model_arch): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) loaded_model = TSModelForCausalLM.from_pretrained(tmpdirname) - loaded_model_outputs = loaded_model(**tokens) + loaded_model_outputs = loaded_model(**tokens, position_ids=position_ids) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -120,7 +134,6 @@ def test_compare_with_and_without_past_key_values(self): model_id = MODEL_NAMES["gpt2"] tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer("This is a sample input", return_tensors="pt") - model_with_pkv = TSModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True) # Warmup _ = model_with_pkv.generate(**tokens) @@ -136,6 +149,9 @@ def test_compare_with_and_without_past_key_values(self): outputs_model_without_pkv = model_without_pkv.generate( **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) + self.assertTrue(model_with_pkv.use_cache) + self.assertFalse(model_without_pkv.use_cache) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) From d60fe98b6822e25ba337160a53b8173ec7c98061 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 15:36:37 +0100 Subject: [PATCH 09/29] Fix compatibility --- .github/workflows/test_openvino.yml | 1 - optimum/intel/generation/modeling.py | 137 +++++++++--------- .../intel/neural_compressor/quantization.py | 55 +++++++ optimum/intel/openvino/modeling_base.py | 9 +- optimum/intel/utils/modeling_utils.py | 134 +++++++++++++++++ setup.py | 2 +- 6 files changed, 258 insertions(+), 80 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index f0b40fa5d1..cb58f412a6 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -32,7 +32,6 @@ jobs: python -m pip install --upgrade pip # install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install git+https://github.com/huggingface/optimum.git pip install .[openvino,nncf,tests,diffusers] - name: Test with Pytest run: | diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index b4c41e0be1..65f83bebf9 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -31,7 +31,11 @@ from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version, is_transformers_version -from ..utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask +from ..utils.modeling_utils import patch_decoder_attention_mask + +from ..utils.modeling_utils import patch_decoder_attention_mask, MULTI_QUERY_ATTN_MODELS + +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS if is_transformers_version("<", "4.25.0"): @@ -47,43 +51,29 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals task = _TASK_ALIASES.get(task, task) signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_class(model.config) - if task == "text-generation" and use_cache: - onnx_config = onnx_config_class(model.config, use_past=True, use_past_in_inputs=True) + if "text-generation" in task: + onnx_config = onnx_config_class(model.config, use_past=use_cache, use_past_in_inputs=use_cache) + else: + onnx_config = onnx_config_class(model.config) + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - model_inputs = {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} - if task == "text-generation" and use_cache and model.config.model_type != "gpt_bigcode": - # WA jit.trace issue of model like llama in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L464, or else, generation output will be incorrect - pkv = [] - for i in range(len(model_inputs["past_key_values"])): - pkv.append([]) - for j in range(len(model_inputs["past_key_values"][0])): - pkv[i].append(model_inputs["past_key_values"][i][j].to(model.dtype)) - pkv[i] = tuple(pkv[i]) - model_inputs["past_key_values"] = tuple(pkv) - i = model_inputs["input_ids"] - a = model_inputs["attention_mask"] - model_inputs["input_ids"] = torch.cat([torch.zeros(i.shape[0], 1), i], -1).to(i.dtype) - model_inputs["attention_mask"] = torch.cat([torch.zeros(a.shape[0], 1), a], -1).to(a.dtype) - return model_inputs + + return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): model_inputs = prepare_jit_inputs(model, task, use_cache) # check if the model_inputs is correct. model(**model_inputs) + torch._C._jit_set_texpr_fuser_enabled(False) if "past_key_values" in model_inputs.keys(): model.config.return_dict = False - if is_torch_version(">", "2.0.1"): - traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) - else: - traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + if is_torch_version(">=", "2.1.0"): + traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) else: - if is_torch_version(">=", "2.0.0"): - traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs, strict=False) - else: - traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values()), strict=False) + traced_model = torch.jit.freeze(traced_model.eval()) traced_model(**model_inputs) traced_model(**model_inputs) @@ -91,11 +81,7 @@ def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False): return traced_model -class PreTrainedModel(OptimizedModel): - pass - - -class BaseModelForCausalLM(PreTrainedModel, GenerationMixin): +class BaseModelForCausalLM(OptimizedModel, GenerationMixin): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" main_input_name = "input_ids" @@ -156,12 +142,28 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) + + position_ids = kwargs.get("position_ids", None) + + + attention_mask = kwargs.get("attention_mask", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + + + return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": self.use_cache, - "position_ids": None, - "attention_mask": kwargs.get("attention_mask", None), + "position_ids": position_ids, + "attention_mask": attention_mask, "token_type_ids": None, } @@ -258,6 +260,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + position_ids: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: if attention_mask is None: @@ -268,43 +271,40 @@ def forward( "attention_mask": attention_mask, } + model_type = self.config.model_type.replace("_", "-") + if self.use_cache: if past_key_values is None: nb_pkv = 2 num_layers = self.normalized_config.num_layers - num_attention_heads = self.normalized_config.num_attention_heads - num_key_value_heads = num_attention_heads - if hasattr(self.normalized_config, "num_key_value_heads"): - num_key_value_heads = self.normalized_config.num_key_value_heads - hidden_size = self.normalized_config.hidden_size - d_k = hidden_size // num_attention_heads - if self.config.model_type == "gpt_bigcode": - new_shape = [input_ids.shape[0], 0, d_k * 2] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - past_key_values = tuple([empty_tensor] * num_layers) - elif self.config.model_type != "bloom": - new_shape = [input_ids.shape[0], num_key_value_heads, 0, d_k] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - pkv = tuple(empty_tensor for _ in range(nb_pkv)) + d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + batch_size = input_ids.shape[0] + + if model_type in {"mistral", "llama"}: + num_attention_heads = self.normalized_config.num_key_value_heads + else: + num_attention_heads = self.normalized_config.num_attention_heads + + if model_type == "bloom": + shape_key = (batch_size * num_attention_heads, d_k, 0) + shape_value = (batch_size * num_attention_heads, 0, d_k) + key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) + value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)) + elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: + shape = (batch_size, 0, d_k * 2) + pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(pkv for _ in range(num_layers)) else: - pkv = () - for nb_pkv in range(nb_pkv): - if nb_pkv % 2 == 0: - new_shape = [input_ids.shape[0] * num_key_value_heads, d_k, 0] - else: - new_shape = [input_ids.shape[0] * num_key_value_heads, 0, d_k] - empty_tensor = torch.empty(size=new_shape) - if self.model_dtype is not None: - empty_tensor = empty_tensor.to(self.model_dtype) - pkv = pkv + (empty_tensor,) - if past_key_values is None: - past_key_values = tuple(tuple(pkv) for _ in range(num_layers)) + shape = (batch_size, num_attention_heads, 0, d_k) + pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) + past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers)) inputs["past_key_values"] = past_key_values + + if position_ids is not None and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS: + inputs["position_ids"] = position_ids + outputs = self.model(**inputs) if isinstance(outputs, (list, tuple)): @@ -389,7 +389,7 @@ def _from_transformers( torch_dtype: Optional[Union[str, "torch.dtype"]] = None, **kwargs, ): - if is_torch_version("<", "2.0.0"): + if is_torch_version("<", "2.1.0"): raise ImportError("`torch>=2.0.0` is needed to trace your model") task = cls.export_feature @@ -405,12 +405,7 @@ def _from_transformers( } model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) - - if model.config.model_type == "bloom": - model.transformer._prepare_attn_mask = _prepare_attn_mask - - if model.config.model_type == "llama": - model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + model = patch_decoder_attention_mask(model) traced_model = jit_trace(model, task, use_cache) save_dir = TemporaryDirectory() diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 36f16524c2..599f47e511 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -30,9 +30,18 @@ from neural_compressor.quantization import fit from torch.utils.data import DataLoader, RandomSampler from transformers import ( + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForMultipleChoice, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification, + AutoModelForVision2Seq, DataCollator, PretrainedConfig, PreTrainedModel, + XLNetLMHeadModel, default_data_collator, ) @@ -528,3 +537,49 @@ def _apply_quantization_from_config(q_config: Dict, model: torch.nn.Module) -> t q_model = convert(q_model, mapping=q_mapping, inplace=True) return q_model + + +class IncQuantizedModel(INCModel): + @classmethod + def from_pretrained(cls, *args, **kwargs): + warnings.warn( + f"The class `{cls.__name__}` has been depreciated and will be removed in optimum-intel v1.12, please use " + f"`{cls.__name__.replace('IncQuantized', 'INC')}` instead." + ) + return super().from_pretrained(*args, **kwargs) + + +class IncQuantizedModelForQuestionAnswering(IncQuantizedModel): + auto_model_class = AutoModelForQuestionAnswering + + +class IncQuantizedModelForSequenceClassification(IncQuantizedModel): + auto_model_class = AutoModelForSequenceClassification + + +class IncQuantizedModelForTokenClassification(IncQuantizedModel): + auto_model_class = AutoModelForTokenClassification + + +class IncQuantizedModelForMultipleChoice(IncQuantizedModel): + auto_model_class = AutoModelForMultipleChoice + + +class IncQuantizedModelForSeq2SeqLM(IncQuantizedModel): + auto_model_class = AutoModelForSeq2SeqLM + + +class IncQuantizedModelForCausalLM(IncQuantizedModel): + auto_model_class = AutoModelForCausalLM + + +class IncQuantizedModelForMaskedLM(IncQuantizedModel): + auto_model_class = AutoModelForMaskedLM + + +class IncQuantizedModelForXLNetLM(IncQuantizedModel): + auto_model_class = XLNetLMHeadModel + + +class IncQuantizedModelForVision2Seq(IncQuantizedModel): + auto_model_class = AutoModelForVision2Seq diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 9384477eb9..694c4e58c0 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -43,17 +43,12 @@ logger = logging.getLogger(__name__) -# workaround to enable compatibility between openvino models and transformers pipelines -class PreTrainedModel(OptimizedModel): - pass - - @add_start_docstrings( """ Base OVModel class. """, ) -class OVBaseModel(PreTrainedModel): +class OVBaseModel(OptimizedModel): auto_model_class = None export_feature = None @@ -302,7 +297,7 @@ def _from_transformers( @classmethod def _to_load( cls, - model: PreTrainedModel, + model: "PreTrainedModel", config: PretrainedConfig, onnx_config: OnnxConfig, use_auth_token: Optional[Union[bool, str]] = None, diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 4061e1d9d9..1a3b6fbede 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,5 +12,139 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + +import torch +from transformers.modeling_utils import PreTrainedModel + MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} + + +# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + device: torch.device, + past_key_values_length: int, + dtype: torch.dtype = torch.bool, +) -> torch.BoolTensor: + """ + Make causal mask used for bi-directional self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) + seq_ids = torch.arange(target_length, device=device) + + mask[:, past_key_values_length:] = ( + (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min + if torch.is_floating_point(mask) + else seq_ids[:, None] < seq_ids[None, :] + ) + + return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + + +# Modified from transformers.models..bloom.modeling_bloom._prepare_attn_mask +def _prepare_attn_mask( + attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int +) -> torch.BoolTensor: + from transformers.models.bloom.modeling_bloom import _expand_mask + + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]_prepare_decoder_attention_mask + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask +def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + from transformers.models.llama.modeling_llama import _expand_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_causal_mask( + input_shape, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + dtype=inputs_embeds.dtype, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask +def _prepare_decoder_sliding_window_attention_mask( + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: int, +): + from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_sliding_window_causal_mask( + input_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +def patch_decoder_attention_mask(model: "PreTrainedModel"): + """ + Apply patch on decoder with past model forward to resolve first inference based on model architecture + + Args: + model (PretrainedModel): The model to patch. + + Returns: + model with applied patch + """ + if model.config.model_type in {"bloom", "mpt"}: + model.transformer._prepare_attn_mask = _prepare_attn_mask + elif model.config.model_type == "llama": + model.model._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + elif model.config.model_type == "mistral": + model.model._prepare_decoder_attention_mask = _prepare_decoder_sliding_window_attention_mask + elif model.config.model_type in {"blenderbot-small", "blenderbot", "opt", "pegasus", "bart"}: + model.model.decoder._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + return model diff --git a/setup.py b/setup.py index 6d81b98b2a..e2da72bf3e 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ - "optimum>=1.13.0", + "optimum @ git+https://github.com/huggingface/optimum.git", "transformers>=4.20.0", "datasets>=1.4.0", "sentencepiece", From e1ca1d6ba8318934c93ad14e3603eeba4dbaa34b Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 16:51:00 +0100 Subject: [PATCH 10/29] style --- optimum/intel/generation/modeling.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 65f83bebf9..bbfc3db63d 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -26,16 +26,13 @@ from transformers.utils import WEIGHTS_NAME from optimum.exporters import TasksManager +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_torch_version, is_transformers_version -from ..utils.modeling_utils import patch_decoder_attention_mask - -from ..utils.modeling_utils import patch_decoder_attention_mask, MULTI_QUERY_ATTN_MODELS - -from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS +from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask if is_transformers_version("<", "4.25.0"): @@ -142,10 +139,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) - position_ids = kwargs.get("position_ids", None) - attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: @@ -155,9 +150,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) - - - return { "input_ids": input_ids, "past_key_values": past_key_values, @@ -290,7 +282,9 @@ def forward( shape_value = (batch_size * num_attention_heads, 0, d_k) key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device) - past_key_values = tuple(tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers)) + past_key_values = tuple( + tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers) + ) elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: shape = (batch_size, 0, d_k * 2) pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) From 187360105d12e0c835b43dbaa9a9e5065465bcaf Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 16:54:04 +0100 Subject: [PATCH 11/29] fix compatibility --- optimum/intel/neural_compressor/quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 599f47e511..137b16b22d 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -48,7 +48,7 @@ from optimum.exporters import TasksManager from optimum.exporters.onnx import OnnxConfig from optimum.onnxruntime import ORTModel -from optimum.onnxruntime.modeling_decoder import ORTModelDecoder +from optimum.onnxruntime.modeling_decoder import ORTModelForCausalLM from optimum.onnxruntime.modeling_seq2seq import ORTModelForConditionalGeneration from optimum.onnxruntime.utils import ONNX_DECODER_NAME from optimum.quantization_base import OptimumQuantizer @@ -265,7 +265,7 @@ def quantize( if isinstance(self._original_model, ORTModelForConditionalGeneration): raise RuntimeError("ORTModelForConditionalGeneration not supported for quantization") - if isinstance(self._original_model, ORTModelDecoder): + if isinstance(self._original_model, ORTModelForCausalLM): model_or_path = self._original_model.onnx_paths if len(model_or_path) > 1: raise RuntimeError( From 7a90d645ce09366412929547b431d0b6d2502ea2 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 16:55:14 +0100 Subject: [PATCH 12/29] remove bigcode --- tests/generation/test_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py index 2b7c16140a..db36b924f4 100644 --- a/tests/generation/test_modeling.py +++ b/tests/generation/test_modeling.py @@ -54,7 +54,7 @@ class ModelingIntegrationTest(unittest.TestCase): "gpt_neo", "mistral", "llama", - "gpt_bigcode", + # "gpt_bigcode", ) GENERATION_LENGTH = 100 From 0e1788337c6e3c81f32e92cd21a5db5bf766f17d Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 17:05:03 +0100 Subject: [PATCH 13/29] fix --- optimum/intel/neural_compressor/quantization.py | 1 + optimum/intel/openvino/modeling_base.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 137b16b22d..bb232cdade 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import Callable, Dict, Optional, Union +import warnings import torch from datasets import Dataset, load_dataset from neural_compressor.adaptor.pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _propagate_qconfig diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 694c4e58c0..e8e62e0818 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -297,7 +297,7 @@ def _from_transformers( @classmethod def _to_load( cls, - model: "PreTrainedModel", + model, config: PretrainedConfig, onnx_config: OnnxConfig, use_auth_token: Optional[Union[bool, str]] = None, From 79b75b94896609f56b088146c8711a3023ebbf1c Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 31 Oct 2023 17:05:43 +0100 Subject: [PATCH 14/29] style --- optimum/intel/neural_compressor/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index bb232cdade..d4846adc15 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -15,12 +15,12 @@ import copy import inspect import logging +import warnings from enum import Enum from itertools import chain from pathlib import Path from typing import Callable, Dict, Optional, Union -import warnings import torch from datasets import Dataset, load_dataset from neural_compressor.adaptor.pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _propagate_qconfig From 8983274f5c5c46f15570171c62791558706436ab Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 2 Nov 2023 12:04:23 +0100 Subject: [PATCH 15/29] fix test --- tests/openvino/test_modeling.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f3978b2965..8efee7bac8 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -77,6 +77,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) from optimum.utils.testing_utils import require_diffusers +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS TENSOR_ALIAS_TO_TYPE = { @@ -496,7 +497,12 @@ def test_compare_to_transformers(self, model_arch): tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None ) - ov_outputs = ov_model(**tokens) + position_ids = None + if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: + input_shape = tokens["input_ids"].shape + position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) + ov_outputs = ov_model(**tokens, position_ids=position_ids) + self.assertTrue("logits" in ov_outputs) self.assertIsInstance(ov_outputs.logits, torch.Tensor) with torch.no_grad(): From d6cdc10b78c92ce8ba94389979c2dd07ec79c841 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 2 Nov 2023 14:50:51 +0100 Subject: [PATCH 16/29] fixes --- optimum/intel/openvino/modeling_decoder.py | 22 ++++++++++++++++------ tests/openvino/test_modeling.py | 2 +- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 02fa96bd2c..5bd9437391 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -396,17 +396,27 @@ def forward( inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: + if "attention_mask" in self.input_names or "position_ids" in self.input_names: if attention_mask is not None: - inputs["attention_mask"] = np.array(attention_mask) + attention_mask = np.array(attention_mask) else: - inputs["attention_mask"] = np.ones( + attention_mask = np.ones( (input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype ) - # Add the attention_mask inputs when needed - if "position_ids" in self.input_names and position_ids is not None: - inputs["position_ids"] = np.array(position_ids) + if "attention_mask" in self.input_names: + inputs["attention_mask"] = attention_mask + + if "position_ids" in self.input_names: + if position_ids is not None: + position_ids = np.array(position_ids) + else: + position_ids = np.cumsum(attention_mask, axis=1) - 1 + position_ids[attention_mask == 0] = 1 + if past_key_values: + position_ids = np.expand_dims(position_ids[:, -1], axis=-1) + + inputs["position_ids"] = position_ids # Run inference self.request.start_async(inputs, shared_memory=True) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 8efee7bac8..c29e8c2eef 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -51,6 +51,7 @@ from transformers.onnx.utils import get_preprocessor from utils_tests import MODEL_NAMES +from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel import ( OVModelForAudioClassification, OVModelForAudioFrameClassification, @@ -77,7 +78,6 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) from optimum.utils.testing_utils import require_diffusers -from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS TENSOR_ALIAS_TO_TYPE = { From 667809e868b6d1f2a3967e89ce0060f19242cedc Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 11:35:00 +0100 Subject: [PATCH 17/29] trigger test From deff847bc1da29518f0bc0e3532274fe5be190e9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 11:56:49 +0100 Subject: [PATCH 18/29] fix trainer --- optimum/intel/openvino/trainer.py | 62 ++++++++++++++++++------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 0bba054ad3..d7e1659bb5 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -54,8 +54,8 @@ from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import deepspeed_init from transformers.integrations import hp_params +from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.pytorch_utils import is_torch_less_than_1_11 from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -65,7 +65,6 @@ from transformers.trainer_utils import ( EvalPrediction, HPSearchBackend, - ShardedDDPOption, TrainOutput, has_length, speed_metrics, @@ -252,7 +251,7 @@ def _inner_training_loop( # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if has_length(train_dataloader): @@ -296,30 +295,42 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = ( - self.sharded_ddp is not None - and self.sharded_ddp != ShardedDDPOption.SIMPLE - or is_sagemaker_mp_enabled() - or self.fsdp is not None - ) - if args.deepspeed: - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - elif not delay_optimizer_creation: + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + # Activate gradient checkpointing if needed if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + if args.gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + else: + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) if is_transformers_version("<", "4.29.0"): is_distributed = self.args.local_rank != -1 @@ -333,16 +344,17 @@ def _inner_training_loop( model = self._wrap_model(self.model_wrapped) - if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: - self._load_from_checkpoint(resume_from_checkpoint, model) + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) From 0ad7dc253d4680234cc64ff62ee00f515fb989e5 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 11:59:11 +0100 Subject: [PATCH 19/29] fix trainer --- optimum/intel/openvino/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index d7e1659bb5..065931c4e0 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -355,6 +355,9 @@ def _inner_training_loop( if model is not self.model: self.model_wrapped = model + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) From a328aa4a98d2202ec6e40fffacd478f4e983f71d Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 12:09:38 +0100 Subject: [PATCH 20/29] fix trainer --- .github/workflows/test_inc.yml | 3 +- optimum/intel/neural_compressor/trainer.py | 310 ++++++++++++++------- optimum/intel/openvino/trainer.py | 283 +++++++++++++------ 3 files changed, 398 insertions(+), 198 deletions(-) diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index fd5fd16509..3a15214f99 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -30,7 +30,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install .[neural-compressor,ipex,diffusers,tests] + pip install .[neural-compressor,diffusers,tests] + pip install intel-extension-for-pytorch - name: Test with Pytest run: | pytest tests/neural_compressor/ diff --git a/optimum/intel/neural_compressor/trainer.py b/optimum/intel/neural_compressor/trainer.py index 8e8fec1758..5ae4a1f72a 100644 --- a/optimum/intel/neural_compressor/trainer.py +++ b/optimum/intel/neural_compressor/trainer.py @@ -15,6 +15,7 @@ import copy import math import os +import shutil import sys import time from collections.abc import Mapping @@ -28,38 +29,38 @@ from neural_compressor.compression import DistillationCallbacks from neural_compressor.conf.pythonic_config import _BaseQuantizationConfig from neural_compressor.experimental.export import torch_to_fp32_onnx, torch_to_int8_onnx - -# from packaging import version +from packaging import version from torch import nn from torch.utils.data import Dataset, RandomSampler -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.distributed import DistributedSampler -from tqdm.auto import tqdm from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import deepspeed_init -from transformers.file_utils import WEIGHTS_NAME # Integrations must be imported before ML frameworks: -from transformers.integrations import hp_params +from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint, hp_params, is_deepspeed_available from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.pytorch_utils import is_torch_less_than_1_11 from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer_callback import TrainerCallback, TrainerState -from transformers.trainer_pt_utils import IterableDatasetShard +from transformers.trainer_pt_utils import get_dataloader_sampler, get_model_param_count from transformers.trainer_utils import ( EvalPrediction, HPSearchBackend, - ShardedDDPOption, TrainOutput, has_length, speed_metrics, ) -from transformers.training_args import TrainingArguments -from transformers.utils import is_apex_available, is_sagemaker_mp_enabled, logging +from transformers.training_args import ParallelMode, TrainingArguments +from transformers.utils import ( + WEIGHTS_NAME, + is_accelerate_available, + is_apex_available, + is_sagemaker_mp_enabled, + is_torch_tpu_available, + logging, +) from optimum.exporters import TasksManager @@ -68,12 +69,31 @@ from .configuration import INCConfig +if is_accelerate_available(): + from accelerate import __version__ as accelerate_version + from accelerate import skip_first_batches + + if version.parse(accelerate_version) > version.parse("0.20.3"): + pass + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + if is_deepspeed_available(): + pass + + if is_apex_available(): from apex import amp if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp +if is_torch_tpu_available(check_device=False): + import torch_xla.core.xla_model as xm + if TYPE_CHECKING: from optimum.exporters.onnx import OnnxConfig @@ -178,7 +198,9 @@ def __init__( def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): + self.accelerator.free_memory() self._train_batch_size = batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -186,9 +208,10 @@ def _inner_training_loop( # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None + num_train_tokens = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps @@ -230,58 +253,106 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = ( - self.sharded_ddp is not None - and self.sharded_ddp != ShardedDDPOption.SIMPLE - or is_sagemaker_mp_enabled() - or self.fsdp is not None - ) - if args.deepspeed: - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - elif not delay_optimizer_creation: + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + # Activate gradient checkpointing if needed if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + if args.gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + else: + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) model = self._wrap_model(self.model_wrapped) - if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: - self._load_from_checkpoint(resume_from_checkpoint, model) + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + if use_accelerator_prepare: + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model - # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples}") - logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() @@ -306,20 +377,19 @@ def _inner_training_loop( logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( - f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " - "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " - "flag to your launch command, but you will resume the training on data already seen by your model." + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." ) - if self.is_local_process_zero() and not args.disable_tqdm: - steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) - steps_trained_progress_bar.set_description("Skipping the first batches") # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader - self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) if trial is not None: assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial self.state.trial_params = hp_params(assignments) @@ -347,26 +417,26 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( - train_dataloader.sampler, RandomSampler - ) + sampler = get_dataloader_sampler(train_dataloader) + sampler_kinds = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + sampler_kinds.append(SeedableRandomSampler) + is_random_sampler = isinstance(sampler, tuple(sampler_kinds)) if is_torch_less_than_1_11 or not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. - # That was before PyTorch 1.11 however... for _ in train_dataloader: break else: # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! - _ = list(train_dataloader.sampler) + sampler = sampler if sampler is not None else [] + _ = list(sampler) + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): - train_dataloader.sampler.set_epoch(epoch) - elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): - train_dataloader.dataset.set_epoch(epoch) - epoch_iterator = train_dataloader + if hasattr(epoch_iterator, "set_epoch"): + epoch_iterator.set_epoch(epoch) # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: @@ -385,8 +455,21 @@ def _inner_training_loop( if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + step = -1 for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 @@ -404,18 +487,14 @@ def _inner_training_loop( if self._compression_manager is not None: self._compression_manager.callbacks.on_step_begin(step) - if ( - ((step + 1) % args.gradient_accumulation_steps != 0) - and args.local_rank != -1 - and args._no_sync_in_gradient_accumulation - ): - # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. - with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) - else: + with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) - if args.logging_nan_inf_filter and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)): + if ( + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: @@ -423,35 +502,38 @@ def _inner_training_loop( self.current_flos += float(self.floating_point_ops(inputs)) - # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps - if self.deepspeed: - self.deepspeed.step() + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) - if (step + 1) % args.gradient_accumulation_steps == 0 or ( + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= args.gradient_accumulation_steps - and (step + 1) == steps_in_epoch + is_last_step_and_steps_less_than_grad_acc ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc or ( + version.parse(accelerate_version) <= version.parse("0.20.3") + ): + self.accelerator.gradient_state._set_sync_gradients(True) + # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping - if self.do_grad_scaling: - # AMP: gradients need unscaling - self.scaler.unscale_(self.optimizer) - if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) - elif hasattr(self.optimizer, "clip_grad_norm"): - # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping - self.optimizer.clip_grad_norm(args.max_grad_norm) - elif hasattr(model, "clip_grad_norm_"): - # Some models (like FullyShardedDDP) have a specific way to do gradient clipping - model.clip_grad_norm_(args.max_grad_norm) - else: + elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), args.max_grad_norm, ) @@ -459,27 +541,20 @@ def _inner_training_loop( self._compression_manager.callbacks.on_before_optimizer_step() # Optimizer step - optimizer_was_run = True - if self.deepspeed: - pass # called outside the loop - elif self.do_grad_scaling: - scale_before = self.scaler.get_scale() - self.scaler.step(self.optimizer) - self.scaler.update() - scale_after = self.scaler.get_scale() - optimizer_was_run = scale_before <= scale_after - else: - self.optimizer.step() + self.optimizer.step() if self._compression_manager is not None: self._compression_manager.callbacks.on_after_optimizer_step() - if optimizer_was_run and not self.deepspeed: - self.lr_scheduler.step() + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self._compression_manager is not None: self._compression_manager.callbacks.on_step_end() @@ -501,7 +576,6 @@ def _inner_training_loop( self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) if self._compression_manager is not None: self._compression_manager.callbacks.on_epoch_end() - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) if self.control.should_training_stop: @@ -513,9 +587,10 @@ def _inner_training_loop( logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: - # Wait for everyone to get here so we are sur the model has been saved by process 0. - - if args.local_rank != -1: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() @@ -526,7 +601,13 @@ def _inner_training_loop( self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step - metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss @@ -537,7 +618,26 @@ def _inner_training_loop( self.log(metrics) + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + if self._compression_manager is not None: self._compression_manager.callbacks.on_train_end() diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 065931c4e0..ccc716fadd 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -16,12 +16,12 @@ import io import math import os +import shutil import sys import time -from collections import defaultdict from itertools import chain from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union import openvino import openvino.runtime @@ -46,22 +46,23 @@ compress_quantize_weights_transformation, ) from openvino.runtime import Core, PartialShape, save_model +from packaging import version +from torch import nn from torch.onnx import export as onnx_export from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader, Dataset, RandomSampler -from torch.utils.data.distributed import DistributedSampler -from tqdm.auto import tqdm +from torch.utils.data import Dataset, RandomSampler from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.integrations import hp_params -from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint + +# Integrations must be imported before ML frameworks: +from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint, hp_params, is_deepspeed_available from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.pytorch_utils import is_torch_less_than_1_11 from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import TRAINER_STATE_NAME, TRAINING_ARGS_NAME from transformers.trainer_callback import TrainerCallback, TrainerState -from transformers.trainer_pt_utils import IterableDatasetShard +from transformers.trainer_pt_utils import get_dataloader_sampler, get_model_param_count from transformers.trainer_utils import ( EvalPrediction, HPSearchBackend, @@ -69,16 +70,18 @@ has_length, speed_metrics, ) +from transformers.training_args import ParallelMode from transformers.utils import ( WEIGHTS_NAME, + is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled, is_torch_tpu_available, logging, ) +from optimum.exporters import TasksManager from optimum.exporters.onnx import OnnxConfig -from optimum.exporters.tasks import TasksManager from ..utils.constant import _TASK_ALIASES from ..utils.import_utils import is_transformers_version @@ -94,6 +97,22 @@ ) +if is_accelerate_available(): + from accelerate import __version__ as accelerate_version + from accelerate import skip_first_batches + + if version.parse(accelerate_version) > version.parse("0.20.3"): + pass + DATA_SAMPLERS = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + from accelerate.data_loader import SeedableRandomSampler + + DATA_SAMPLERS += [SeedableRandomSampler] + + if is_deepspeed_available(): + pass + + if is_apex_available(): from apex import amp @@ -243,7 +262,9 @@ def _set_signature_columns_if_needed(self): def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): + self.accelerator.free_memory() self._train_batch_size = batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -254,6 +275,7 @@ def _inner_training_loop( total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None + num_train_tokens = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps @@ -267,10 +289,16 @@ def _inner_training_loop( # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps + ) else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. @@ -278,6 +306,8 @@ def _inner_training_loop( num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps else: raise ValueError( "args.max_steps must be set to a positive value if dataloader does not have a length, was" @@ -286,7 +316,7 @@ def _inner_training_loop( if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: - # torch.nn.DataParallel(model) replicates the model, creating new variables and module + # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module raise ValueError( "Currently --debug underflow_overflow is not supported under DP. Please use DDP" @@ -297,6 +327,11 @@ def _inner_training_loop( delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + if self.is_deepspeed_enabled: self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) @@ -344,6 +379,41 @@ def _inner_training_loop( model = self._wrap_model(self.model_wrapped) + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + if use_accelerator_prepare: + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + # ckpt loading if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: @@ -351,28 +421,25 @@ def _inner_training_loop( elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) - # for the rest of this function `model` is the outside model, whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model - - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) - # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) # important: at this point: # self.model is the Transformers Model - # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples}") - logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() @@ -397,20 +464,19 @@ def _inner_training_loop( logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( - f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " - "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " - "flag to your launch command, but you will resume the training on data already seen by your model." + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." ) - if self.is_local_process_zero() and not args.disable_tqdm: - steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) - steps_trained_progress_bar.set_description("Skipping the first batches") # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader - self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) if trial is not None: assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial self.state.trial_params = hp_params(assignments) @@ -423,8 +489,8 @@ def _inner_training_loop( self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() + # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) - self.compression_metrics = defaultdict(lambda: torch.tensor(0.0).to(args.device)) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step @@ -435,31 +501,33 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( - train_dataloader.sampler, RandomSampler - ) + sampler = get_dataloader_sampler(train_dataloader) + sampler_kinds = [RandomSampler] + if version.parse(accelerate_version) > version.parse("0.23.0"): + sampler_kinds.append(SeedableRandomSampler) + is_random_sampler = isinstance(sampler, tuple(sampler_kinds)) if is_torch_less_than_1_11 or not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. - # That was before PyTorch 1.11 however... for _ in train_dataloader: break else: - # Otherwise we need to call the whole sampler cause there is some random operation added + # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! - _ = list(train_dataloader.sampler) + sampler = sampler if sampler is not None else [] + _ = list(sampler) + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): - train_dataloader.sampler.set_epoch(epoch) - elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): - train_dataloader.dataset.set_epoch(epoch) + epoch_iterator = train_dataloader + if hasattr(epoch_iterator, "set_epoch"): + epoch_iterator.set_epoch(epoch) # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( - len(train_dataloader) + len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) @@ -475,8 +543,21 @@ def _inner_training_loop( if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + step = -1 - for step, inputs in enumerate(train_dataloader): + for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 @@ -495,17 +576,14 @@ def _inner_training_loop( # Must be called at the beginning of each training step to prepare the compression method self.compression_controller.scheduler.step() + with self.accelerator.accumulate(model): + tr_loss_step = self.training_step(model, inputs) + if ( - ((step + 1) % args.gradient_accumulation_steps != 0) - and args.local_rank != -1 - and args._no_sync_in_gradient_accumulation + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): - # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. - with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) - else: - tr_loss_step = self.training_step(model, inputs) - if args.logging_nan_inf_filter and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)): # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: @@ -513,57 +591,52 @@ def _inner_training_loop( self.current_flos += float(self.floating_point_ops(inputs)) - # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps - if self.deepspeed: - self.deepspeed.step() + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) - if (step + 1) % args.gradient_accumulation_steps == 0 or ( + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= args.gradient_accumulation_steps - and (step + 1) == steps_in_epoch + is_last_step_and_steps_less_than_grad_acc ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc or ( + version.parse(accelerate_version) <= version.parse("0.20.3") + ): + self.accelerator.gradient_state._set_sync_gradients(True) + # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping - if self.do_grad_scaling: - # AMP: gradients need unscaling - self.scaler.unscale_(self.optimizer) - if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) - elif hasattr(self.optimizer, "clip_grad_norm"): - # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping - self.optimizer.clip_grad_norm(args.max_grad_norm) - elif hasattr(model, "clip_grad_norm_"): - # Some models (like FullyShardedDDP) have a specific way to do gradient clipping - model.clip_grad_norm_(args.max_grad_norm) - else: + elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision - torch.nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), args.max_grad_norm, ) # Optimizer step - optimizer_was_run = True - if self.deepspeed: - pass # called outside the loop - elif self.do_grad_scaling: - scale_before = self.scaler.get_scale() - self.scaler.step(self.optimizer) - self.scaler.update() - scale_after = self.scaler.get_scale() - optimizer_was_run = scale_before <= scale_after - else: - self.optimizer.step() - - if optimizer_was_run and not self.deepspeed: - self.lr_scheduler.step() + self.optimizer.step() + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 - self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) @@ -574,7 +647,7 @@ def _inner_training_loop( break if step < 0: logger.warning( - "There seems to be not a single sample in your train_dataloader, stopping training at step" + "There seems to be not a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) @@ -582,7 +655,7 @@ def _inner_training_loop( self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) - + if self.control.should_training_stop: break @@ -592,8 +665,10 @@ def _inner_training_loop( logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: - # Wait for everyone to get here so we are sur the model has been saved by process 0. - if args.local_rank != -1: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() @@ -604,7 +679,13 @@ def _inner_training_loop( self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step - metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss @@ -615,8 +696,26 @@ def _inner_training_loop( self.log(metrics) + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + return TrainOutput(self.state.global_step, train_loss, metrics) def compute_distillation_loss(self, inputs, student_outputs): From bb8925e6d3a066976040e028123ca826c819a09d Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 14:56:09 +0100 Subject: [PATCH 21/29] fix test --- optimum/intel/openvino/trainer.py | 4 ++-- tests/neural_compressor/test_onnx.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index ccc716fadd..5133f40bfb 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -21,7 +21,7 @@ import time from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import openvino import openvino.runtime @@ -655,7 +655,7 @@ def _inner_training_loop( self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) - + if self.control.should_training_stop: break diff --git a/tests/neural_compressor/test_onnx.py b/tests/neural_compressor/test_onnx.py index f5dc0b7c66..387c369dd1 100644 --- a/tests/neural_compressor/test_onnx.py +++ b/tests/neural_compressor/test_onnx.py @@ -54,7 +54,7 @@ def test_static_quantization(self, task, model_name, expected_quantized_matmuls) tokenizer.pad_token = tokenizer.eos_token quantizer = INCQuantizer.from_pretrained(model, task=task) calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=num_samples) - save_onnx_model = True + save_onnx_model = False op_type_dict = ( {"Embedding": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}} if save_onnx_model From c21f7364adf3e17a3c39e7be63edd54e9662d960 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Fri, 3 Nov 2023 15:32:53 +0100 Subject: [PATCH 22/29] fix trainer --- optimum/intel/openvino/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 5133f40bfb..7cf7b017a1 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -19,6 +19,7 @@ import shutil import sys import time +from collections import defaultdict from itertools import chain from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Type, Union @@ -491,6 +492,7 @@ def _inner_training_loop( # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) + self.compression_metrics = defaultdict(lambda: torch.tensor(0.0).to(args.device)) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step From 339605f94250325e7b69d1318a380f370ced1cd4 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 12:16:03 +0100 Subject: [PATCH 23/29] fix conflits --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 49fcef2408..ca47ac1ad7 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,6 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ -<<<<<<< HEAD "optimum @ git+https://github.com/huggingface/optimum.git", "transformers>=4.20.0", "datasets>=1.4.0", From f188c0f2b14960bd0942f581c5df0836ec5c90e3 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 12:16:33 +0100 Subject: [PATCH 24/29] format --- optimum/intel/neural_compressor/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/intel/neural_compressor/trainer.py b/optimum/intel/neural_compressor/trainer.py index 56a97dbdcc..c648f10432 100644 --- a/optimum/intel/neural_compressor/trainer.py +++ b/optimum/intel/neural_compressor/trainer.py @@ -44,7 +44,6 @@ from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow - # Integrations must be imported before ML frameworks: from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint, hp_params, is_deepspeed_available from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model From 32457320ead060a65bbd55e1bd53f38ffe9c6e19 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 13:09:35 +0100 Subject: [PATCH 25/29] format --- optimum/intel/neural_compressor/trainer.py | 3 --- optimum/intel/openvino/trainer.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/optimum/intel/neural_compressor/trainer.py b/optimum/intel/neural_compressor/trainer.py index c648f10432..918a2e4885 100644 --- a/optimum/intel/neural_compressor/trainer.py +++ b/optimum/intel/neural_compressor/trainer.py @@ -43,9 +43,6 @@ from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow - -# Integrations must be imported before ML frameworks: -from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint, hp_params, is_deepspeed_available from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.pytorch_utils import is_torch_less_than_1_11 diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py index 7bbfc4599e..dfc659882c 100644 --- a/optimum/intel/openvino/trainer.py +++ b/optimum/intel/openvino/trainer.py @@ -62,9 +62,6 @@ from transformers import Trainer from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow - -# Integrations must be imported before ML frameworks: -from transformers.integrations import deepspeed_init, deepspeed_load_checkpoint, hp_params, is_deepspeed_available from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.pytorch_utils import is_torch_less_than_1_11 from transformers.tokenization_utils_base import PreTrainedTokenizerBase From 398aaa9dd2eca0c685eabd731bbd1749e47b9ce9 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 13:55:31 +0100 Subject: [PATCH 26/29] fix transformers version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ca47ac1ad7..ef8203e687 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ "onnxruntime<1.15.0", "transformers>=4.33.0", ], - "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers>=4.33.0"], + "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers>=4.33.0,<4.35.0"], "nncf": ["nncf>=2.6.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], From 892374d331c8e3120978bca46b51061335eff9e4 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 18:03:58 +0100 Subject: [PATCH 27/29] fix version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ef8203e687..2f3a17558f 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ - "optimum @ git+https://github.com/huggingface/optimum.git", + "optimum>=1.14.0", "transformers>=4.20.0", "datasets>=1.4.0", "sentencepiece", From 53110771c4f361300b4b0df77de9a5363fe263e0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 18:20:58 +0100 Subject: [PATCH 28/29] version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2f3a17558f..7949f6d11d 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ "onnxruntime<1.15.0", "transformers>=4.33.0", ], - "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers>=4.33.0,<4.35.0"], + "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers>=4.33.0"], "nncf": ["nncf>=2.6.0"], "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], From 14c3108e1a9d82172483b008d202862737f983e0 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Mon, 6 Nov 2023 18:21:33 +0100 Subject: [PATCH 29/29] remove constraitn ipex --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0c1feace30..96bd2c3063 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ ], "openvino": ["openvino>=2023.1.0", "onnx", "onnxruntime", "transformers>=4.33.0"], "nncf": ["nncf>=2.6.0"], - "ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"], + "ipex": ["intel-extension-for-pytorch", "onnx"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE,