diff --git a/intel_extension_for_transformers/transformers/trainer.py b/intel_extension_for_transformers/transformers/trainer.py index 387a9466389..e3899e2a41f 100644 --- a/intel_extension_for_transformers/transformers/trainer.py +++ b/intel_extension_for_transformers/transformers/trainer.py @@ -387,7 +387,7 @@ def distill( assert False, "Please provide teacher model for DistillationConfig." self._eval_func = self.builtin_eval_func if eval_func is None else eval_func self._train_func = self.builtin_train_func if train_func is None else train_func - + compression_manager = prepare_compression(self.model, distillation_config) self.compression_manager = compression_manager self.compression_manager.callbacks.on_train_begin() diff --git a/tests/Nightly/test_distillation.py b/tests/Nightly/test_distillation.py index cce4410f09d..8f818f38e51 100644 --- a/tests/Nightly/test_distillation.py +++ b/tests/Nightly/test_distillation.py @@ -28,7 +28,7 @@ from neural_compressor.config import ( DistillationConfig, KnowledgeDistillationLossConfig, -) +) from intel_extension_for_transformers.transformers.trainer import NLPTrainer from transformers import ( AutoModelForSequenceClassification, @@ -76,7 +76,7 @@ def compute_metrics(p): preds = np.argmax(preds, axis=1) return metric.compute(predictions=preds, references=p.label_ids) origin_weight = copy.deepcopy(self.model.classifier.weight) - + self.trainer = NLPTrainer( model=copy.deepcopy(self.model), train_dataset=self.dataset, diff --git a/tests/Nightly/test_orchestrate_optimization.py b/tests/Nightly/test_orchestrate_optimization.py index 7137ccc844d..d65ece8099c 100644 --- a/tests/Nightly/test_orchestrate_optimization.py +++ b/tests/Nightly/test_orchestrate_optimization.py @@ -85,8 +85,8 @@ def compute_metrics(p): name="eval_accuracy", is_relative=True, criterion=0.5 ) self.trainer.metrics = tune_metric - pruning_conf = WeightPruningConfig([{"start_step": 0, "end_step": 2}], - target_sparsity=0.64, + pruning_conf = WeightPruningConfig([{"start_step": 0, "end_step": 2}], + target_sparsity=0.64, pruning_scope="local") distillation_criterion = KnowledgeDistillationLossConfig(loss_types=["CE", "KL"]) distillation_conf = DistillationConfig(teacher_model=self.teacher_model, criterion=distillation_criterion) diff --git a/workflows/compression_aware_training/src/itrex_opt.py b/workflows/compression_aware_training/src/itrex_opt.py index c1ba546cdc4..fcfd5f7eab7 100755 --- a/workflows/compression_aware_training/src/itrex_opt.py +++ b/workflows/compression_aware_training/src/itrex_opt.py @@ -778,7 +778,7 @@ def _do_quantization_aware_training(self): quantization_config = QuantizationAwareTrainingConfig( tuning_criterion=tuning_criterion, accuracy_criterion=accuracy_criterion - ) + ) early_stopping_patience = 2 early_stopping_threshold = 0.001 # optional self.trainer.add_callback(transformers.EarlyStoppingCallback(early_stopping_patience, \