Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
migrate distillation pruning orchestration
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <[email protected]>
  • Loading branch information
changwangss committed Jun 17, 2024
1 parent 75b13d1 commit 773bd1b
Show file tree
Hide file tree
Showing 48 changed files with 756 additions and 7,241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@
"import transformers\n",
"from intel_extension_for_transformers.transformers import (\n",
" metrics,\n",
" PrunerConfig,\n",
" PruningConfig,\n",
" DistillationConfig,\n",
" QuantizationConfig,\n",
" OptimizedModel,\n",
" objectives\n",
")\n",
"from neural_compressor.config import (\n",
" WeightPruningConfig,\n",
" DistillationConfig,\n",
" KnowledgeDistillationLossConfig,\n",
" QuantizationAwareTrainingConfig,\n",
")\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"from trainer_qa import QuestionAnsweringTrainer\n",
Expand Down Expand Up @@ -214,7 +216,7 @@
" metadata={\"help\": \"Whether or not to apply prune.\"},\n",
" )\n",
" pruning_approach: Optional[str] = field(\n",
" default=\"BasicMagnitude\",\n",
" default=\"magnitude\",\n",
" metadata={\"help\": \"Pruning approach. Supported approach is basic_magnite.\"},\n",
" )\n",
" target_sparsity_ratio: Optional[float] = field(\n",
Expand All @@ -234,9 +236,9 @@
" metadata={\"help\": \"Whether or not to apply quantization.\"},\n",
" )\n",
" quantization_approach: Optional[str] = field(\n",
" default=\"PostTrainingStatic\",\n",
" metadata={\"help\": \"Quantization approach. Supported approach are PostTrainingStatic, \"\n",
" \"PostTrainingDynamic and QuantizationAwareTraining.\"},\n",
" default=\"static\",\n",
" metadata={\"help\": \"Quantization approach. Supported approach are static, \"\n",
" \"dynamic and qat.\"},\n",
" )\n",
" metric_name: Optional[str] = field(\n",
" default=None,\n",
Expand Down Expand Up @@ -300,7 +302,7 @@
")\n",
"optim_args = OptimizationArguments(\n",
" tune=True,\n",
" quantization_approach=\"PostTrainingStatic\"\n",
" quantization_approach=\"static\"\n",
")\n",
"log_level = training_args.get_process_log_level()"
]
Expand Down Expand Up @@ -730,9 +732,7 @@
"logger.info(\"***** Number of student model parameters: {:.2f}M *****\".format(\\\n",
" para_counter(model)/10**6))\n",
"\n",
"# Trace model\n",
"from neural_compressor.adaptor.torch_utils.symbolic_trace import symbolic_trace\n",
"model = symbolic_trace(model, optim_args.quantization_approach==\"QuantizationAwareTraining\")"
"# Trace model\n"
]
},
{
Expand Down Expand Up @@ -779,21 +779,18 @@
" tune_metric = metrics.Metric(\n",
" name=metric_name, is_relative=optim_args.is_relative, criterion=optim_args.perf_tol\n",
" )\n",
" prune_type = 'PatternLock' \\\n",
" prune_type = 'pattern_lock' \\\n",
" if optim_args.pruning_approach else optim_args.pruning_approach\n",
" target_sparsity_ratio = optim_args.target_sparsity_ratio \\\n",
" if optim_args.target_sparsity_ratio else None\n",
" pruner_config = PrunerConfig(prune_type=prune_type, target_sparsity_ratio=target_sparsity_ratio)\n",
" pruning_conf = PruningConfig(framework=\"pytorch_fx\",pruner_config=[pruner_config], metrics=tune_metric)\n",
" distillation_conf = DistillationConfig(framework=\"pytorch_fx\", metrics=tune_metric)\n",
"\n",
" objective = objectives.performance\n",
" quantization_conf = QuantizationConfig(\n",
" approach=optim_args.quantization_approach,\n",
" max_trials=600,\n",
" metrics=[tune_metric],\n",
" objectives=[objective]\n",
" )\n",
" trainer.metrics = tune_metric\n",
" pruning_conf = WeightPruningConfig([{\"start_step\": 0, \"end_step\": 2}],\n",
" target_sparsity=target_sparsity_ratio,\n",
" pruning_scope=\"local\",\n",
" pruning_type=prune_type)\n",
" distillation_criterion = KnowledgeDistillationLossConfig(loss_types=[\"CE\", \"KL\"])\n",
" distillation_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)\n",
" quantization_conf = QuantizationAwareTrainingConfig()\n",
" conf_list = [pruning_conf, distillation_conf, quantization_conf]\n",
" model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@
" DataCollatorWithPadding,\n",
" EvalPrediction,\n",
")\n",
"from neural_compressor.config import (\n",
" WeightPruningConfig,\n",
" DistillationConfig,\n",
" KnowledgeDistillationLossConfig,\n",
" QuantizationAwareTrainingConfig,\n",
")\n",
"from transformers.utils import check_min_version\n",
"from transformers.utils.versions import require_version\n",
"from typing import Optional\n",
Expand Down Expand Up @@ -430,18 +436,14 @@
" name=metric_name, is_relative=True, criterion=0.01\n",
")\n",
"\n",
"target_sparsity_ratio = None\n",
"pruner_config = PrunerConfig(prune_type='PatternLock', target_sparsity_ratio=None)\n",
"pruning_conf = PruningConfig(framework=\"pytorch_fx\",pruner_config=[pruner_config], metrics=tune_metric)\n",
"distillation_conf = DistillationConfig(framework=\"pytorch_fx\", metrics=tune_metric)\n",
"\n",
"objective = objectives.performance\n",
"quantization_conf = QuantizationConfig(\n",
" approach=\"QuantizationAwareTraining\",\n",
" max_trials=600,\n",
" metrics=[tune_metric],\n",
" objectives=[objective]\n",
")\n",
"trainer.metrics = tune_metric\n",
"pruning_conf = WeightPruningConfig([{\"start_step\": 0, \"end_step\": 2}],\n",
" target_sparsity=0.64,\n",
" pruning_scope=\"local\",\n",
" pruning_type=\"pattern_lock\")\n",
"distillation_criterion = KnowledgeDistillationLossConfig(loss_types=[\"CE\", \"KL\"])\n",
"distillation_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)\n",
"quantization_conf = QuantizationAwareTrainingConfig()\n",
"conf_list = [pruning_conf, distillation_conf, quantization_conf]\n",
"model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@
"from datasets import load_dataset, load_metric\n",
"from intel_extension_for_transformers.transformers import (\n",
" metrics,\n",
" PrunerConfig,\n",
" PruningConfig,\n",
" DistillationConfig,\n",
" QuantizationConfig,\n",
" objectives\n",
")\n",
"from neural_compressor.config import (\n",
" WeightPruningConfig,\n",
" DistillationConfig,\n",
" KnowledgeDistillationLossConfig,\n",
" QuantizationAwareTrainingConfig,\n",
")\n",
"from intel_extension_for_transformers.transformers.trainer import NLPTrainer\n",
"from transformers import (\n",
" AutoConfig,\n",
Expand Down Expand Up @@ -343,18 +345,14 @@
" name=metric_name, is_relative=True, criterion=0.01\n",
")\n",
"\n",
"target_sparsity_ratio = None\n",
"pruner_config = PrunerConfig(prune_type='PatternLock', target_sparsity_ratio=None)\n",
"pruning_conf = PruningConfig(framework=\"pytorch_fx\",pruner_config=[pruner_config], metrics=tune_metric)\n",
"distillation_conf = DistillationConfig(framework=\"pytorch_fx\", metrics=tune_metric)\n",
"\n",
"objective = objectives.performance\n",
"quantization_conf = QuantizationConfig(\n",
" approach=\"QuantizationAwareTraining\",\n",
" max_trials=600,\n",
" metrics=[tune_metric],\n",
" objectives=[objective]\n",
")\n",
"trainer.metrics = tune_metric\n",
"pruning_conf = WeightPruningConfig([{\"start_step\": 0, \"end_step\": 2}],\n",
" target_sparsity=0.64,\n",
" pruning_scope=\"local\",\n",
" pruning_type=\"pattern_lock\")\n",
"distillation_criterion = KnowledgeDistillationLossConfig(loss_types=[\"CE\", \"KL\"])\n",
"distillation_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)\n",
"quantization_conf = QuantizationAwareTrainingConfig()\n",
"conf_list = [pruning_conf, distillation_conf, quantization_conf]\n",
"model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
import transformers
from intel_extension_for_transformers.transformers import (
metrics,
PrunerConfig,
PruningConfig,
DistillationConfig,
QuantizationConfig,
OptimizedModel,
objectives
)
from neural_compressor.config import (
WeightPruningConfig,
DistillationConfig,
KnowledgeDistillationLossConfig,
QuantizationAwareTrainingConfig,
)
from torch.utils.data import DataLoader
from tqdm import tqdm
from trainer_qa import QuestionAnsweringTrainer
Expand Down Expand Up @@ -225,7 +227,7 @@ class OptimizationArguments:
metadata={"help": "Whether or not to apply prune."},
)
pruning_approach: Optional[str] = field(
default="BasicMagnitude",
default="magnitude",
metadata={"help": "Pruning approach. Supported approach is basic_magnite."},
)
target_sparsity_ratio: Optional[float] = field(
Expand All @@ -245,9 +247,9 @@ class OptimizationArguments:
metadata={"help": "Whether or not to apply quantization."},
)
quantization_approach: Optional[str] = field(
default="QuantizationAwareTraining",
metadata={"help": "Quantization approach. Supported approach are PostTrainingStatic, "
"PostTrainingDynamic and QuantizationAwareTraining."},
default="qat",
metadata={"help": "Quantization approach. Supported approach are static, "
"dynamic and qat."},
)
metric_name: Optional[str] = field(
default="eval_f1",
Expand Down Expand Up @@ -789,7 +791,7 @@ def get_logits(teacher_model, train_dataset, teacher_train_dataset):

# Trace model
from neural_compressor.adaptor.torch_utils.symbolic_trace import symbolic_trace
model = symbolic_trace(model, optim_args.quantization_approach=="QuantizationAwareTraining")
model = symbolic_trace(model, optim_args.quantization_approach=="qat")

# Initialize our Trainer
trainer = QuestionAnsweringTrainer(
Expand All @@ -814,23 +816,20 @@ def get_logits(teacher_model, train_dataset, teacher_train_dataset):
tune_metric = metrics.Metric(
name=metric_name, is_relative=optim_args.is_relative, criterion=optim_args.perf_tol
)
prune_type = 'PatternLock' \
prune_type = 'pattern_lock' \
if optim_args.pruning_approach else optim_args.pruning_approach
target_sparsity_ratio = optim_args.target_sparsity_ratio \
if optim_args.target_sparsity_ratio else None
pruner_config = PrunerConfig(prune_type=prune_type, target_sparsity_ratio=target_sparsity_ratio)
pruning_conf = PruningConfig(framework="pytorch_fx",pruner_config=[pruner_config], metrics=tune_metric)
distillation_conf = DistillationConfig(framework="pytorch_fx", metrics=tune_metric)

objective = objectives.performance
quantization_conf = QuantizationConfig(
approach=optim_args.quantization_approach,
max_trials=600,
metrics=[tune_metric],
objectives=[objective]
)
trainer.metrics = tune_metric
pruning_conf = WeightPruningConfig([{"start_step": 0, "end_step": 2}],
target_sparsity=target_sparsity_ratio,
pruning_scope="local",
pruning_type=prune_type)
distillation_criterion = KnowledgeDistillationLossConfig(loss_types=["CE", "KL"])
distillation_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)
quantization_conf = QuantizationAwareTrainingConfig()
conf_list = [pruning_conf, distillation_conf, quantization_conf]
model = trainer.orchestrate_optimizations(config_list=conf_list, teacher_model=teacher_model)
model = trainer.orchestrate_optimizations(config_list=conf_list)

if optim_args.benchmark or optim_args.accuracy_only:
start_time = timeit.default_timer()
Expand All @@ -839,7 +838,7 @@ def get_logits(teacher_model, train_dataset, teacher_train_dataset):
max_eval_samples = data_args.max_eval_samples \
if data_args.max_eval_samples is not None else len(eval_dataset)
eval_samples = min(max_eval_samples, len(eval_dataset))
samples = eval_samples - (eval_samples % batch_size) \
samples = eval_samples - (eval_samples % optim_args.batch_size) \
if training_args.dataloader_drop_last else eval_samples
logger.info("metrics keys: {}".format(results.keys()))
bert_task_acc_keys = ['eval_f1', 'eval_accuracy', 'eval_matthews_correlation',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,25 +550,19 @@ def compute_metrics(p: EvalPrediction):
)
trainer.metrics = tune_metric
objective = objectives.performance
tuning_criterion = TuningCriterion(max_trials=600, objective=[objective.name])
accuracy_criterion = AccuracyCriterion(
higher_is_better=True, # optional.
criterion="relative" if optim_args.is_relative else "absolute", # optional. Available values are "relative" and "absolute".
tolerable_loss=optim_args.perf_tol, # optional.
)
if optim_args.quantization_approach != "qat":
tuning_criterion = TuningCriterion(max_trials=600, objective=[objective.name])
accuracy_criterion = AccuracyCriterion(
higher_is_better=True, # optional.
criterion="relative" if optim_args.is_relative else "absolute", # optional. Available values are "relative" and "absolute".
tolerable_loss=optim_args.perf_tol, # optional.
)
quantization_config = PostTrainingQuantConfig(
approach=optim_args.quantization_approach,
tuning_criterion=tuning_criterion,
accuracy_criterion=accuracy_criterion
)
else:
tuning_criterion = TuningCriterion(max_trials=600, objective=["performance"])
accuracy_criterion = AccuracyCriterion(
higher_is_better=True, # optional.
criterion="relative" if optim_args.is_relative else "absolute", # optional. Available values are "relative" and "absolute".
tolerable_loss=optim_args.perf_tol, # optional.
)
quantization_config = QuantizationAwareTrainingConfig(
tuning_criterion=tuning_criterion,
accuracy_criterion=accuracy_criterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from intel_extension_for_transformers.transformers.config import (
from neural_compressor.config import (
DistillationConfig,
QuantizationConfig,
IntermediateLayersKnowledgeDistillationLossConfig,
QuantizationAwareTrainingConfig,
)
from intel_extension_for_transformers.transformers.utils import metrics, objectives
from intel_extension_for_transformers.transformers.trainer import NLPTrainer
Expand Down Expand Up @@ -769,12 +770,7 @@ def train_func(model):
tune_metric = metrics.Metric(name="")
if args.do_quantization:
objective = objectives.performance
quantization_conf = QuantizationConfig(
approach="QuantizationAwareTraining",
max_trials=600,
metrics=[tune_metric],
objectives=[objective]
)
quantization_conf = QuantizationAwareTrainingConfig()
conf_list.append(quantization_conf)

if args.do_distillation:
Expand Down Expand Up @@ -828,28 +824,24 @@ def train_func(model):
[['mid_block.resnets.1', ]],
[['conv_out', ]],
]

distillation_conf = DistillationConfig(
framework="pytorch_fx", metrics=tune_metric,
criterion=Criterion(
name="IntermediateLayersLoss",
layer_mappings=layer_mappings,
loss_types=["MSE"] * len(layer_mappings),
loss_weight_ratio=[1.0 / len(layer_mappings)] * len(layer_mappings),
add_origin_loss=True
)
criterion_conf = IntermediateLayersKnowledgeDistillationLossConfig(
layer_mappings=layer_mappings,
loss_types=["MSE"] * len(layer_mappings),
loss_weight_ratio=[1.0 / len(layer_mappings)] * len(layer_mappings),
add_origin_loss=True
)
distillation_conf = DistillationConfig(teacher_model=teacher_model, criterion=criterion_conf)
conf_list.append(distillation_conf)

# Initialize our Trainer
trainer = NLPTrainer(
model=model,
args=TrainingArguments(output_dir=args.output_dir),
)
trainer.metrics = tune_metric

model = trainer.orchestrate_optimizations(
config_list=conf_list,
teacher_model=teacher_model,
eval_func=lambda model:1,
train_func=train_func)

Expand Down
Loading

0 comments on commit 773bd1b

Please sign in to comment.