From ad55b858e7ec60f4fb3e84c6951cb2775be0cdf2 Mon Sep 17 00:00:00 2001 From: Thomas Holz Date: Mon, 3 Jun 2024 12:29:59 +0000 Subject: [PATCH] feat: add early stopping --- training/run_distillation.py | 127 ++++++++++++++++++++++++++--------- 1 file changed, 96 insertions(+), 31 deletions(-) diff --git a/training/run_distillation.py b/training/run_distillation.py index 3f4ff34..37fd144 100644 --- a/training/run_distillation.py +++ b/training/run_distillation.py @@ -58,7 +58,7 @@ WhisperForConditionalGeneration, WhisperProcessor, WhisperTokenizerFast, - get_scheduler + get_scheduler, ) from transformers.modeling_outputs import BaseModelOutput from transformers.models.whisper.english_normalizer import BasicTextNormalizer, EnglishTextNormalizer @@ -470,6 +470,49 @@ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> D return batch +class EarlyStopping: + """ + Monitor the total eval loss and stop training when it stops improving. + Args: + patience (:obj: `int`) + Number of checks / epochs with no improvement after which training will be \ + stopped. + min_delta (:obj: `float`) + Minimum change in the monitored total eval loss to qualify as an \ + improvement, i.e. an absolute change of less than or equal to \ + min_delta, will count as no improvement. + """ + + def __init__(self, patience: int = 3, min_delta: float = 0.001): + self.patience: int = patience + self.min_delta: float = min_delta + self.counter: int = 0 + self.best_loss: Optional[float] = None + self.early_stop: bool = False + + def __call__(self, val_loss: float, epoch: int): + """ + Call this method if cur_step % eval_steps == 0 or cur_step == total_train_steps\ + with its corresponding validation loss. + + Args: + val_loss (:obj: 'float'): Current epoch's validation loss. + epoch (:obj: 'int'): Current epoch number. + """ + if self.best_loss is None: + self.best_loss = val_loss + elif val_loss < self.best_loss - self.min_delta: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + logger.info(f"Increased early stopping counter at epoch {epoch}: {self.counter}/{self.patience}.") + + if self.counter >= self.patience: + logger.info(f"Early stopping at epoch {epoch}") + self.early_stop = True + + def log_metric( accelerator, metrics: Dict, @@ -656,7 +699,7 @@ def load_multiple_datasets( if use_pseudo_labels: if "whisper_transcript" not in dataset_features: raise ValueError( - f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure" + f"Pseudo-label column `whisper_transcript` not found in dataset {dataset_dict['name']}. Ensure " "pseudo-labels are present in the dataset under this column name, or train directly on the text " "labels by setting `--use_pseudo_labels=False` and defining the appropriate `--text_column_name`." ) @@ -790,11 +833,7 @@ def main(): accelerator.init_trackers( project_name=data_args.wandb_project, - init_kwargs={ - "wandb": {"name": data_args.wandb_name, - "dir": data_args.wandb_dir} - } - + init_kwargs={"wandb": {"name": data_args.wandb_name, "dir": data_args.wandb_dir}}, ) # 3. Set-up basic logging @@ -999,13 +1038,12 @@ def set_trainable_parameters(module, requires_grad=False): if training_args.freeze_encoder: set_trainable_parameters(student_model.model.encoder, requires_grad=False) student_model.model.encoder.gradient_checkpointing = False - + if training_args.freeze_decoder: set_trainable_parameters(student_model.model.decoder, requires_grad=False) student_model.model.decoder.gradient_checkpointing = False # un-freeze LM head parameters (and consequently word embeddings), frozen when frozing decoder since tied word embedding and LM head - set_trainable_parameters(student_model.proj_out, requires_grad=True) - + set_trainable_parameters(student_model.proj_out, requires_grad=True) if training_args.freeze_embed_positions: # set_trainable_parameters(student_model.model.decoder.embed_tokens, requires_grad=False) @@ -1014,7 +1052,7 @@ def set_trainable_parameters(module, requires_grad=False): logger.info( "Disabling gradient checkpointing in the decoder since it's incompatible with `freeze_embed_positions`." ) - + logger.info( f"Number of trainable parameters: {sum(p.numel() for p in student_model.parameters() if p.requires_grad):.3e}" ) @@ -1349,12 +1387,12 @@ def compute_metrics(preds, labels): eval_steps = training_args.eval_steps # 13. Define optimizer, LR scheduler, collator - + forbidden_module = [ module for module, flag in [ (student_model.model.encoder, training_args.freeze_encoder), - (student_model.model.decoder, training_args.freeze_decoder) + (student_model.model.decoder, training_args.freeze_decoder), ] if flag ] or None @@ -1503,6 +1541,26 @@ def generate_step(batch): output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id) return output_ids + def push_model_to_hub( + training_args: DistillationTrainingArguments, + repo_name: str, + cur_step: int, + ) -> None: + upload_folder( + folder_path=training_args.output_dir, + repo_id=repo_name, + repo_type="model", + commit_message=f"Saving final weights of step {cur_step}", + ) + + def unwrap_and_save( + training_args: DistillationTrainingArguments, + accelerator: Accelerator, + student_model: WhisperForConditionalGeneration, + ) -> None: + student_model = accelerator.unwrap_model(student_model) + student_model.save_pretrained(training_args.output_dir) + logger.info("***** Running training *****") logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}") if not data_args.streaming: @@ -1559,6 +1617,8 @@ def generate_step(batch): else: resume_step = None + early_stopping = EarlyStopping() + for epoch in range(epochs_trained, num_epochs): vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed) train_dataloader = DataLoader( @@ -1592,6 +1652,7 @@ def generate_step(batch): if accelerator.sync_gradients: steps_trained_progress_bar.update(1) cur_step += 1 + best_model_tag = "" if cur_step % training_args.logging_steps == 0: steps_trained_progress_bar.write( @@ -1611,19 +1672,19 @@ def generate_step(batch): # save checkpoint and weights after each save_steps and at the end of training if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps: - intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}") + if early_stopping.counter == 0: + best_model_tag = "-best" + + intermediate_dir = os.path.join( + training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}{best_model_tag}" + ) accelerator.save_state(output_dir=intermediate_dir) accelerator.wait_for_everyone() if accelerator.is_main_process: rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir) if training_args.push_to_hub: - upload_folder( - folder_path=training_args.output_dir, - repo_id=repo_name, - repo_type="model", - commit_message=f"Saving train state of step {cur_step}", - ) + push_model_to_hub(training_args, repo_name, cur_step) if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps): train_time += time.time() - train_start @@ -1709,20 +1770,24 @@ def generate_step(batch): # flush the train metrics train_start = time.time() - # break condition - if cur_step == total_train_steps: + # Check early stopping condition + early_stopping(float(eval_metrics["loss"]), epoch) + + if early_stopping.early_stop: + if training_args.push_to_hub: + push_model_to_hub(training_args, repo_name, cur_step) + + unwrap_and_save(training_args, accelerator, student_model) - # un-wrap student model for save - student_model = accelerator.unwrap_model(student_model) - student_model.save_pretrained(training_args.output_dir) + continue_training = False + break + # break condition + if cur_step == total_train_steps: if training_args.push_to_hub: - upload_folder( - folder_path=training_args.output_dir, - repo_id=repo_name, - repo_type="model", - commit_message=f"Saving final weights of step {cur_step}", - ) + push_model_to_hub(training_args, repo_name, cur_step) + + unwrap_and_save(training_args, accelerator, student_model) continue_training = False break