Skip to content

Commit

Permalink
fix INC trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Nov 3, 2023
1 parent 2d7580a commit db275b3
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.deepspeed import deepspeed_init
from transformers.file_utils import WEIGHTS_NAME

# Integrations must be imported before ML frameworks:
from transformers.integrations import hp_params
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.pytorch_utils import is_torch_less_than_1_11
Expand All @@ -53,7 +53,6 @@
from transformers.trainer_utils import (
EvalPrediction,
HPSearchBackend,
ShardedDDPOption,
TrainOutput,
has_length,
speed_metrics,
Expand Down Expand Up @@ -186,7 +185,7 @@ def _inner_training_loop(
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size

len_dataloader = None
if has_length(train_dataloader):
Expand Down Expand Up @@ -230,35 +229,51 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = (
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled()
or self.fsdp is not None
)
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
elif not delay_optimizer_creation:
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled

if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None

# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps

# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
else:
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs

self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

model = self._wrap_model(self.model_wrapped)

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
Expand Down

0 comments on commit db275b3

Please sign in to comment.